forked from TrueCloudLab/neoneo-go
rpcclient: Support mTLS
Signed-off-by: Evgenii Stratonikov <fyfyrchik@runbox.com>
This commit is contained in:
parent
90efaa4771
commit
925ba49d92
3 changed files with 60 additions and 8 deletions
|
@ -3,6 +3,7 @@ package rpcclient
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -68,6 +69,7 @@ type Options struct {
|
||||||
RequestTimeout time.Duration
|
RequestTimeout time.Duration
|
||||||
// Limit total number of connections per host. No limit by default.
|
// Limit total number of connections per host. No limit by default.
|
||||||
MaxConnsPerHost int
|
MaxConnsPerHost int
|
||||||
|
TLSClientConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// cache stores cache values for the RPC client methods.
|
// cache stores cache values for the RPC client methods.
|
||||||
|
@ -104,13 +106,16 @@ func initClient(ctx context.Context, cl *Client, endpoint string, opts Options)
|
||||||
opts.RequestTimeout = defaultRequestTimeout
|
opts.RequestTimeout = defaultRequestTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := &http.Client{
|
tr := &http.Transport{
|
||||||
Transport: &http.Transport{
|
|
||||||
DialContext: (&net.Dialer{
|
DialContext: (&net.Dialer{
|
||||||
Timeout: opts.DialTimeout,
|
Timeout: opts.DialTimeout,
|
||||||
}).DialContext,
|
}).DialContext,
|
||||||
MaxConnsPerHost: opts.MaxConnsPerHost,
|
MaxConnsPerHost: opts.MaxConnsPerHost,
|
||||||
},
|
TLSClientConfig: opts.TLSClientConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: tr,
|
||||||
Timeout: opts.RequestTimeout,
|
Timeout: opts.RequestTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
47
pkg/rpcclient/mtls_hook.go
Normal file
47
pkg/rpcclient/mtls_hook.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package rpcclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TransportHook ...
|
||||||
|
type TransportHook = func(*http.Transport)
|
||||||
|
|
||||||
|
func TLSClientConfig(rootCAs []string, certFile, keyFile string) (*tls.Config, error) {
|
||||||
|
certificate, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read client certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
for _, name := range rootCAs {
|
||||||
|
caCertFile, err := os.ReadFile(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPool.AppendCertsFromPEM(caCertFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
Certificates: []tls.Certificate{certificate},
|
||||||
|
InsecureSkipVerify: len(rootCAs) == 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTLSTransportHook enables client certificate advertising as well as retricting the set of rootCA we accept.
|
||||||
|
func MTLSTransportHook(rootCAs []string, certFile, keyFile string) (func(*http.Transport), error) {
|
||||||
|
cfg, err := TLSClientConfig(rootCAs, certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(tr *http.Transport) {
|
||||||
|
tr.TLSClientConfig = cfg
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -453,7 +453,7 @@ var errConnClosedByUser = errors.New("connection closed by user")
|
||||||
// You should call Init method to initialize the network magic the client is
|
// You should call Init method to initialize the network magic the client is
|
||||||
// operating on.
|
// operating on.
|
||||||
func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) {
|
func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) {
|
||||||
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
|
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout, TLSClientConfig: opts.TLSClientConfig}
|
||||||
ws, resp, err := dialer.DialContext(ctx, endpoint, nil)
|
ws, resp, err := dialer.DialContext(ctx, endpoint, nil)
|
||||||
if resp != nil && resp.Body != nil { // Can be non-nil even with error returned.
|
if resp != nil && resp.Body != nil { // Can be non-nil even with error returned.
|
||||||
defer resp.Body.Close() // Not exactly required by websocket, but let's do this for bodyclose checker.
|
defer resp.Body.Close() // Not exactly required by websocket, but let's do this for bodyclose checker.
|
||||||
|
|
Loading…
Reference in a new issue