diff --git a/pool/pool.go b/pool/pool.go index 079f05f..e2bcc90 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -219,6 +219,7 @@ type wrapperPrm struct { timeout time.Duration errorThreshold uint32 responseInfoCallback func(sdkClient.ResponseMetaInfo) error + dialCtx context.Context } // setAddress sets endpoint to connect in NeoFS network. @@ -247,6 +248,11 @@ func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo) x.responseInfoCallback = f } +// setDialContext specifies context for client dial. +func (x *wrapperPrm) setDialContext(ctx context.Context) { + x.dialCtx = ctx +} + // newWrapper creates a clientWrapper that implements the client interface. func newWrapper(prm wrapperPrm) (*clientWrapper, error) { var prmInit sdkClient.PrmInit @@ -263,6 +269,7 @@ func newWrapper(prm wrapperPrm) (*clientWrapper, error) { var prmDial sdkClient.PrmDial prmDial.SetServerURI(prm.address) prmDial.SetTimeout(prm.timeout) + prmDial.SetContext(prm.dialCtx) err := res.client.Dial(prmDial) if err != nil { @@ -818,6 +825,14 @@ func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error return err } +// clientBuilder is a type alias of client constructors which open connection +// to the given endpoint. +type clientBuilder = func(endpoint string) (client, error) + +// clientBuilderContext is a type alias of client constructors which open +// connection to the given endpoint using provided context. +type clientBuilderContext = func(ctx context.Context, endpoint string) (client, error) + // InitParameters contains values used to initialize connection Pool. type InitParameters struct { key *ecdsa.PrivateKey @@ -829,7 +844,7 @@ type InitParameters struct { errorThreshold uint32 nodeParams []NodeParam - clientBuilder func(endpoint string) (client, error) + clientBuilder clientBuilderContext } // SetKey specifies default key to be used for the protocol communication by default. @@ -876,6 +891,24 @@ func (x *InitParameters) AddNode(nodeParam NodeParam) { x.nodeParams = append(x.nodeParams, nodeParam) } +// setClientBuilder sets clientBuilder used for client construction. +// Wraps setClientBuilderContext without a context. +func (x *InitParameters) setClientBuilder(builder clientBuilder) { + x.setClientBuilderContext(func(_ context.Context, endpoint string) (client, error) { + return builder(endpoint) + }) +} + +// setClientBuilderContext sets clientBuilderContext used for client construction. +func (x *InitParameters) setClientBuilderContext(builder clientBuilderContext) { + x.clientBuilder = builder +} + +// isMissingClientBuilder checks if client constructor was not specified. +func (x *InitParameters) isMissingClientBuilder() bool { + return x.clientBuilder == nil +} + type rebalanceParameters struct { nodesParams []*nodesParam nodeRequestTimeout time.Duration @@ -1303,7 +1336,7 @@ type Pool struct { cache *sessionCache stokenDuration uint64 rebalanceParams rebalanceParameters - clientBuilder func(endpoint string) (client, error) + clientBuilder clientBuilderContext logger *zap.Logger } @@ -1371,7 +1404,7 @@ func (p *Pool) Dial(ctx context.Context) error { for i, params := range p.rebalanceParams.nodesParams { clients := make([]client, len(params.weights)) for j, addr := range params.addresses { - c, err := p.clientBuilder(addr) + c, err := p.clientBuilder(ctx, addr) if err != nil { return err } @@ -1428,8 +1461,8 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { params.healthcheckTimeout = defaultRequestTimeout } - if params.clientBuilder == nil { - params.clientBuilder = func(addr string) (client, error) { + if params.isMissingClientBuilder() { + params.setClientBuilderContext(func(ctx context.Context, addr string) (client, error) { var prm wrapperPrm prm.setAddress(addr) prm.setKey(*params.key) @@ -1439,8 +1472,9 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { cache.updateEpoch(info.Epoch()) return nil }) + prm.setDialContext(ctx) return newWrapper(prm) - } + }) } } diff --git a/pool/pool_test.go b/pool/pool_test.go index 7275835..6a93989 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -27,10 +27,10 @@ func TestBuildPoolClientFailed(t *testing.T) { } opts := InitParameters{ - key: newPrivateKey(t), - nodeParams: []NodeParam{{1, "peer0", 1}}, - clientBuilder: clientBuilder, + key: newPrivateKey(t), + nodeParams: []NodeParam{{1, "peer0", 1}}, } + opts.setClientBuilder(clientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -46,10 +46,10 @@ func TestBuildPoolCreateSessionFailed(t *testing.T) { } opts := InitParameters{ - key: newPrivateKey(t), - nodeParams: []NodeParam{{1, "peer0", 1}}, - clientBuilder: clientBuilder, + key: newPrivateKey(t), + nodeParams: []NodeParam{{1, "peer0", 1}}, } + opts.setClientBuilder(clientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -87,11 +87,11 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { require.NoError(t, err) opts := InitParameters{ key: newPrivateKey(t), - clientBuilder: clientBuilder, clientRebalanceInterval: 1000 * time.Millisecond, logger: log, nodeParams: nodes, } + opts.setClientBuilder(clientBuilder) clientPool, err := NewPool(opts) require.NoError(t, err) @@ -127,10 +127,10 @@ func TestOneNode(t *testing.T) { } opts := InitParameters{ - key: newPrivateKey(t), - nodeParams: []NodeParam{{1, "peer0", 1}}, - clientBuilder: clientBuilder, + key: newPrivateKey(t), + nodeParams: []NodeParam{{1, "peer0", 1}}, } + opts.setClientBuilder(clientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -159,8 +159,8 @@ func TestTwoNodes(t *testing.T) { {1, "peer0", 1}, {1, "peer1", 1}, }, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -209,8 +209,8 @@ func TestOneOfTwoFailed(t *testing.T) { key: newPrivateKey(t), nodeParams: nodes, clientRebalanceInterval: 200 * time.Millisecond, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -247,8 +247,8 @@ func TestTwoFailed(t *testing.T) { {1, "peer1", 1}, }, clientRebalanceInterval: 200 * time.Millisecond, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -280,8 +280,8 @@ func TestSessionCache(t *testing.T) { {1, "peer0", 1}, }, clientRebalanceInterval: 30 * time.Second, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -348,8 +348,8 @@ func TestPriority(t *testing.T) { key: newPrivateKey(t), nodeParams: nodes, clientRebalanceInterval: 1500 * time.Millisecond, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -395,8 +395,8 @@ func TestSessionCacheWithKey(t *testing.T) { {1, "peer0", 1}, }, clientRebalanceInterval: 30 * time.Second, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -434,8 +434,8 @@ func TestSessionTokenOwner(t *testing.T) { nodeParams: []NodeParam{ {1, "peer0", 1}, }, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -638,8 +638,8 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { key: newPrivateKey(t), nodeParams: nodes, clientRebalanceInterval: 30 * time.Second, - clientBuilder: clientBuilder, } + opts.setClientBuilder(clientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel()