From fa9573e85756918fd11aeec0b931989bd213bcb0 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 7 Apr 2023 08:55:12 +0300 Subject: [PATCH] [#47] client: Pass context to Dial() explicitly Signed-off-by: Evgenii Stratonikov --- client/client.go | 20 ++------------------ client/client_test.go | 5 +---- pool/pool.go | 6 ++---- 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/client/client.go b/client/client.go index 59de8abd..601f5574 100644 --- a/client/client.go +++ b/client/client.go @@ -76,7 +76,7 @@ func (c *Client) Init(prm PrmInit) { // Calling multiple times leads to undefined behavior. // // See also Init / Close. -func (c *Client) Dial(prm PrmDial) error { +func (c *Client) Dial(ctx context.Context, prm PrmDial) error { if prm.endpoint == "" { return errorServerAddrUnset } @@ -105,13 +105,9 @@ func (c *Client) Dial(prm PrmDial) error { c.setFrostFSAPIServer((*coreServer)(&c.c)) - if prm.parentCtx == nil { - prm.parentCtx = context.Background() - } - // TODO: (neofs-api-go#382) perform generic dial stage of the client.Client _, err := rpc.Balance(&c.c, new(v2accounting.BalanceRequest), - client.WithContext(prm.parentCtx), + client.WithContext(ctx), ) // return context errors since they signal about dial problem if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { @@ -180,8 +176,6 @@ func (x *PrmInit) SetResponseInfoCallback(f func(ResponseMetaInfo) error) { // PrmDial groups connection parameters for the Client. // // See also Dial. -// -// nolint: containedctx type PrmDial struct { endpoint string @@ -192,8 +186,6 @@ type PrmDial struct { streamTimeoutSet bool streamTimeout time.Duration - - parentCtx context.Context } // SetServerURI sets server URI in the FrostFS network. @@ -234,11 +226,3 @@ func (x *PrmDial) SetStreamTimeout(timeout time.Duration) { x.streamTimeoutSet = true x.streamTimeout = timeout } - -// SetContext allows to specify optional base context within which connection -// should be established. -// -// Context SHOULD NOT be nil. -func (x *PrmDial) SetContext(ctx context.Context) { - x.parentCtx = ctx -} diff --git a/client/client_test.go b/client/client_test.go index 98c9ab80..16f66e2d 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -47,11 +47,8 @@ func TestClient_DialContext(t *testing.T) { prm.SetServerURI("localhost:8080") assert := func(ctx context.Context, errExpected error) { - // use the particular context - prm.SetContext(ctx) - // expect particular context error according to Dial docs - require.ErrorIs(t, c.Dial(prm), errExpected) + require.ErrorIs(t, c.Dial(ctx, prm), errExpected) } // create pre-abandoned context diff --git a/pool/pool.go b/pool/pool.go index 8dac8a04..657afbf5 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -304,9 +304,8 @@ func (c *clientWrapper) dial(ctx context.Context) error { prmDial.SetServerURI(c.prm.address) prmDial.SetTimeout(c.prm.dialTimeout) prmDial.SetStreamTimeout(c.prm.streamTimeout) - prmDial.SetContext(ctx) - if err = cl.Dial(prmDial); err != nil { + if err = cl.Dial(ctx, prmDial); err != nil { c.setUnhealthy() return err } @@ -335,9 +334,8 @@ func (c *clientWrapper) restartIfUnhealthy(ctx context.Context) (healthy, change prmDial.SetServerURI(c.prm.address) prmDial.SetTimeout(c.prm.dialTimeout) prmDial.SetStreamTimeout(c.prm.streamTimeout) - prmDial.SetContext(ctx) - if err := cl.Dial(prmDial); err != nil { + if err := cl.Dial(ctx, prmDial); err != nil { c.setUnhealthy() return false, wasHealthy }