forked from TrueCloudLab/neoneo-go
rpcclient: Support mTLS
Signed-off-by: Evgenii Stratonikov <fyfyrchik@runbox.com>
This commit is contained in:
parent
90efaa4771
commit
925ba49d92
3 changed files with 60 additions and 8 deletions
|
@ -3,6 +3,7 @@ package rpcclient
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -68,6 +69,7 @@ type Options struct {
|
|||
RequestTimeout time.Duration
|
||||
// Limit total number of connections per host. No limit by default.
|
||||
MaxConnsPerHost int
|
||||
TLSClientConfig *tls.Config
|
||||
}
|
||||
|
||||
// cache stores cache values for the RPC client methods.
|
||||
|
@ -104,14 +106,17 @@ func initClient(ctx context.Context, cl *Client, endpoint string, opts Options)
|
|||
opts.RequestTimeout = defaultRequestTimeout
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: opts.DialTimeout,
|
||||
}).DialContext,
|
||||
MaxConnsPerHost: opts.MaxConnsPerHost,
|
||||
TLSClientConfig: opts.TLSClientConfig,
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: opts.DialTimeout,
|
||||
}).DialContext,
|
||||
MaxConnsPerHost: opts.MaxConnsPerHost,
|
||||
},
|
||||
Timeout: opts.RequestTimeout,
|
||||
Transport: tr,
|
||||
Timeout: opts.RequestTimeout,
|
||||
}
|
||||
|
||||
// TODO(@antdm): Enable SSL.
|
||||
|
|
47
pkg/rpcclient/mtls_hook.go
Normal file
47
pkg/rpcclient/mtls_hook.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package rpcclient
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
// TransportHook ...
|
||||
type TransportHook = func(*http.Transport)
|
||||
|
||||
func TLSClientConfig(rootCAs []string, certFile, keyFile string) (*tls.Config, error) {
|
||||
certificate, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read client certificate: %w", err)
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
for _, name := range rootCAs {
|
||||
caCertFile, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read CA certificate: %w", err)
|
||||
}
|
||||
|
||||
caCertPool.AppendCertsFromPEM(caCertFile)
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
InsecureSkipVerify: len(rootCAs) == 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MTLSTransportHook enables client certificate advertising as well as retricting the set of rootCA we accept.
|
||||
func MTLSTransportHook(rootCAs []string, certFile, keyFile string) (func(*http.Transport), error) {
|
||||
cfg, err := TLSClientConfig(rootCAs, certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(tr *http.Transport) {
|
||||
tr.TLSClientConfig = cfg
|
||||
}, nil
|
||||
}
|
|
@ -453,7 +453,7 @@ var errConnClosedByUser = errors.New("connection closed by user")
|
|||
// You should call Init method to initialize the network magic the client is
|
||||
// operating on.
|
||||
func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) {
|
||||
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
|
||||
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout, TLSClientConfig: opts.TLSClientConfig}
|
||||
ws, resp, err := dialer.DialContext(ctx, endpoint, nil)
|
||||
if resp != nil && resp.Body != nil { // Can be non-nil even with error returned.
|
||||
defer resp.Body.Close() // Not exactly required by websocket, but let's do this for bodyclose checker.
|
||||
|
|
Loading…
Reference in a new issue