rpcclient: Support mTLS

Signed-off-by: Evgenii Stratonikov <fyfyrchik@runbox.com>
This commit is contained in:
Evgenii Stratonikov 2024-06-05 14:50:25 +03:00
parent 90efaa4771
commit 925ba49d92
3 changed files with 60 additions and 8 deletions

View file

@ -3,6 +3,7 @@ package rpcclient
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -68,6 +69,7 @@ type Options struct {
RequestTimeout time.Duration RequestTimeout time.Duration
// Limit total number of connections per host. No limit by default. // Limit total number of connections per host. No limit by default.
MaxConnsPerHost int MaxConnsPerHost int
TLSClientConfig *tls.Config
} }
// cache stores cache values for the RPC client methods. // cache stores cache values for the RPC client methods.
@ -104,13 +106,16 @@ func initClient(ctx context.Context, cl *Client, endpoint string, opts Options)
opts.RequestTimeout = defaultRequestTimeout opts.RequestTimeout = defaultRequestTimeout
} }
httpClient := &http.Client{ tr := &http.Transport{
Transport: &http.Transport{
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: opts.DialTimeout, Timeout: opts.DialTimeout,
}).DialContext, }).DialContext,
MaxConnsPerHost: opts.MaxConnsPerHost, MaxConnsPerHost: opts.MaxConnsPerHost,
}, TLSClientConfig: opts.TLSClientConfig,
}
httpClient := &http.Client{
Transport: tr,
Timeout: opts.RequestTimeout, Timeout: opts.RequestTimeout,
} }

View file

@ -0,0 +1,47 @@
package rpcclient
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"os"
)
// TransportHook ...
type TransportHook = func(*http.Transport)
func TLSClientConfig(rootCAs []string, certFile, keyFile string) (*tls.Config, error) {
certificate, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("read client certificate: %w", err)
}
caCertPool := x509.NewCertPool()
for _, name := range rootCAs {
caCertFile, err := os.ReadFile(name)
if err != nil {
return nil, fmt.Errorf("read CA certificate: %w", err)
}
caCertPool.AppendCertsFromPEM(caCertFile)
}
return &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{certificate},
InsecureSkipVerify: len(rootCAs) == 0,
}, nil
}
// MTLSTransportHook enables client certificate advertising as well as retricting the set of rootCA we accept.
func MTLSTransportHook(rootCAs []string, certFile, keyFile string) (func(*http.Transport), error) {
cfg, err := TLSClientConfig(rootCAs, certFile, keyFile)
if err != nil {
return nil, err
}
return func(tr *http.Transport) {
tr.TLSClientConfig = cfg
}, nil
}

View file

@ -453,7 +453,7 @@ var errConnClosedByUser = errors.New("connection closed by user")
// You should call Init method to initialize the network magic the client is // You should call Init method to initialize the network magic the client is
// operating on. // operating on.
func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) { func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) {
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout, TLSClientConfig: opts.TLSClientConfig}
ws, resp, err := dialer.DialContext(ctx, endpoint, nil) ws, resp, err := dialer.DialContext(ctx, endpoint, nil)
if resp != nil && resp.Body != nil { // Can be non-nil even with error returned. if resp != nil && resp.Body != nil { // Can be non-nil even with error returned.
defer resp.Body.Close() // Not exactly required by websocket, but let's do this for bodyclose checker. defer resp.Body.Close() // Not exactly required by websocket, but let's do this for bodyclose checker.