From 8aee80dbdc072492e350e1a751d0d8cd501b556f Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Tue, 15 Oct 2024 16:38:23 +0300 Subject: [PATCH] [#2] rpcclient: Allow to specify custom DialContext func Signed-off-by: Dmitrii Stepanov --- pkg/rpcclient/client.go | 15 ++++++++++++--- pkg/rpcclient/wsclient.go | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pkg/rpcclient/client.go b/pkg/rpcclient/client.go index 2073946eb..07d293cb5 100644 --- a/pkg/rpcclient/client.go +++ b/pkg/rpcclient/client.go @@ -70,6 +70,7 @@ type Options struct { // Limit total number of connections per host. No limit by default. MaxConnsPerHost int TLSClientConfig *tls.Config + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) } // cache stores cache values for the RPC client methods. @@ -105,11 +106,19 @@ func initClient(ctx context.Context, cl *Client, endpoint string, opts Options) if opts.RequestTimeout <= 0 { opts.RequestTimeout = defaultRequestTimeout } + dialContext := (&net.Dialer{ + Timeout: opts.DialTimeout, + }).DialContext + if opts.NetDialContext != nil { + dialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, opts.DialTimeout) + defer cancel() + return opts.NetDialContext(ctx, network, addr) + } + } tr := &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: opts.DialTimeout, - }).DialContext, + DialContext: dialContext, MaxConnsPerHost: opts.MaxConnsPerHost, TLSClientConfig: opts.TLSClientConfig, } diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 455e4d4f8..8d88bcbff 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -453,7 +453,7 @@ var errConnClosedByUser = errors.New("connection closed by user") // You should call Init method to initialize the network magic the client is // operating on. func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) { - dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout, TLSClientConfig: opts.TLSClientConfig} + dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout, TLSClientConfig: opts.TLSClientConfig, NetDialContext: opts.NetDialContext} ws, resp, err := dialer.DialContext(ctx, endpoint, nil) if resp != nil && resp.Body != nil { // Can be non-nil even with error returned. defer resp.Body.Close() // Not exactly required by websocket, but let's do this for bodyclose checker. -- 2.45.2