From 925ba49d92fa0370a2ea81f117017d0666b8ab8a Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Wed, 5 Jun 2024 14:50:25 +0300 Subject: [PATCH] rpcclient: Support mTLS Signed-off-by: Evgenii Stratonikov --- pkg/rpcclient/client.go | 19 +++++++++------ pkg/rpcclient/mtls_hook.go | 47 ++++++++++++++++++++++++++++++++++++++ pkg/rpcclient/wsclient.go | 2 +- 3 files changed, 60 insertions(+), 8 deletions(-) create mode 100644 pkg/rpcclient/mtls_hook.go diff --git a/pkg/rpcclient/client.go b/pkg/rpcclient/client.go index 12af5f0b8..2073946eb 100644 --- a/pkg/rpcclient/client.go +++ b/pkg/rpcclient/client.go @@ -3,6 +3,7 @@ package rpcclient import ( "bytes" "context" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -68,6 +69,7 @@ type Options struct { RequestTimeout time.Duration // Limit total number of connections per host. No limit by default. MaxConnsPerHost int + TLSClientConfig *tls.Config } // cache stores cache values for the RPC client methods. @@ -104,14 +106,17 @@ func initClient(ctx context.Context, cl *Client, endpoint string, opts Options) opts.RequestTimeout = defaultRequestTimeout } + tr := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: opts.DialTimeout, + }).DialContext, + MaxConnsPerHost: opts.MaxConnsPerHost, + TLSClientConfig: opts.TLSClientConfig, + } + httpClient := &http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: opts.DialTimeout, - }).DialContext, - MaxConnsPerHost: opts.MaxConnsPerHost, - }, - Timeout: opts.RequestTimeout, + Transport: tr, + Timeout: opts.RequestTimeout, } // TODO(@antdm): Enable SSL. diff --git a/pkg/rpcclient/mtls_hook.go b/pkg/rpcclient/mtls_hook.go new file mode 100644 index 000000000..944e8996f --- /dev/null +++ b/pkg/rpcclient/mtls_hook.go @@ -0,0 +1,47 @@ +package rpcclient + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" +) + +// TransportHook ... +type TransportHook = func(*http.Transport) + +func TLSClientConfig(rootCAs []string, certFile, keyFile string) (*tls.Config, error) { + certificate, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, fmt.Errorf("read client certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + for _, name := range rootCAs { + caCertFile, err := os.ReadFile(name) + if err != nil { + return nil, fmt.Errorf("read CA certificate: %w", err) + } + + caCertPool.AppendCertsFromPEM(caCertFile) + } + + return &tls.Config{ + RootCAs: caCertPool, + Certificates: []tls.Certificate{certificate}, + InsecureSkipVerify: len(rootCAs) == 0, + }, nil +} + +// MTLSTransportHook enables client certificate advertising as well as retricting the set of rootCA we accept. +func MTLSTransportHook(rootCAs []string, certFile, keyFile string) (func(*http.Transport), error) { + cfg, err := TLSClientConfig(rootCAs, certFile, keyFile) + if err != nil { + return nil, err + } + + return func(tr *http.Transport) { + tr.TLSClientConfig = cfg + }, nil +} diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 3b131f890..dc877c407 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -453,7 +453,7 @@ var errConnClosedByUser = errors.New("connection closed by user") // You should call Init method to initialize the network magic the client is // operating on. func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) { - dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} + dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout, TLSClientConfig: opts.TLSClientConfig} ws, resp, err := dialer.DialContext(ctx, endpoint, nil) if resp != nil && resp.Body != nil { // Can be non-nil even with error returned. defer resp.Body.Close() // Not exactly required by websocket, but let's do this for bodyclose checker.