diff --git a/pool/mock_test.go b/pool/mock_test.go index 9eb14615..69cf0504 100644 --- a/pool/mock_test.go +++ b/pool/mock_test.go @@ -23,6 +23,7 @@ type mockClient struct { key ecdsa.PrivateKey clientStatusMonitor + errorOnDial bool errorOnCreateSession bool errorOnEndpointInfo bool errorOnNetworkInfo bool @@ -52,6 +53,13 @@ func (m *mockClient) errOnNetworkInfo() { m.errorOnEndpointInfo = true } +func (m *mockClient) errOnDial() { + m.errorOnDial = true + m.errOnCreateSession() + m.errOnEndpointInfo() + m.errOnNetworkInfo() +} + func (m *mockClient) statusOnGetObject(st apistatus.Status) { m.stOnGetObject = st } @@ -160,3 +168,22 @@ func (m *mockClient) sessionCreate(context.Context, prmCreateSession) (resCreate sessionKey: v2tok.GetBody().GetSessionKey(), }, nil } + +func (m *mockClient) dial(context.Context) error { + if m.errorOnDial { + return errors.New("dial error") + } + return nil +} + +func (m *mockClient) restartIfUnhealthy(ctx context.Context) (healthy bool, changed bool) { + _, err := m.endpointInfo(ctx, prmEndpointInfo{}) + healthy = err == nil + changed = healthy != m.isHealthy() + if healthy { + m.setHealthy() + } else { + m.setUnhealthy() + } + return +} diff --git a/pool/pool.go b/pool/pool.go index e2bcc900..a17d2d3d 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -70,15 +70,19 @@ type client interface { sessionCreate(context.Context, prmCreateSession) (resCreateSession, error) clientStatus + + // see clientWrapper.dial. + dial(ctx context.Context) error + // see clientWrapper.restartIfUnhealthy. + restartIfUnhealthy(ctx context.Context) (bool, bool) } // clientStatus provide access to some metrics for connection. type clientStatus interface { // isHealthy checks if the connection can handle requests. isHealthy() bool - // setHealthy allows set healthy status for connection. - // It's used to update status during Pool.startRebalance routing. - setHealthy(bool) bool + // setUnhealthy marks client as unhealthy. + setUnhealthy() // address return address of endpoint. address() string // currentErrorRate returns current errors rate. @@ -91,6 +95,9 @@ type clientStatus interface { methodsStatus() []statusSnapshot } +// ErrPoolClientUnhealthy is an error to indicate that client in pool is unhealthy. +var ErrPoolClientUnhealthy = errors.New("pool client unhealthy") + // clientStatusMonitor count error rate and other statistics for connection. type clientStatusMonitor struct { addr string @@ -207,8 +214,10 @@ func (m *methodStatus) incRequests(elapsed time.Duration) { // clientWrapper is used by default, alternative implementations are intended for testing purposes only. type clientWrapper struct { - client sdkClient.Client - key ecdsa.PrivateKey + clientMutex sync.RWMutex + client *sdkClient.Client + prm wrapperPrm + clientStatusMonitor } @@ -219,7 +228,6 @@ type wrapperPrm struct { timeout time.Duration errorThreshold uint32 responseInfoCallback func(sdkClient.ResponseMetaInfo) error - dialCtx context.Context } // setAddress sets endpoint to connect in NeoFS network. @@ -248,44 +256,107 @@ func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo) x.responseInfoCallback = f } -// setDialContext specifies context for client dial. -func (x *wrapperPrm) setDialContext(ctx context.Context) { - x.dialCtx = ctx -} - // newWrapper creates a clientWrapper that implements the client interface. -func newWrapper(prm wrapperPrm) (*clientWrapper, error) { +func newWrapper(prm wrapperPrm) *clientWrapper { + var cl sdkClient.Client var prmInit sdkClient.PrmInit prmInit.SetDefaultPrivateKey(prm.key) prmInit.SetResponseInfoCallback(prm.responseInfoCallback) + cl.Init(prmInit) + res := &clientWrapper{ - key: prm.key, + client: &cl, clientStatusMonitor: newClientStatusMonitor(prm.address, prm.errorThreshold), + prm: prm, } - res.client.Init(prmInit) + return res +} + +// dial establishes a connection to the server from the NeoFS network. +// Returns an error describing failure reason. If failed, the client +// SHOULD NOT be used. +func (c *clientWrapper) dial(ctx context.Context) error { + cl, err := c.getClient() + if err != nil { + return err + } var prmDial sdkClient.PrmDial - prmDial.SetServerURI(prm.address) - prmDial.SetTimeout(prm.timeout) - prmDial.SetContext(prm.dialCtx) + prmDial.SetServerURI(c.prm.address) + prmDial.SetTimeout(c.prm.timeout) + prmDial.SetContext(ctx) - err := res.client.Dial(prmDial) - if err != nil { - return nil, fmt.Errorf("client dial: %w", err) + if err = cl.Dial(prmDial); err != nil { + c.setUnhealthy() + return err } - return res, nil + return nil +} + +// restartIfUnhealthy checks healthy status of client and recreate it if status is unhealthy. +// Return current healthy status and indicating if status was changed by this function call. +func (c *clientWrapper) restartIfUnhealthy(ctx context.Context) (healthy, changed bool) { + var wasHealthy bool + if _, err := c.endpointInfo(ctx, prmEndpointInfo{}); err == nil { + return true, false + } else if !errors.Is(err, ErrPoolClientUnhealthy) { + wasHealthy = true + } + + var cl sdkClient.Client + var prmInit sdkClient.PrmInit + prmInit.SetDefaultPrivateKey(c.prm.key) + prmInit.SetResponseInfoCallback(c.prm.responseInfoCallback) + + cl.Init(prmInit) + + var prmDial sdkClient.PrmDial + prmDial.SetServerURI(c.prm.address) + prmDial.SetTimeout(c.prm.timeout) + prmDial.SetContext(ctx) + + if err := cl.Dial(prmDial); err != nil { + c.setUnhealthy() + return false, wasHealthy + } + + c.clientMutex.Lock() + c.client = &cl + c.clientMutex.Unlock() + + if _, err := cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}); err != nil { + c.setUnhealthy() + return false, wasHealthy + } + + c.setHealthy() + return true, !wasHealthy +} + +func (c *clientWrapper) getClient() (*sdkClient.Client, error) { + c.clientMutex.RLock() + defer c.clientMutex.RUnlock() + if c.isHealthy() { + return c.client, nil + } + return nil, ErrPoolClientUnhealthy } // balanceGet invokes sdkClient.BalanceGet parse response status to error and return result as is. func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) { + cl, err := c.getClient() + if err != nil { + return accounting.Decimal{}, err + } + var cliPrm sdkClient.PrmBalanceGet cliPrm.SetAccount(prm.account) start := time.Now() - res, err := c.client.BalanceGet(ctx, cliPrm) + res, err := cl.BalanceGet(ctx, cliPrm) c.incRequests(time.Since(start), methodBalanceGet) var st apistatus.Status if res != nil { @@ -301,8 +372,13 @@ func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (acco // containerPut invokes sdkClient.ContainerPut parse response status to error and return result as is. // It also waits for the container to appear on the network. func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) (cid.ID, error) { + cl, err := c.getClient() + if err != nil { + return cid.ID{}, err + } + start := time.Now() - res, err := c.client.ContainerPut(ctx, prm.prmClient) + res, err := cl.ContainerPut(ctx, prm.prmClient) c.incRequests(time.Since(start), methodContainerPut) var st apistatus.Status if res != nil { @@ -328,11 +404,16 @@ func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) ( // containerGet invokes sdkClient.ContainerGet parse response status to error and return result as is. func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) (container.Container, error) { + cl, err := c.getClient() + if err != nil { + return container.Container{}, err + } + var cliPrm sdkClient.PrmContainerGet cliPrm.SetContainer(prm.cnrID) start := time.Now() - res, err := c.client.ContainerGet(ctx, cliPrm) + res, err := cl.ContainerGet(ctx, cliPrm) c.incRequests(time.Since(start), methodContainerGet) var st apistatus.Status if res != nil { @@ -347,11 +428,16 @@ func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) ( // containerList invokes sdkClient.ContainerList parse response status to error and return result as is. func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) { + cl, err := c.getClient() + if err != nil { + return nil, err + } + var cliPrm sdkClient.PrmContainerList cliPrm.SetAccount(prm.ownerID) start := time.Now() - res, err := c.client.ContainerList(ctx, cliPrm) + res, err := cl.ContainerList(ctx, cliPrm) c.incRequests(time.Since(start), methodContainerList) var st apistatus.Status if res != nil { @@ -366,6 +452,11 @@ func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) // containerDelete invokes sdkClient.ContainerDelete parse response status to error. // It also waits for the container to be removed from the network. func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDelete) error { + cl, err := c.getClient() + if err != nil { + return err + } + var cliPrm sdkClient.PrmContainerDelete cliPrm.SetContainer(prm.cnrID) if prm.stokenSet { @@ -373,7 +464,7 @@ func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDel } start := time.Now() - res, err := c.client.ContainerDelete(ctx, cliPrm) + res, err := cl.ContainerDelete(ctx, cliPrm) c.incRequests(time.Since(start), methodContainerDelete) var st apistatus.Status if res != nil { @@ -392,11 +483,16 @@ func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDel // containerEACL invokes sdkClient.ContainerEACL parse response status to error and return result as is. func (c *clientWrapper) containerEACL(ctx context.Context, prm PrmContainerEACL) (eacl.Table, error) { + cl, err := c.getClient() + if err != nil { + return eacl.Table{}, err + } + var cliPrm sdkClient.PrmContainerEACL cliPrm.SetContainer(prm.cnrID) start := time.Now() - res, err := c.client.ContainerEACL(ctx, cliPrm) + res, err := cl.ContainerEACL(ctx, cliPrm) c.incRequests(time.Since(start), methodContainerEACL) var st apistatus.Status if res != nil { @@ -412,6 +508,11 @@ func (c *clientWrapper) containerEACL(ctx context.Context, prm PrmContainerEACL) // containerSetEACL invokes sdkClient.ContainerSetEACL parse response status to error. // It also waits for the EACL to appear on the network. func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSetEACL) error { + cl, err := c.getClient() + if err != nil { + return err + } + var cliPrm sdkClient.PrmContainerSetEACL cliPrm.SetTable(prm.table) @@ -420,7 +521,7 @@ func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSe } start := time.Now() - res, err := c.client.ContainerSetEACL(ctx, cliPrm) + res, err := cl.ContainerSetEACL(ctx, cliPrm) c.incRequests(time.Since(start), methodContainerSetEACL) var st apistatus.Status if res != nil { @@ -449,8 +550,13 @@ func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSe // endpointInfo invokes sdkClient.EndpointInfo parse response status to error and return result as is. func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (netmap.NodeInfo, error) { + cl, err := c.getClient() + if err != nil { + return netmap.NodeInfo{}, err + } + start := time.Now() - res, err := c.client.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}) + res, err := cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}) c.incRequests(time.Since(start), methodEndpointInfo) var st apistatus.Status if res != nil { @@ -465,8 +571,13 @@ func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (ne // networkInfo invokes sdkClient.NetworkInfo parse response status to error and return result as is. func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netmap.NetworkInfo, error) { + cl, err := c.getClient() + if err != nil { + return netmap.NetworkInfo{}, err + } + start := time.Now() - res, err := c.client.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{}) + res, err := cl.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{}) c.incRequests(time.Since(start), methodNetworkInfo) var st apistatus.Status if res != nil { @@ -481,6 +592,11 @@ func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netm // objectPut writes object to NeoFS. func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID, error) { + cl, err := c.getClient() + if err != nil { + return oid.ID{}, err + } + var cliPrm sdkClient.PrmObjectPutInit cliPrm.SetCopiesNumber(prm.copiesNumber) if prm.stoken != nil { @@ -494,7 +610,7 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID } start := time.Now() - wObj, err := c.client.ObjectPutInit(ctx, cliPrm) + wObj, err := cl.ObjectPutInit(ctx, cliPrm) c.incRequests(time.Since(start), methodObjectPut) if err = c.handleError(nil, err); err != nil { return oid.ID{}, fmt.Errorf("init writing on API client: %w", err) @@ -559,6 +675,11 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID // objectDelete invokes sdkClient.ObjectDelete parse response status to error. func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) error { + cl, err := c.getClient() + if err != nil { + return err + } + var cliPrm sdkClient.PrmObjectDelete cliPrm.FromContainer(prm.addr.Container()) cliPrm.ByID(prm.addr.Object()) @@ -576,7 +697,7 @@ func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) e } start := time.Now() - res, err := c.client.ObjectDelete(ctx, cliPrm) + res, err := cl.ObjectDelete(ctx, cliPrm) c.incRequests(time.Since(start), methodObjectDelete) var st apistatus.Status if res != nil { @@ -590,6 +711,11 @@ func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) e // objectGet returns reader for object. func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGetObject, error) { + cl, err := c.getClient() + if err != nil { + return ResGetObject{}, err + } + var cliPrm sdkClient.PrmObjectGet cliPrm.FromContainer(prm.addr.Container()) cliPrm.ByID(prm.addr.Object()) @@ -608,7 +734,7 @@ func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGet var res ResGetObject - rObj, err := c.client.ObjectGetInit(ctx, cliPrm) + rObj, err := cl.ObjectGetInit(ctx, cliPrm) if err = c.handleError(nil, err); err != nil { return ResGetObject{}, fmt.Errorf("init object reading on client: %w", err) } @@ -638,6 +764,11 @@ func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGet // objectHead invokes sdkClient.ObjectHead parse response status to error and return result as is. func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (object.Object, error) { + cl, err := c.getClient() + if err != nil { + return object.Object{}, err + } + var cliPrm sdkClient.PrmObjectHead cliPrm.FromContainer(prm.addr.Container()) cliPrm.ByID(prm.addr.Object()) @@ -657,7 +788,7 @@ func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (obje var obj object.Object start := time.Now() - res, err := c.client.ObjectHead(ctx, cliPrm) + res, err := cl.ObjectHead(ctx, cliPrm) c.incRequests(time.Since(start), methodObjectHead) var st apistatus.Status if res != nil { @@ -675,6 +806,11 @@ func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (obje // objectRange returns object range reader. func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRange, error) { + cl, err := c.getClient() + if err != nil { + return ResObjectRange{}, err + } + var cliPrm sdkClient.PrmObjectRange cliPrm.FromContainer(prm.addr.Container()) cliPrm.ByID(prm.addr.Object()) @@ -694,7 +830,7 @@ func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (Re } start := time.Now() - res, err := c.client.ObjectRangeInit(ctx, cliPrm) + res, err := cl.ObjectRangeInit(ctx, cliPrm) c.incRequests(time.Since(start), methodObjectRange) if err = c.handleError(nil, err); err != nil { return ResObjectRange{}, fmt.Errorf("init payload range reading on client: %w", err) @@ -710,6 +846,11 @@ func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (Re // objectSearch invokes sdkClient.ObjectSearchInit parse response status to error and return result as is. func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) (ResObjectSearch, error) { + cl, err := c.getClient() + if err != nil { + return ResObjectSearch{}, err + } + var cliPrm sdkClient.PrmObjectSearch cliPrm.InContainer(prm.cnrID) @@ -727,7 +868,7 @@ func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) ( cliPrm.UseKey(*prm.key) } - res, err := c.client.ObjectSearchInit(ctx, cliPrm) + res, err := cl.ObjectSearchInit(ctx, cliPrm) if err = c.handleError(nil, err); err != nil { return ResObjectSearch{}, fmt.Errorf("init object searching on client: %w", err) } @@ -737,12 +878,17 @@ func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) ( // sessionCreate invokes sdkClient.SessionCreate parse response status to error and return result as is. func (c *clientWrapper) sessionCreate(ctx context.Context, prm prmCreateSession) (resCreateSession, error) { + cl, err := c.getClient() + if err != nil { + return resCreateSession{}, err + } + var cliPrm sdkClient.PrmSessionCreate cliPrm.SetExp(prm.exp) cliPrm.UseKey(prm.key) start := time.Now() - res, err := c.client.SessionCreate(ctx, cliPrm) + res, err := cl.SessionCreate(ctx, cliPrm) c.incRequests(time.Since(start), methodSessionCreate) var st apistatus.Status if res != nil { @@ -762,8 +908,12 @@ func (c *clientStatusMonitor) isHealthy() bool { return c.healthy.Load() } -func (c *clientStatusMonitor) setHealthy(val bool) bool { - return c.healthy.Swap(val) != val +func (c *clientStatusMonitor) setHealthy() { + c.healthy.Store(true) +} + +func (c *clientStatusMonitor) setUnhealthy() { + c.healthy.Store(false) } func (c *clientStatusMonitor) address() string { @@ -776,7 +926,7 @@ func (c *clientStatusMonitor) incErrorRate() { c.currentErrorCount++ c.overallErrorCount++ if c.currentErrorCount >= c.errorThreshold { - c.setHealthy(false) + c.setUnhealthy() c.currentErrorCount = 0 } } @@ -827,11 +977,7 @@ func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error // clientBuilder is a type alias of client constructors which open connection // to the given endpoint. -type clientBuilder = func(endpoint string) (client, error) - -// clientBuilderContext is a type alias of client constructors which open -// connection to the given endpoint using provided context. -type clientBuilderContext = func(ctx context.Context, endpoint string) (client, error) +type clientBuilder = func(endpoint string) client // InitParameters contains values used to initialize connection Pool. type InitParameters struct { @@ -844,7 +990,7 @@ type InitParameters struct { errorThreshold uint32 nodeParams []NodeParam - clientBuilder clientBuilderContext + clientBuilder clientBuilder } // SetKey specifies default key to be used for the protocol communication by default. @@ -894,13 +1040,6 @@ func (x *InitParameters) AddNode(nodeParam NodeParam) { // setClientBuilder sets clientBuilder used for client construction. // Wraps setClientBuilderContext without a context. func (x *InitParameters) setClientBuilder(builder clientBuilder) { - x.setClientBuilderContext(func(_ context.Context, endpoint string) (client, error) { - return builder(endpoint) - }) -} - -// setClientBuilderContext sets clientBuilderContext used for client construction. -func (x *InitParameters) setClientBuilderContext(builder clientBuilderContext) { x.clientBuilder = builder } @@ -1336,7 +1475,7 @@ type Pool struct { cache *sessionCache stokenDuration uint64 rebalanceParams rebalanceParameters - clientBuilder clientBuilderContext + clientBuilder clientBuilder logger *zap.Logger } @@ -1404,22 +1543,26 @@ func (p *Pool) Dial(ctx context.Context) error { for i, params := range p.rebalanceParams.nodesParams { clients := make([]client, len(params.weights)) for j, addr := range params.addresses { - c, err := p.clientBuilder(ctx, addr) - if err != nil { - return err + c := p.clientBuilder(addr) + if err := c.dial(ctx); err != nil { + if p.logger != nil { + p.logger.Warn("failed to build client", zap.String("address", addr), zap.Error(err)) + } } - var healthy bool + var st session.Object - err = initSessionForDuration(ctx, &st, c, p.rebalanceParams.sessionExpirationDuration, *p.key) - if err != nil && p.logger != nil { - p.logger.Warn("failed to create neofs session token for client", - zap.String("Address", addr), - zap.Error(err)) - } else if err == nil { - healthy, atLeastOneHealthy = true, true + err := initSessionForDuration(ctx, &st, c, p.rebalanceParams.sessionExpirationDuration, *p.key) + if err != nil { + c.setUnhealthy() + if p.logger != nil { + p.logger.Warn("failed to create neofs session token for client", + zap.String("address", addr), zap.Error(err)) + } + } else { + atLeastOneHealthy = true _ = p.cache.Put(formCacheKey(addr, p.key), st) } - c.setHealthy(healthy) + clients[j] = c } source := rand.NewSource(time.Now().UnixNano()) @@ -1462,7 +1605,7 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { } if params.isMissingClientBuilder() { - params.setClientBuilderContext(func(ctx context.Context, addr string) (client, error) { + params.setClientBuilder(func(addr string) client { var prm wrapperPrm prm.setAddress(addr) prm.setKey(*params.key) @@ -1472,7 +1615,6 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { cache.updateEpoch(info.Epoch()) return nil }) - prm.setDialContext(ctx) return newWrapper(prm) }) } @@ -1551,29 +1693,23 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights healthyChanged := atomic.NewBool(false) wg := sync.WaitGroup{} - var prmEndpoint prmEndpointInfo - for j, cli := range pool.clients { wg.Add(1) go func(j int, cli client) { defer wg.Done() - ok := true + tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout) defer c() - // TODO (@kirillovdenis) : #283 consider reconnect to the node on failure - if _, err := cli.endpointInfo(tctx, prmEndpoint); err != nil { - ok = false - bufferWeights[j] = 0 - } - - if ok { + healthy, changed := cli.restartIfUnhealthy(tctx) + if healthy { bufferWeights[j] = options.nodesParams[i].weights[j] } else { + bufferWeights[j] = 0 p.cache.DeleteByPrefix(cli.address()) } - if cli.setHealthy(ok) { + if changed { healthyChanged.Store(true) } }(j, cli) @@ -1616,7 +1752,7 @@ func (p *Pool) connection() (client, error) { } func (p *innerPool) connection() (client, error) { - p.lock.RLock() // TODO(@kirillovdenis): #283 consider remove this lock because using client should be thread safe + p.lock.RLock() // need lock because of using p.sampler defer p.lock.RUnlock() if len(p.clients) == 1 { cp := p.clients[0] diff --git a/pool/pool_test.go b/pool/pool_test.go index 6a93989d..aedd6143 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/ecdsa" "errors" - "fmt" "strconv" "testing" "time" @@ -22,15 +21,17 @@ import ( ) func TestBuildPoolClientFailed(t *testing.T) { - clientBuilder := func(string) (client, error) { - return nil, fmt.Errorf("error") + mockClientBuilder := func(addr string) client { + mockCli := newMockClient(addr, *newPrivateKey(t)) + mockCli.errOnDial() + return mockCli } opts := InitParameters{ key: newPrivateKey(t), nodeParams: []NodeParam{{1, "peer0", 1}}, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -39,17 +40,17 @@ func TestBuildPoolClientFailed(t *testing.T) { } func TestBuildPoolCreateSessionFailed(t *testing.T) { - clientBuilder := func(addr string) (client, error) { + clientMockBuilder := func(addr string) client { mockCli := newMockClient(addr, *newPrivateKey(t)) mockCli.errOnCreateSession() - return mockCli, nil + return mockCli } opts := InitParameters{ key: newPrivateKey(t), nodeParams: []NodeParam{{1, "peer0", 1}}, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(clientMockBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -70,17 +71,17 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { } var clientKeys []*ecdsa.PrivateKey - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { key := newPrivateKey(t) clientKeys = append(clientKeys, key) if addr == nodes[0].address { mockCli := newMockClient(addr, *key) mockCli.errOnEndpointInfo() - return mockCli, nil + return mockCli } - return newMockClient(addr, *key), nil + return newMockClient(addr, *key) } log, err := zap.NewProduction() @@ -91,7 +92,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { logger: log, nodeParams: nodes, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) clientPool, err := NewPool(opts) require.NoError(t, err) @@ -122,15 +123,15 @@ func TestBuildPoolZeroNodes(t *testing.T) { func TestOneNode(t *testing.T) { key1 := newPrivateKey(t) - clientBuilder := func(addr string) (client, error) { - return newMockClient(addr, *key1), nil + mockClientBuilder := func(addr string) client { + return newMockClient(addr, *key1) } opts := InitParameters{ key: newPrivateKey(t), nodeParams: []NodeParam{{1, "peer0", 1}}, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -147,10 +148,10 @@ func TestOneNode(t *testing.T) { func TestTwoNodes(t *testing.T) { var clientKeys []*ecdsa.PrivateKey - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { key := newPrivateKey(t) clientKeys = append(clientKeys, key) - return newMockClient(addr, *key), nil + return newMockClient(addr, *key) } opts := InitParameters{ @@ -160,7 +161,7 @@ func TestTwoNodes(t *testing.T) { {1, "peer1", 1}, }, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -191,18 +192,18 @@ func TestOneOfTwoFailed(t *testing.T) { } var clientKeys []*ecdsa.PrivateKey - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { key := newPrivateKey(t) clientKeys = append(clientKeys, key) if addr == nodes[0].address { - return newMockClient(addr, *key), nil + return newMockClient(addr, *key) } mockCli := newMockClient(addr, *key) mockCli.errOnEndpointInfo() mockCli.errOnNetworkInfo() - return mockCli, nil + return mockCli } opts := InitParameters{ @@ -210,7 +211,7 @@ func TestOneOfTwoFailed(t *testing.T) { nodeParams: nodes, clientRebalanceInterval: 200 * time.Millisecond, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -232,12 +233,12 @@ func TestOneOfTwoFailed(t *testing.T) { func TestTwoFailed(t *testing.T) { var clientKeys []*ecdsa.PrivateKey - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { key := newPrivateKey(t) clientKeys = append(clientKeys, key) mockCli := newMockClient(addr, *key) mockCli.errOnEndpointInfo() - return mockCli, nil + return mockCli } opts := InitParameters{ @@ -248,7 +249,7 @@ func TestTwoFailed(t *testing.T) { }, clientRebalanceInterval: 200 * time.Millisecond, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) pool, err := NewPool(opts) require.NoError(t, err) @@ -268,10 +269,10 @@ func TestSessionCache(t *testing.T) { key := newPrivateKey(t) expectedAuthKey := neofsecdsa.PublicKey(key.PublicKey) - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { mockCli := newMockClient(addr, *key) mockCli.statusOnGetObject(apistatus.SessionTokenNotFound{}) - return mockCli, nil + return mockCli } opts := InitParameters{ @@ -281,7 +282,7 @@ func TestSessionCache(t *testing.T) { }, clientRebalanceInterval: 30 * time.Second, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -331,17 +332,17 @@ func TestPriority(t *testing.T) { } var clientKeys []*ecdsa.PrivateKey - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { key := newPrivateKey(t) clientKeys = append(clientKeys, key) if addr == nodes[0].address { mockCli := newMockClient(addr, *key) mockCli.errOnEndpointInfo() - return mockCli, nil + return mockCli } - return newMockClient(addr, *key), nil + return newMockClient(addr, *key) } opts := InitParameters{ @@ -349,7 +350,7 @@ func TestPriority(t *testing.T) { nodeParams: nodes, clientRebalanceInterval: 1500 * time.Millisecond, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -385,8 +386,8 @@ func TestSessionCacheWithKey(t *testing.T) { key := newPrivateKey(t) expectedAuthKey := neofsecdsa.PublicKey(key.PublicKey) - clientBuilder := func(addr string) (client, error) { - return newMockClient(addr, *key), nil + mockClientBuilder := func(addr string) client { + return newMockClient(addr, *key) } opts := InitParameters{ @@ -396,7 +397,7 @@ func TestSessionCacheWithKey(t *testing.T) { }, clientRebalanceInterval: 30 * time.Second, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -424,9 +425,9 @@ func TestSessionCacheWithKey(t *testing.T) { } func TestSessionTokenOwner(t *testing.T) { - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { key := newPrivateKey(t) - return newMockClient(addr, *key), nil + return newMockClient(addr, *key) } opts := InitParameters{ @@ -435,7 +436,7 @@ func TestSessionTokenOwner(t *testing.T) { {1, "peer0", 1}, }, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -620,7 +621,7 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { errorThreshold := 5 var clientKeys []*ecdsa.PrivateKey - clientBuilder := func(addr string) (client, error) { + mockClientBuilder := func(addr string) client { key := newPrivateKey(t) clientKeys = append(clientKeys, key) @@ -628,10 +629,10 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { mockCli := newMockClient(addr, *key) mockCli.setThreshold(uint32(errorThreshold)) mockCli.statusOnGetObject(apistatus.ServerInternal{}) - return mockCli, nil + return mockCli } - return newMockClient(addr, *key), nil + return newMockClient(addr, *key) } opts := InitParameters{ @@ -639,7 +640,7 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { nodeParams: nodes, clientRebalanceInterval: 30 * time.Second, } - opts.setClientBuilder(clientBuilder) + opts.setClientBuilder(mockClientBuilder) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pool/sampler_test.go b/pool/sampler_test.go index c62ccde3..5ea23260 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -2,11 +2,9 @@ package pool import ( "context" - "fmt" "math/rand" "testing" - "github.com/nspcc-dev/neofs-sdk-go/netmap" "github.com/stretchr/testify/require" ) @@ -42,34 +40,6 @@ func TestSamplerStability(t *testing.T) { } } -type clientMock struct { - clientWrapper - name string - err error -} - -func (c *clientMock) endpointInfo(context.Context, prmEndpointInfo) (netmap.NodeInfo, error) { - return netmap.NodeInfo{}, nil -} - -func (c *clientMock) networkInfo(context.Context, prmNetworkInfo) (netmap.NetworkInfo, error) { - return netmap.NetworkInfo{}, nil -} - -func newNetmapMock(name string, needErr bool) *clientMock { - var err error - if needErr { - err = fmt.Errorf("not available") - } - return &clientMock{ - clientWrapper: clientWrapper{ - clientStatusMonitor: newClientStatusMonitor("", 10), - }, - name: name, - err: err, - } -} - func TestHealthyReweight(t *testing.T) { var ( weights = []float64{0.9, 0.1} @@ -80,12 +50,14 @@ func TestHealthyReweight(t *testing.T) { cache, err := newCache() require.NoError(t, err) + client1 := newMockClient(names[0], *newPrivateKey(t)) + client1.errOnDial() + + client2 := newMockClient(names[1], *newPrivateKey(t)) + inner := &innerPool{ sampler: newSampler(weights, rand.NewSource(0)), - clients: []client{ - newNetmapMock(names[0], true), - newNetmapMock(names[1], false), - }, + clients: []client{client1, client2}, } p := &Pool{ innerPools: []*innerPool{inner}, @@ -97,19 +69,19 @@ func TestHealthyReweight(t *testing.T) { // check getting first node connection before rebalance happened connection0, err := p.connection() require.NoError(t, err) - mock0 := connection0.(*clientMock) - require.Equal(t, names[0], mock0.name) + mock0 := connection0.(*mockClient) + require.Equal(t, names[0], mock0.address()) p.updateInnerNodesHealth(context.TODO(), 0, buffer) connection1, err := p.connection() require.NoError(t, err) - mock1 := connection1.(*clientMock) - require.Equal(t, names[1], mock1.name) + mock1 := connection1.(*mockClient) + require.Equal(t, names[1], mock1.address()) // enabled first node again inner.lock.Lock() - inner.clients[0] = newNetmapMock(names[0], false) + inner.clients[0] = newMockClient(names[0], *newPrivateKey(t)) inner.lock.Unlock() p.updateInnerNodesHealth(context.TODO(), 0, buffer) @@ -117,8 +89,8 @@ func TestHealthyReweight(t *testing.T) { connection0, err = p.connection() require.NoError(t, err) - mock0 = connection0.(*clientMock) - require.Equal(t, names[0], mock0.name) + mock0 = connection0.(*mockClient) + require.Equal(t, names[0], mock0.address()) } func TestHealthyNoReweight(t *testing.T) { @@ -132,8 +104,8 @@ func TestHealthyNoReweight(t *testing.T) { inner := &innerPool{ sampler: sampl, clients: []client{ - newNetmapMock(names[0], false), - newNetmapMock(names[1], false), + newMockClient(names[0], *newPrivateKey(t)), + newMockClient(names[1], *newPrivateKey(t)), }, } p := &Pool{