[#47] client: Pass context to Dial() explicitly

Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
This commit is contained in:
Evgenii Stratonikov 2023-04-07 08:55:12 +03:00
parent bc62e2f712
commit fa9573e857
3 changed files with 5 additions and 26 deletions

View file

@ -76,7 +76,7 @@ func (c *Client) Init(prm PrmInit) {
// Calling multiple times leads to undefined behavior. // Calling multiple times leads to undefined behavior.
// //
// See also Init / Close. // See also Init / Close.
func (c *Client) Dial(prm PrmDial) error { func (c *Client) Dial(ctx context.Context, prm PrmDial) error {
if prm.endpoint == "" { if prm.endpoint == "" {
return errorServerAddrUnset return errorServerAddrUnset
} }
@ -105,13 +105,9 @@ func (c *Client) Dial(prm PrmDial) error {
c.setFrostFSAPIServer((*coreServer)(&c.c)) c.setFrostFSAPIServer((*coreServer)(&c.c))
if prm.parentCtx == nil {
prm.parentCtx = context.Background()
}
// TODO: (neofs-api-go#382) perform generic dial stage of the client.Client // TODO: (neofs-api-go#382) perform generic dial stage of the client.Client
_, err := rpc.Balance(&c.c, new(v2accounting.BalanceRequest), _, err := rpc.Balance(&c.c, new(v2accounting.BalanceRequest),
client.WithContext(prm.parentCtx), client.WithContext(ctx),
) )
// return context errors since they signal about dial problem // return context errors since they signal about dial problem
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
@ -180,8 +176,6 @@ func (x *PrmInit) SetResponseInfoCallback(f func(ResponseMetaInfo) error) {
// PrmDial groups connection parameters for the Client. // PrmDial groups connection parameters for the Client.
// //
// See also Dial. // See also Dial.
//
// nolint: containedctx
type PrmDial struct { type PrmDial struct {
endpoint string endpoint string
@ -192,8 +186,6 @@ type PrmDial struct {
streamTimeoutSet bool streamTimeoutSet bool
streamTimeout time.Duration streamTimeout time.Duration
parentCtx context.Context
} }
// SetServerURI sets server URI in the FrostFS network. // SetServerURI sets server URI in the FrostFS network.
@ -234,11 +226,3 @@ func (x *PrmDial) SetStreamTimeout(timeout time.Duration) {
x.streamTimeoutSet = true x.streamTimeoutSet = true
x.streamTimeout = timeout x.streamTimeout = timeout
} }
// SetContext allows to specify optional base context within which connection
// should be established.
//
// Context SHOULD NOT be nil.
func (x *PrmDial) SetContext(ctx context.Context) {
x.parentCtx = ctx
}

View file

@ -47,11 +47,8 @@ func TestClient_DialContext(t *testing.T) {
prm.SetServerURI("localhost:8080") prm.SetServerURI("localhost:8080")
assert := func(ctx context.Context, errExpected error) { assert := func(ctx context.Context, errExpected error) {
// use the particular context
prm.SetContext(ctx)
// expect particular context error according to Dial docs // expect particular context error according to Dial docs
require.ErrorIs(t, c.Dial(prm), errExpected) require.ErrorIs(t, c.Dial(ctx, prm), errExpected)
} }
// create pre-abandoned context // create pre-abandoned context

View file

@ -304,9 +304,8 @@ func (c *clientWrapper) dial(ctx context.Context) error {
prmDial.SetServerURI(c.prm.address) prmDial.SetServerURI(c.prm.address)
prmDial.SetTimeout(c.prm.dialTimeout) prmDial.SetTimeout(c.prm.dialTimeout)
prmDial.SetStreamTimeout(c.prm.streamTimeout) prmDial.SetStreamTimeout(c.prm.streamTimeout)
prmDial.SetContext(ctx)
if err = cl.Dial(prmDial); err != nil { if err = cl.Dial(ctx, prmDial); err != nil {
c.setUnhealthy() c.setUnhealthy()
return err return err
} }
@ -335,9 +334,8 @@ func (c *clientWrapper) restartIfUnhealthy(ctx context.Context) (healthy, change
prmDial.SetServerURI(c.prm.address) prmDial.SetServerURI(c.prm.address)
prmDial.SetTimeout(c.prm.dialTimeout) prmDial.SetTimeout(c.prm.dialTimeout)
prmDial.SetStreamTimeout(c.prm.streamTimeout) prmDial.SetStreamTimeout(c.prm.streamTimeout)
prmDial.SetContext(ctx)
if err := cl.Dial(prmDial); err != nil { if err := cl.Dial(ctx, prmDial); err != nil {
c.setUnhealthy() c.setUnhealthy()
return false, wasHealthy return false, wasHealthy
} }