diff --git a/pool/mock_test.go b/pool/mock_test.go index 82ecd03..758b8a9 100644 --- a/pool/mock_test.go +++ b/pool/mock_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" sessionv2 "github.com/nspcc-dev/neofs-api-go/v2/session" "github.com/nspcc-dev/neofs-sdk-go/accounting" + apistatus "github.com/nspcc-dev/neofs-sdk-go/client/status" "github.com/nspcc-dev/neofs-sdk-go/container" cid "github.com/nspcc-dev/neofs-sdk-go/container/id" neofsecdsa "github.com/nspcc-dev/neofs-sdk-go/crypto/ecdsa" @@ -16,30 +17,29 @@ import ( "github.com/nspcc-dev/neofs-sdk-go/object" oid "github.com/nspcc-dev/neofs-sdk-go/object/id" "github.com/nspcc-dev/neofs-sdk-go/session" - "go.uber.org/atomic" ) type mockClient struct { - key ecdsa.PrivateKey - addr string - healthy *atomic.Bool - errorCount *atomic.Uint32 + key ecdsa.PrivateKey + *clientStatusMonitor errorOnCreateSession bool errorOnEndpointInfo bool errorOnNetworkInfo bool - errorOnGetObject error + stOnGetObject apistatus.Status } func newMockClient(addr string, key ecdsa.PrivateKey) *mockClient { return &mockClient{ - key: key, - addr: addr, - healthy: atomic.NewBool(true), - errorCount: atomic.NewUint32(0), + key: key, + clientStatusMonitor: newTestStatusMonitor(addr), } } +func (m *mockClient) setThreshold(threshold uint32) { + m.errorThreshold = threshold +} + func (m *mockClient) errOnCreateSession() { m.errorOnCreateSession = true } @@ -52,8 +52,8 @@ func (m *mockClient) errOnNetworkInfo() { m.errorOnEndpointInfo = true } -func (m *mockClient) errOnGetObject(err error) { - m.errorOnGetObject = err +func (m *mockClient) statusOnGetObject(st apistatus.Status) { + m.stOnGetObject = st } func newToken(key ecdsa.PrivateKey) *session.Object { @@ -95,7 +95,7 @@ func (m *mockClient) containerSetEACL(context.Context, PrmContainerSetEACL) erro func (m *mockClient) endpointInfo(context.Context, prmEndpointInfo) (*netmap.NodeInfo, error) { if m.errorOnEndpointInfo { - return nil, errors.New("error") + return nil, m.handleError(nil, errors.New("error")) } var ni netmap.NodeInfo @@ -105,7 +105,7 @@ func (m *mockClient) endpointInfo(context.Context, prmEndpointInfo) (*netmap.Nod func (m *mockClient) networkInfo(context.Context, prmNetworkInfo) (*netmap.NetworkInfo, error) { if m.errorOnNetworkInfo { - return nil, errors.New("error") + return nil, m.handleError(nil, errors.New("error")) } var ni netmap.NetworkInfo @@ -121,7 +121,12 @@ func (m *mockClient) objectDelete(context.Context, PrmObjectDelete) error { } func (m *mockClient) objectGet(context.Context, PrmObjectGet) (*ResGetObject, error) { - return &ResGetObject{}, m.errorOnGetObject + if m.stOnGetObject == nil { + return &ResGetObject{}, nil + } + + status := apistatus.ErrFromStatus(m.stOnGetObject) + return &ResGetObject{}, m.handleError(status, nil) } func (m *mockClient) objectHead(context.Context, PrmObjectHead) (*object.Object, error) { @@ -138,7 +143,7 @@ func (m *mockClient) objectSearch(context.Context, PrmObjectSearch) (*ResObjectS func (m *mockClient) sessionCreate(context.Context, prmCreateSession) (*resCreateSession, error) { if m.errorOnCreateSession { - return nil, errors.New("error") + return nil, m.handleError(nil, errors.New("error")) } tok := newToken(m.key) @@ -151,23 +156,3 @@ func (m *mockClient) sessionCreate(context.Context, prmCreateSession) (*resCreat sessionKey: v2tok.GetBody().GetSessionKey(), }, nil } - -func (m *mockClient) isHealthy() bool { - return m.healthy.Load() -} - -func (m *mockClient) setHealthy(b bool) bool { - return m.healthy.Swap(b) != b -} - -func (m *mockClient) address() string { - return m.addr -} - -func (m *mockClient) errorRate() uint32 { - return m.errorCount.Load() -} - -func (m *mockClient) resetErrorCounter() { - m.errorCount.Store(0) -} diff --git a/pool/pool.go b/pool/pool.go index 03aaccc..be763cd 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -51,6 +51,10 @@ type client interface { objectSearch(context.Context, PrmObjectSearch) (*ResObjectSearch, error) sessionCreate(context.Context, prmCreateSession) (*resCreateSession, error) + clientStatus +} + +type clientStatus interface { isHealthy() bool setHealthy(bool) bool address() string @@ -58,19 +62,25 @@ type client interface { resetErrorCounter() } +type clientStatusMonitor struct { + addr string + healthy *atomic.Bool + errorCount *atomic.Uint32 + errorThreshold uint32 +} + // clientWrapper is used by default, alternative implementations are intended for testing purposes only. type clientWrapper struct { - client sdkClient.Client - key ecdsa.PrivateKey - addr string - healthy *atomic.Bool - errorCount *atomic.Uint32 + client sdkClient.Client + key ecdsa.PrivateKey + *clientStatusMonitor } type wrapperPrm struct { address string key ecdsa.PrivateKey timeout time.Duration + errorThreshold uint32 responseInfoCallback func(sdkClient.ResponseMetaInfo) error } @@ -86,6 +96,10 @@ func (x *wrapperPrm) setTimeout(timeout time.Duration) { x.timeout = timeout } +func (x *wrapperPrm) setErrorThreshold(threshold uint32) { + x.errorThreshold = threshold +} + func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo) error) { x.responseInfoCallback = f } @@ -97,10 +111,13 @@ func newWrapper(prm wrapperPrm) (*clientWrapper, error) { prmInit.SetResponseInfoCallback(prm.responseInfoCallback) res := &clientWrapper{ - addr: prm.address, - key: prm.key, - healthy: atomic.NewBool(true), - errorCount: atomic.NewUint32(0), + key: prm.key, + clientStatusMonitor: &clientStatusMonitor{ + addr: prm.address, + healthy: atomic.NewBool(true), + errorCount: atomic.NewUint32(0), + errorThreshold: prm.errorThreshold, + }, } res.client.Init(prmInit) @@ -476,27 +493,27 @@ func (c *clientWrapper) sessionCreate(ctx context.Context, prm prmCreateSession) }, nil } -func (c *clientWrapper) isHealthy() bool { +func (c *clientStatusMonitor) isHealthy() bool { return c.healthy.Load() } -func (c *clientWrapper) setHealthy(val bool) bool { +func (c *clientStatusMonitor) setHealthy(val bool) bool { return c.healthy.Swap(val) != val } -func (c *clientWrapper) address() string { +func (c *clientStatusMonitor) address() string { return c.addr } -func (c *clientWrapper) errorRate() uint32 { +func (c *clientStatusMonitor) errorRate() uint32 { return c.errorCount.Load() } -func (c *clientWrapper) resetErrorCounter() { +func (c *clientStatusMonitor) resetErrorCounter() { c.errorCount.Store(0) } -func (c *clientWrapper) handleError(st apistatus.Status, err error) error { +func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error { if err != nil { c.errorCount.Inc() return err @@ -504,10 +521,14 @@ func (c *clientWrapper) handleError(st apistatus.Status, err error) error { err = apistatus.ErrFromStatus(st) switch err.(type) { - case apistatus.ServerInternal, - apistatus.WrongMagicNumber, - apistatus.SignatureVerification: + case apistatus.ServerInternal, *apistatus.ServerInternal, + apistatus.WrongMagicNumber, *apistatus.WrongMagicNumber, + apistatus.SignatureVerification, *apistatus.SignatureVerification: c.errorCount.Inc() + if c.errorCount.Load() >= c.errorThreshold { + c.setHealthy(false) + c.resetErrorCounter() + } } return err @@ -521,6 +542,7 @@ type InitParameters struct { healthcheckTimeout time.Duration clientRebalanceInterval time.Duration sessionExpirationDuration uint64 + errorThreshold uint32 nodeParams []NodeParam clientBuilder func(endpoint string) (client, error) @@ -560,6 +582,11 @@ func (x *InitParameters) SetSessionExpirationDuration(expirationDuration uint64) x.sessionExpirationDuration = expirationDuration } +// SetErrorThreshold specifies the number of errors on connection after which node is considered as unhealthy. +func (x *InitParameters) SetErrorThreshold(threshold uint32) { + x.errorThreshold = threshold +} + // AddNode append information about the node to which you want to connect. func (x *InitParameters) AddNode(nodeParam NodeParam) { x.nodeParams = append(x.nodeParams, nodeParam) @@ -996,6 +1023,7 @@ type innerPool struct { const ( defaultSessionTokenExpirationDuration = 100 // in blocks + defaultErrorThreshold = 100 defaultRebalanceInterval = 25 * time.Second defaultRequestTimeout = 4 * time.Second @@ -1096,6 +1124,10 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { params.sessionExpirationDuration = defaultSessionTokenExpirationDuration } + if params.errorThreshold == 0 { + params.errorThreshold = defaultErrorThreshold + } + if params.clientRebalanceInterval <= 0 { params.clientRebalanceInterval = defaultRebalanceInterval } @@ -1110,6 +1142,7 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { prm.setAddress(addr) prm.setKey(*params.key) prm.setTimeout(params.nodeDialTimeout) + prm.setErrorThreshold(params.errorThreshold) prm.setResponseInfoCallback(func(info sdkClient.ResponseMetaInfo) error { cache.updateEpoch(info.Epoch()) return nil diff --git a/pool/pool_test.go b/pool/pool_test.go index 72bbeae..3b5a507 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -270,7 +270,7 @@ func TestSessionCache(t *testing.T) { clientBuilder := func(addr string) (client, error) { mockCli := newMockClient(addr, *key) - mockCli.errOnGetObject(apistatus.SessionTokenNotFound{}) + mockCli.statusOnGetObject(apistatus.SessionTokenNotFound{}) return mockCli, nil } @@ -508,16 +508,17 @@ func TestWaitPresence(t *testing.T) { }) } -func newTestWrapper(addr string) *clientWrapper { - return &clientWrapper{ - addr: addr, - healthy: atomic.NewBool(true), - errorCount: atomic.NewUint32(0), +func newTestStatusMonitor(addr string) *clientStatusMonitor { + return &clientStatusMonitor{ + addr: addr, + healthy: atomic.NewBool(true), + errorCount: atomic.NewUint32(0), + errorThreshold: 10, } } func TestHandleError(t *testing.T) { - wrapper := newTestWrapper("") + monitor := newTestStatusMonitor("") for i, tc := range []struct { status apistatus.Status @@ -573,10 +574,16 @@ func TestHandleError(t *testing.T) { expectedError: true, countError: true, }, + { + status: &apistatus.SignatureVerification{}, + err: nil, + expectedError: true, + countError: true, + }, } { t.Run(strconv.Itoa(i), func(t *testing.T) { - errCount := wrapper.errorCount.Load() - err := wrapper.handleError(tc.status, tc.err) + errCount := monitor.errorRate() + err := monitor.handleError(tc.status, tc.err) if tc.expectedError { require.Error(t, err) } else { @@ -585,7 +592,61 @@ func TestHandleError(t *testing.T) { if tc.countError { errCount++ } - require.Equal(t, errCount, wrapper.errorCount.Load()) + require.Equal(t, errCount, monitor.errorRate()) }) } } + +func TestSwitchAfterErrorThreshold(t *testing.T) { + nodes := []NodeParam{ + {1, "peer0", 1}, + {2, "peer1", 100}, + } + + errorThreshold := 5 + + var clientKeys []*ecdsa.PrivateKey + clientBuilder := func(addr string) (client, error) { + key := newPrivateKey(t) + clientKeys = append(clientKeys, key) + + if addr == nodes[0].address { + mockCli := newMockClient(addr, *key) + mockCli.setThreshold(uint32(errorThreshold)) + mockCli.statusOnGetObject(apistatus.ServerInternal{}) + return mockCli, nil + } + + return newMockClient(addr, *key), nil + } + + opts := InitParameters{ + key: newPrivateKey(t), + nodeParams: nodes, + clientRebalanceInterval: 30 * time.Second, + clientBuilder: clientBuilder, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pool, err := NewPool(opts) + require.NoError(t, err) + err = pool.Dial(ctx) + require.NoError(t, err) + t.Cleanup(pool.Close) + + for i := 0; i < errorThreshold; i++ { + conn, err := pool.connection() + require.NoError(t, err) + require.Equal(t, nodes[0].address, conn.address()) + _, err = conn.objectGet(ctx, PrmObjectGet{}) + require.Error(t, err) + } + + conn, err := pool.connection() + require.NoError(t, err) + require.Equal(t, nodes[1].address, conn.address()) + _, err = conn.objectGet(ctx, PrmObjectGet{}) + require.NoError(t, err) +} diff --git a/pool/sampler_test.go b/pool/sampler_test.go index 09ffbec..ddd3721 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -8,7 +8,6 @@ import ( "github.com/nspcc-dev/neofs-sdk-go/netmap" "github.com/stretchr/testify/require" - "go.uber.org/atomic" ) func TestSamplerStability(t *testing.T) { @@ -64,9 +63,7 @@ func newNetmapMock(name string, needErr bool) *clientMock { } return &clientMock{ clientWrapper: clientWrapper{ - addr: "", - healthy: atomic.NewBool(true), - errorCount: atomic.NewUint32(0), + clientStatusMonitor: newTestStatusMonitor(""), }, name: name, err: err,