From 552219b8e130f07bb2e5faf1865a0c08d572fe70 Mon Sep 17 00:00:00 2001 From: Denis Kirillov Date: Wed, 29 Mar 2023 13:20:37 +0300 Subject: [PATCH] [#16] pool: Fix counting context canceled error Signed-off-by: Denis Kirillov --- pool/mock_test.go | 16 ++++++------- pool/pool.go | 58 +++++++++++++++++++++++------------------------ pool/pool_test.go | 24 +++++++++++++++++++- 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/pool/mock_test.go b/pool/mock_test.go index a74faa22..9df107dc 100644 --- a/pool/mock_test.go +++ b/pool/mock_test.go @@ -103,22 +103,22 @@ func (m *mockClient) containerSetEACL(context.Context, PrmContainerSetEACL) erro return nil } -func (m *mockClient) endpointInfo(context.Context, prmEndpointInfo) (netmap.NodeInfo, error) { +func (m *mockClient) endpointInfo(ctx context.Context, _ prmEndpointInfo) (netmap.NodeInfo, error) { var ni netmap.NodeInfo if m.errorOnEndpointInfo { - return ni, m.handleError(nil, errors.New("error")) + return ni, m.handleError(ctx, nil, errors.New("error")) } ni.SetNetworkEndpoints(m.addr) return ni, nil } -func (m *mockClient) networkInfo(context.Context, prmNetworkInfo) (netmap.NetworkInfo, error) { +func (m *mockClient) networkInfo(ctx context.Context, _ prmNetworkInfo) (netmap.NetworkInfo, error) { var ni netmap.NetworkInfo if m.errorOnNetworkInfo { - return ni, m.handleError(nil, errors.New("error")) + return ni, m.handleError(ctx, nil, errors.New("error")) } return ni, nil @@ -132,7 +132,7 @@ func (m *mockClient) objectDelete(context.Context, PrmObjectDelete) error { return nil } -func (m *mockClient) objectGet(context.Context, PrmObjectGet) (ResGetObject, error) { +func (m *mockClient) objectGet(ctx context.Context, _ PrmObjectGet) (ResGetObject, error) { var res ResGetObject if m.stOnGetObject == nil { @@ -140,7 +140,7 @@ func (m *mockClient) objectGet(context.Context, PrmObjectGet) (ResGetObject, err } status := apistatus.ErrFromStatus(m.stOnGetObject) - return res, m.handleError(status, nil) + return res, m.handleError(ctx, status, nil) } func (m *mockClient) objectHead(context.Context, PrmObjectHead) (object.Object, error) { @@ -155,9 +155,9 @@ func (m *mockClient) objectSearch(context.Context, PrmObjectSearch) (ResObjectSe return ResObjectSearch{}, nil } -func (m *mockClient) sessionCreate(context.Context, prmCreateSession) (resCreateSession, error) { +func (m *mockClient) sessionCreate(ctx context.Context, _ prmCreateSession) (resCreateSession, error) { if m.errorOnCreateSession { - return resCreateSession{}, m.handleError(nil, errors.New("error")) + return resCreateSession{}, m.handleError(ctx, nil, errors.New("error")) } tok := newToken(m.key) diff --git a/pool/pool.go b/pool/pool.go index ef9b3888..772afc15 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -32,8 +32,6 @@ import ( "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) // client represents virtual connection to the single FrostFS network endpoint from which Pool is formed. @@ -383,7 +381,7 @@ func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (acco if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return accounting.Decimal{}, fmt.Errorf("balance get on client: %w", err) } @@ -405,7 +403,7 @@ func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) ( if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return cid.ID{}, fmt.Errorf("container put on client: %w", err) } @@ -416,7 +414,7 @@ func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) ( idCnr := res.ID() err = waitForContainerPresence(ctx, c, idCnr, &prm.waitParams) - if err = c.handleError(nil, err); err != nil { + if err = c.handleError(ctx, nil, err); err != nil { return cid.ID{}, fmt.Errorf("wait container presence on client: %w", err) } @@ -440,7 +438,7 @@ func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) ( if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return container.Container{}, fmt.Errorf("container get on client: %w", err) } @@ -464,7 +462,7 @@ func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return nil, fmt.Errorf("container list on client: %w", err) } return res.Containers(), nil @@ -491,7 +489,7 @@ func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDel if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return fmt.Errorf("container delete on client: %w", err) } @@ -519,7 +517,7 @@ func (c *clientWrapper) containerEACL(ctx context.Context, prm PrmContainerEACL) if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return eacl.Table{}, fmt.Errorf("get eacl on client: %w", err) } @@ -548,7 +546,7 @@ func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSe if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return fmt.Errorf("set eacl on client: %w", err) } @@ -562,7 +560,7 @@ func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSe } err = waitForEACLPresence(ctx, c, cIDp, &prm.table, &prm.waitParams) - if err = c.handleError(nil, err); err != nil { + if err = c.handleError(ctx, nil, err); err != nil { return fmt.Errorf("wait eacl presence on client: %w", err) } @@ -583,7 +581,7 @@ func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (ne if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return netmap.NodeInfo{}, fmt.Errorf("endpoint info on client: %w", err) } @@ -604,7 +602,7 @@ func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netm if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return netmap.NetworkInfo{}, fmt.Errorf("network info on client: %w", err) } @@ -633,7 +631,7 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID start := time.Now() wObj, err := cl.ObjectPutInit(ctx, cliPrm) c.incRequests(time.Since(start), methodObjectPut) - if err = c.handleError(nil, err); err != nil { + if err = c.handleError(ctx, nil, err); err != nil { return oid.ID{}, fmt.Errorf("init writing on API client: %w", err) } @@ -677,7 +675,7 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID break } - return oid.ID{}, fmt.Errorf("read payload: %w", c.handleError(nil, err)) + return oid.ID{}, fmt.Errorf("read payload: %w", c.handleError(ctx, nil, err)) } } } @@ -687,7 +685,7 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { // here err already carries both status and client errors + if err = c.handleError(ctx, st, err); err != nil { // here err already carries both status and client errors return oid.ID{}, fmt.Errorf("client failure: %w", err) } @@ -724,7 +722,7 @@ func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) e if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return fmt.Errorf("delete object on client: %w", err) } return nil @@ -756,7 +754,7 @@ func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGet var res ResGetObject rObj, err := cl.ObjectGetInit(ctx, cliPrm) - if err = c.handleError(nil, err); err != nil { + if err = c.handleError(ctx, nil, err); err != nil { return ResGetObject{}, fmt.Errorf("init object reading on client: %w", err) } @@ -769,7 +767,7 @@ func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGet if rObjRes != nil { st = rObjRes.Status() } - err = c.handleError(st, err) + err = c.handleError(ctx, st, err) return res, fmt.Errorf("read header: %w", err) } @@ -818,7 +816,7 @@ func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (obje if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return obj, fmt.Errorf("read object header via client: %w", err) } if !res.ReadHeader(&obj) { @@ -856,7 +854,7 @@ func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (Re start := time.Now() res, err := cl.ObjectRangeInit(ctx, cliPrm) c.incRequests(time.Since(start), methodObjectRange) - if err = c.handleError(nil, err); err != nil { + if err = c.handleError(ctx, nil, err); err != nil { return ResObjectRange{}, fmt.Errorf("init payload range reading on client: %w", err) } @@ -893,7 +891,7 @@ func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) ( } res, err := cl.ObjectSearchInit(ctx, cliPrm) - if err = c.handleError(nil, err); err != nil { + if err = c.handleError(ctx, nil, err); err != nil { return ResObjectSearch{}, fmt.Errorf("init object searching on client: %w", err) } @@ -918,7 +916,7 @@ func (c *clientWrapper) sessionCreate(ctx context.Context, prm prmCreateSession) if res != nil { st = res.Status() } - if err = c.handleError(st, err); err != nil { + if err = c.handleError(ctx, st, err); err != nil { return resCreateSession{}, fmt.Errorf("session creation on client: %w", err) } @@ -995,9 +993,9 @@ func (c *clientWrapper) incRequests(elapsed time.Duration, method MethodIndex) { } } -func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error { +func (c *clientStatusMonitor) handleError(ctx context.Context, st apistatus.Status, err error) error { if err != nil { - if needCountError(err) { + if needCountError(ctx, err) { c.incErrorRate() } @@ -1016,7 +1014,7 @@ func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error return err } -func needCountError(err error) bool { +func needCountError(ctx context.Context, err error) bool { // non-status logic error that could be returned // from the SDK client; should not be considered // as a connection error @@ -1025,9 +1023,11 @@ func needCountError(err error) bool { return false } - // we can't use errors.Is(err, context.Canceled) - // https://github.com/grpc/grpc-go/issues/4375 - return status.Code(err) != codes.Canceled + if errors.Is(ctx.Err(), context.Canceled) { + return false + } + + return true } // clientBuilder is a type alias of client constructors which open connection diff --git a/pool/pool_test.go b/pool/pool_test.go index c0789c9a..84f664dd 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -527,78 +527,100 @@ func TestStatusMonitor(t *testing.T) { } func TestHandleError(t *testing.T) { + ctx := context.Background() monitor := newClientStatusMonitor(zap.NewExample(), "", 10) + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + for i, tc := range []struct { + ctx context.Context status apistatus.Status err error expectedError bool countError bool }{ { + ctx: ctx, status: nil, err: nil, expectedError: false, countError: false, }, { + ctx: ctx, status: apistatus.SuccessDefaultV2{}, err: nil, expectedError: false, countError: false, }, { + ctx: ctx, status: apistatus.SuccessDefaultV2{}, err: errors.New("error"), expectedError: true, countError: true, }, { + ctx: ctx, status: nil, err: errors.New("error"), expectedError: true, countError: true, }, { + ctx: ctx, status: apistatus.ObjectNotFound{}, err: nil, expectedError: true, countError: false, }, { + ctx: ctx, status: apistatus.ServerInternal{}, err: nil, expectedError: true, countError: true, }, { + ctx: ctx, status: apistatus.WrongMagicNumber{}, err: nil, expectedError: true, countError: true, }, { + ctx: ctx, status: apistatus.SignatureVerification{}, err: nil, expectedError: true, countError: true, }, { + ctx: ctx, status: &apistatus.SignatureVerification{}, err: nil, expectedError: true, countError: true, }, { + ctx: ctx, status: apistatus.NodeUnderMaintenance{}, err: nil, expectedError: true, countError: true, }, + { + ctx: canceledCtx, + status: nil, + err: errors.New("error"), + expectedError: true, + countError: false, + }, } { t.Run(strconv.Itoa(i), func(t *testing.T) { errCount := monitor.currentErrorRate() - err := monitor.handleError(tc.status, tc.err) + err := monitor.handleError(tc.ctx, tc.status, tc.err) if tc.expectedError { require.Error(t, err) } else {