[#343] pool: Provide dial context to clients

After recent changes `client.Client` accepts dial context. There is a
need to forward the context passed into `Pool.Dial` to the underlying
`Client` instances.

Define type aliases of different client constructors: context-based and
non-context. Use context-based constructor in `Pool`. Pass `ctx`
parameter of `Pool.Dial` method to the client builder.

Signed-off-by: Leonard Lyubich <ctulhurider@gmail.com>
This commit is contained in:
Leonard Lyubich 2022-10-03 15:04:30 +04:00 committed by fyrchik
parent 452a50e9d5
commit 8c682641bf
2 changed files with 58 additions and 24 deletions

View file

@ -219,6 +219,7 @@ type wrapperPrm struct {
timeout time.Duration timeout time.Duration
errorThreshold uint32 errorThreshold uint32
responseInfoCallback func(sdkClient.ResponseMetaInfo) error responseInfoCallback func(sdkClient.ResponseMetaInfo) error
dialCtx context.Context
} }
// setAddress sets endpoint to connect in NeoFS network. // setAddress sets endpoint to connect in NeoFS network.
@ -247,6 +248,11 @@ func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo)
x.responseInfoCallback = f x.responseInfoCallback = f
} }
// setDialContext specifies context for client dial.
func (x *wrapperPrm) setDialContext(ctx context.Context) {
x.dialCtx = ctx
}
// newWrapper creates a clientWrapper that implements the client interface. // newWrapper creates a clientWrapper that implements the client interface.
func newWrapper(prm wrapperPrm) (*clientWrapper, error) { func newWrapper(prm wrapperPrm) (*clientWrapper, error) {
var prmInit sdkClient.PrmInit var prmInit sdkClient.PrmInit
@ -263,6 +269,7 @@ func newWrapper(prm wrapperPrm) (*clientWrapper, error) {
var prmDial sdkClient.PrmDial var prmDial sdkClient.PrmDial
prmDial.SetServerURI(prm.address) prmDial.SetServerURI(prm.address)
prmDial.SetTimeout(prm.timeout) prmDial.SetTimeout(prm.timeout)
prmDial.SetContext(prm.dialCtx)
err := res.client.Dial(prmDial) err := res.client.Dial(prmDial)
if err != nil { if err != nil {
@ -818,6 +825,14 @@ func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error
return err return err
} }
// clientBuilder is a type alias of client constructors which open connection
// to the given endpoint.
type clientBuilder = func(endpoint string) (client, error)
// clientBuilderContext is a type alias of client constructors which open
// connection to the given endpoint using provided context.
type clientBuilderContext = func(ctx context.Context, endpoint string) (client, error)
// InitParameters contains values used to initialize connection Pool. // InitParameters contains values used to initialize connection Pool.
type InitParameters struct { type InitParameters struct {
key *ecdsa.PrivateKey key *ecdsa.PrivateKey
@ -829,7 +844,7 @@ type InitParameters struct {
errorThreshold uint32 errorThreshold uint32
nodeParams []NodeParam nodeParams []NodeParam
clientBuilder func(endpoint string) (client, error) clientBuilder clientBuilderContext
} }
// SetKey specifies default key to be used for the protocol communication by default. // SetKey specifies default key to be used for the protocol communication by default.
@ -876,6 +891,24 @@ func (x *InitParameters) AddNode(nodeParam NodeParam) {
x.nodeParams = append(x.nodeParams, nodeParam) x.nodeParams = append(x.nodeParams, nodeParam)
} }
// setClientBuilder sets clientBuilder used for client construction.
// Wraps setClientBuilderContext without a context.
func (x *InitParameters) setClientBuilder(builder clientBuilder) {
x.setClientBuilderContext(func(_ context.Context, endpoint string) (client, error) {
return builder(endpoint)
})
}
// setClientBuilderContext sets clientBuilderContext used for client construction.
func (x *InitParameters) setClientBuilderContext(builder clientBuilderContext) {
x.clientBuilder = builder
}
// isMissingClientBuilder checks if client constructor was not specified.
func (x *InitParameters) isMissingClientBuilder() bool {
return x.clientBuilder == nil
}
type rebalanceParameters struct { type rebalanceParameters struct {
nodesParams []*nodesParam nodesParams []*nodesParam
nodeRequestTimeout time.Duration nodeRequestTimeout time.Duration
@ -1303,7 +1336,7 @@ type Pool struct {
cache *sessionCache cache *sessionCache
stokenDuration uint64 stokenDuration uint64
rebalanceParams rebalanceParameters rebalanceParams rebalanceParameters
clientBuilder func(endpoint string) (client, error) clientBuilder clientBuilderContext
logger *zap.Logger logger *zap.Logger
} }
@ -1371,7 +1404,7 @@ func (p *Pool) Dial(ctx context.Context) error {
for i, params := range p.rebalanceParams.nodesParams { for i, params := range p.rebalanceParams.nodesParams {
clients := make([]client, len(params.weights)) clients := make([]client, len(params.weights))
for j, addr := range params.addresses { for j, addr := range params.addresses {
c, err := p.clientBuilder(addr) c, err := p.clientBuilder(ctx, addr)
if err != nil { if err != nil {
return err return err
} }
@ -1428,8 +1461,8 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) {
params.healthcheckTimeout = defaultRequestTimeout params.healthcheckTimeout = defaultRequestTimeout
} }
if params.clientBuilder == nil { if params.isMissingClientBuilder() {
params.clientBuilder = func(addr string) (client, error) { params.setClientBuilderContext(func(ctx context.Context, addr string) (client, error) {
var prm wrapperPrm var prm wrapperPrm
prm.setAddress(addr) prm.setAddress(addr)
prm.setKey(*params.key) prm.setKey(*params.key)
@ -1439,8 +1472,9 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) {
cache.updateEpoch(info.Epoch()) cache.updateEpoch(info.Epoch())
return nil return nil
}) })
prm.setDialContext(ctx)
return newWrapper(prm) return newWrapper(prm)
} })
} }
} }

View file

@ -29,8 +29,8 @@ func TestBuildPoolClientFailed(t *testing.T) {
opts := InitParameters{ opts := InitParameters{
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}}, nodeParams: []NodeParam{{1, "peer0", 1}},
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -48,8 +48,8 @@ func TestBuildPoolCreateSessionFailed(t *testing.T) {
opts := InitParameters{ opts := InitParameters{
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}}, nodeParams: []NodeParam{{1, "peer0", 1}},
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -87,11 +87,11 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
opts := InitParameters{ opts := InitParameters{
key: newPrivateKey(t), key: newPrivateKey(t),
clientBuilder: clientBuilder,
clientRebalanceInterval: 1000 * time.Millisecond, clientRebalanceInterval: 1000 * time.Millisecond,
logger: log, logger: log,
nodeParams: nodes, nodeParams: nodes,
} }
opts.setClientBuilder(clientBuilder)
clientPool, err := NewPool(opts) clientPool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -129,8 +129,8 @@ func TestOneNode(t *testing.T) {
opts := InitParameters{ opts := InitParameters{
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}}, nodeParams: []NodeParam{{1, "peer0", 1}},
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -159,8 +159,8 @@ func TestTwoNodes(t *testing.T) {
{1, "peer0", 1}, {1, "peer0", 1},
{1, "peer1", 1}, {1, "peer1", 1},
}, },
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -209,8 +209,8 @@ func TestOneOfTwoFailed(t *testing.T) {
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: nodes, nodeParams: nodes,
clientRebalanceInterval: 200 * time.Millisecond, clientRebalanceInterval: 200 * time.Millisecond,
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -247,8 +247,8 @@ func TestTwoFailed(t *testing.T) {
{1, "peer1", 1}, {1, "peer1", 1},
}, },
clientRebalanceInterval: 200 * time.Millisecond, clientRebalanceInterval: 200 * time.Millisecond,
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -280,8 +280,8 @@ func TestSessionCache(t *testing.T) {
{1, "peer0", 1}, {1, "peer0", 1},
}, },
clientRebalanceInterval: 30 * time.Second, clientRebalanceInterval: 30 * time.Second,
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -348,8 +348,8 @@ func TestPriority(t *testing.T) {
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: nodes, nodeParams: nodes,
clientRebalanceInterval: 1500 * time.Millisecond, clientRebalanceInterval: 1500 * time.Millisecond,
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -395,8 +395,8 @@ func TestSessionCacheWithKey(t *testing.T) {
{1, "peer0", 1}, {1, "peer0", 1},
}, },
clientRebalanceInterval: 30 * time.Second, clientRebalanceInterval: 30 * time.Second,
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -434,8 +434,8 @@ func TestSessionTokenOwner(t *testing.T) {
nodeParams: []NodeParam{ nodeParams: []NodeParam{
{1, "peer0", 1}, {1, "peer0", 1},
}, },
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -638,8 +638,8 @@ func TestSwitchAfterErrorThreshold(t *testing.T) {
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: nodes, nodeParams: nodes,
clientRebalanceInterval: 30 * time.Second, clientRebalanceInterval: 30 * time.Second,
clientBuilder: clientBuilder,
} }
opts.setClientBuilder(clientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()