diff --git a/pool/mock_test.go b/pool/mock_test.go index d5a635a8..583b1a03 100644 --- a/pool/mock_test.go +++ b/pool/mock_test.go @@ -26,11 +26,15 @@ type mockClient struct { errorOnDial bool errorOnCreateSession bool - errorOnEndpointInfo bool + errorOnEndpointInfo error + resOnEndpointInfo netmap.NodeInfo + healthcheckFn func() errorOnNetworkInfo bool stOnGetObject apistatus.Status } +var _ client = (*mockClient)(nil) + func newMockClient(addr string, key ecdsa.PrivateKey) *mockClient { return &mockClient{ key: key, @@ -38,6 +42,16 @@ func newMockClient(addr string, key ecdsa.PrivateKey) *mockClient { } } +func newMockClientHealthy(addr string, key ecdsa.PrivateKey, healthy bool) *mockClient { + m := newMockClient(addr, key) + if healthy { + m.setHealthy() + } else { + m.setUnhealthy() + } + return m +} + func (m *mockClient) setThreshold(threshold uint32) { m.errorThreshold = threshold } @@ -47,11 +61,11 @@ func (m *mockClient) errOnCreateSession() { } func (m *mockClient) errOnEndpointInfo() { - m.errorOnEndpointInfo = true + m.errorOnEndpointInfo = errors.New("error") } func (m *mockClient) errOnNetworkInfo() { - m.errorOnEndpointInfo = true + m.errorOnEndpointInfo = errors.New("error") } func (m *mockClient) errOnDial() { @@ -94,27 +108,32 @@ func (m *mockClient) containerDelete(context.Context, PrmContainerDelete) error return nil } -func (c *mockClient) apeManagerAddChain(ctx context.Context, prm PrmAddAPEChain) error { +func (m *mockClient) apeManagerAddChain(context.Context, PrmAddAPEChain) error { return nil } -func (c *mockClient) apeManagerRemoveChain(ctx context.Context, prm PrmRemoveAPEChain) error { +func (m *mockClient) apeManagerRemoveChain(context.Context, PrmRemoveAPEChain) error { return nil } -func (c *mockClient) apeManagerListChains(ctx context.Context, prm PrmListAPEChains) ([]ape.Chain, error) { +func (m *mockClient) apeManagerListChains(context.Context, PrmListAPEChains) ([]ape.Chain, error) { return []ape.Chain{}, nil } func (m *mockClient) endpointInfo(ctx context.Context, _ prmEndpointInfo) (netmap.NodeInfo, error) { - var ni netmap.NodeInfo - - if m.errorOnEndpointInfo { - return ni, m.handleError(ctx, nil, errors.New("error")) + if m.errorOnEndpointInfo != nil { + return netmap.NodeInfo{}, m.handleError(ctx, nil, m.errorOnEndpointInfo) } - ni.SetNetworkEndpoints(m.addr) - return ni, nil + m.resOnEndpointInfo.SetNetworkEndpoints(m.addr) + return m.resOnEndpointInfo, nil +} + +func (m *mockClient) healthcheck(ctx context.Context) (netmap.NodeInfo, error) { + if m.healthcheckFn != nil { + m.healthcheckFn() + } + return m.endpointInfo(ctx, prmEndpointInfo{}) } func (m *mockClient) networkInfo(ctx context.Context, _ prmNetworkInfo) (netmap.NetworkInfo, error) { @@ -190,16 +209,12 @@ func (m *mockClient) dial(context.Context) error { return nil } -func (m *mockClient) restartIfUnhealthy(ctx context.Context) (changed bool, err error) { - _, err = m.endpointInfo(ctx, prmEndpointInfo{}) - healthy := err == nil - changed = healthy != m.isHealthy() - if healthy { - m.setHealthy() - } else { - m.setUnhealthy() +func (m *mockClient) restart(context.Context) error { + if m.errorOnDial { + return errors.New("restart dial error") } - return + + return nil } func (m *mockClient) close() error { diff --git a/pool/pool.go b/pool/pool.go index 950c2f7a..4e76978c 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -57,6 +57,8 @@ type client interface { apeManagerListChains(context.Context, PrmListAPEChains) ([]ape.Chain, error) // see clientWrapper.endpointInfo. endpointInfo(context.Context, prmEndpointInfo) (netmap.NodeInfo, error) + // see clientWrapper.healthcheck. + healthcheck(ctx context.Context) (netmap.NodeInfo, error) // see clientWrapper.networkInfo. networkInfo(context.Context, prmNetworkInfo) (netmap.NetworkInfo, error) // see clientWrapper.netMapSnapshot @@ -82,8 +84,8 @@ type client interface { // see clientWrapper.dial. dial(ctx context.Context) error - // see clientWrapper.restartIfUnhealthy. - restartIfUnhealthy(ctx context.Context) (bool, error) + // see clientWrapper.restart. + restart(ctx context.Context) error // see clientWrapper.close. close() error } @@ -92,10 +94,10 @@ type client interface { type clientStatus interface { // isHealthy checks if the connection can handle requests. isHealthy() bool - // isDialed checks if the connection was created. - isDialed() bool // setUnhealthy marks client as unhealthy. setUnhealthy() + // setHealthy marks client as healthy. + setHealthy() // address return address of endpoint. address() string // currentErrorRate returns current errors rate. @@ -126,15 +128,10 @@ type clientStatusMonitor struct { // values for healthy status of clientStatusMonitor. const ( - // statusUnhealthyOnDial is set when dialing to the endpoint is failed, - // so there is no connection to the endpoint, and pool should not close it - // before re-establishing connection once again. - statusUnhealthyOnDial = iota - // statusUnhealthyOnRequest is set when communication after dialing to the // endpoint is failed due to immediate or accumulated errors, connection is // available and pool should close it before re-establishing connection once again. - statusUnhealthyOnRequest + statusUnhealthyOnRequest = iota // statusHealthy is set when connection is ready to be used by the pool. statusHealthy @@ -233,6 +230,7 @@ func newClientStatusMonitor(logger *zap.Logger, addr string, errorThreshold uint type clientWrapper struct { clientMutex sync.RWMutex client *sdkClient.Client + dialed bool prm wrapperPrm clientStatusMonitor @@ -342,30 +340,17 @@ func (c *clientWrapper) dial(ctx context.Context) error { GRPCDialOptions: c.prm.dialOptions, } - if err = cl.Dial(ctx, prmDial); err != nil { - c.setUnhealthyOnDial() + err = cl.Dial(ctx, prmDial) + c.setDialed(err == nil) + if err != nil { return err } return nil } -// restartIfUnhealthy checks healthy status of client and recreate it if status is unhealthy. -// Indicating if status was changed by this function call and returns error that caused unhealthy status. -func (c *clientWrapper) restartIfUnhealthy(ctx context.Context) (changed bool, err error) { - var wasHealthy bool - if _, err = c.endpointInfo(ctx, prmEndpointInfo{}); err == nil { - return false, nil - } else if !errors.Is(err, errPoolClientUnhealthy) { - wasHealthy = true - } - - // if connection is dialed before, to avoid routine / connection leak, - // pool has to close it and then initialize once again. - if c.isDialed() { - c.scheduleGracefulClose() - } - +// restart recreates and redial inner sdk client. +func (c *clientWrapper) restart(ctx context.Context) error { var cl sdkClient.Client prmInit := sdkClient.PrmInit{ Key: c.prm.key, @@ -381,22 +366,35 @@ func (c *clientWrapper) restartIfUnhealthy(ctx context.Context) (changed bool, e GRPCDialOptions: c.prm.dialOptions, } - if err = cl.Dial(ctx, prmDial); err != nil { - c.setUnhealthyOnDial() - return wasHealthy, err + // if connection is dialed before, to avoid routine / connection leak, + // pool has to close it and then initialize once again. + if c.isDialed() { + c.scheduleGracefulClose() + } + + err := cl.Dial(ctx, prmDial) + c.setDialed(err == nil) + if err != nil { + return err } c.clientMutex.Lock() c.client = &cl c.clientMutex.Unlock() - if _, err = cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}); err != nil { - c.setUnhealthy() - return wasHealthy, err - } + return nil +} - c.setHealthy() - return !wasHealthy, nil +func (c *clientWrapper) isDialed() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.dialed +} + +func (c *clientWrapper) setDialed(dialed bool) { + c.mu.Lock() + c.dialed = dialed + c.mu.Unlock() } func (c *clientWrapper) getClient() (*sdkClient.Client, error) { @@ -654,6 +652,15 @@ func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (ne return netmap.NodeInfo{}, err } + return c.endpointInfoRaw(ctx, cl) +} + +func (c *clientWrapper) healthcheck(ctx context.Context) (netmap.NodeInfo, error) { + cl := c.getClientRaw() + return c.endpointInfoRaw(ctx, cl) +} + +func (c *clientWrapper) endpointInfoRaw(ctx context.Context, cl *sdkClient.Client) (netmap.NodeInfo, error) { start := time.Now() res, err := cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}) c.incRequests(time.Since(start), methodEndpointInfo) @@ -1121,10 +1128,6 @@ func (c *clientStatusMonitor) isHealthy() bool { return c.healthy.Load() == statusHealthy } -func (c *clientStatusMonitor) isDialed() bool { - return c.healthy.Load() != statusUnhealthyOnDial -} - func (c *clientStatusMonitor) setHealthy() { c.healthy.Store(statusHealthy) } @@ -1133,10 +1136,6 @@ func (c *clientStatusMonitor) setUnhealthy() { c.healthy.Store(statusUnhealthyOnRequest) } -func (c *clientStatusMonitor) setUnhealthyOnDial() { - c.healthy.Store(statusUnhealthyOnDial) -} - func (c *clientStatusMonitor) address() string { return c.addr } @@ -1211,6 +1210,9 @@ func (c *clientWrapper) incRequests(elapsed time.Duration, method MethodIndex) { } func (c *clientWrapper) close() error { + if !c.isDialed() { + return nil + } if cl := c.getClientRaw(); cl != nil { return cl.Close() } @@ -2153,7 +2155,9 @@ func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) { // startRebalance runs loop to monitor connection healthy status. func (p *Pool) startRebalance(ctx context.Context) { - ticker := time.NewTimer(p.rebalanceParams.clientRebalanceInterval) + ticker := time.NewTicker(p.rebalanceParams.clientRebalanceInterval) + defer ticker.Stop() + buffers := make([][]float64, len(p.rebalanceParams.nodesParams)) for i, params := range p.rebalanceParams.nodesParams { buffers[i] = make([]float64, len(params.weights)) @@ -2203,7 +2207,7 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout) defer c() - changed, err := cli.restartIfUnhealthy(tctx) + changed, err := restartIfUnhealthy(tctx, cli) healthy := err == nil if healthy { bufferWeights[j] = options.nodesParams[i].weights[j] @@ -2234,6 +2238,43 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights } } +// restartIfUnhealthy checks healthy status of client and recreate it if status is unhealthy. +// Indicating if status was changed by this function call and returns error that caused unhealthy status. +func restartIfUnhealthy(ctx context.Context, c client) (changed bool, err error) { + defer func() { + if err != nil { + c.setUnhealthy() + } else { + c.setHealthy() + } + }() + + wasHealthy := c.isHealthy() + + if res, err := c.healthcheck(ctx); err == nil { + if res.Status().IsMaintenance() { + return wasHealthy, new(apistatus.NodeUnderMaintenance) + } + + return !wasHealthy, nil + } + + if err = c.restart(ctx); err != nil { + return wasHealthy, err + } + + res, err := c.healthcheck(ctx) + if err != nil { + return wasHealthy, err + } + + if res.Status().IsMaintenance() { + return wasHealthy, new(apistatus.NodeUnderMaintenance) + } + + return !wasHealthy, nil +} + func adjustWeights(weights []float64) []float64 { adjusted := make([]float64, len(weights)) sum := 0.0 @@ -3032,9 +3073,7 @@ func (p *Pool) Close() { // close all clients for _, pools := range p.innerPools { for _, cli := range pools.clients { - if cli.isDialed() { - _ = cli.close() - } + _ = cli.close() } } } diff --git a/pool/pool_test.go b/pool/pool_test.go index 9270faf0..1362654b 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -4,11 +4,13 @@ import ( "context" "crypto/ecdsa" "errors" + "math/rand" "testing" "time" apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" frostfsecdsa "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/crypto/ecdsa" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/session" @@ -17,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" ) func TestBuildPoolClientFailed(t *testing.T) { @@ -230,6 +233,179 @@ func TestOneOfTwoFailed(t *testing.T) { } } +func TestUpdateNodesHealth(t *testing.T) { + ctx := context.Background() + key := newPrivateKey(t) + + for _, tc := range []struct { + name string + wasHealthy bool + willHealthy bool + prepareCli func(*mockClient) + }{ + { + name: "healthy, maintenance, unhealthy", + wasHealthy: true, + willHealthy: false, + prepareCli: func(c *mockClient) { c.resOnEndpointInfo.SetStatus(netmap.Maintenance) }, + }, + { + name: "unhealthy, maintenance, unhealthy", + wasHealthy: false, + willHealthy: false, + prepareCli: func(c *mockClient) { c.resOnEndpointInfo.SetStatus(netmap.Maintenance) }, + }, + { + name: "healthy, no error, healthy", + wasHealthy: true, + willHealthy: true, + prepareCli: func(c *mockClient) { c.resOnEndpointInfo.SetStatus(netmap.Online) }, + }, + { + name: "unhealthy, no error, healthy", + wasHealthy: false, + willHealthy: true, + prepareCli: func(c *mockClient) { c.resOnEndpointInfo.SetStatus(netmap.Online) }, + }, + { + name: "healthy, error, failed restart, unhealthy", + wasHealthy: true, + willHealthy: false, + prepareCli: func(c *mockClient) { + c.errOnEndpointInfo() + c.errorOnDial = true + }, + }, + { + name: "unhealthy, error, failed restart, unhealthy", + wasHealthy: false, + willHealthy: false, + prepareCli: func(c *mockClient) { + c.errOnEndpointInfo() + c.errorOnDial = true + }, + }, + { + name: "healthy, error, restart, error, unhealthy", + wasHealthy: true, + willHealthy: false, + prepareCli: func(c *mockClient) { c.errOnEndpointInfo() }, + }, + { + name: "unhealthy, error, restart, error, unhealthy", + wasHealthy: false, + willHealthy: false, + prepareCli: func(c *mockClient) { c.errOnEndpointInfo() }, + }, + { + name: "healthy, error, restart, maintenance, unhealthy", + wasHealthy: true, + willHealthy: false, + prepareCli: func(c *mockClient) { + healthError := true + c.healthcheckFn = func() { + if healthError { + c.errorOnEndpointInfo = errors.New("error") + } else { + c.errorOnEndpointInfo = nil + c.resOnEndpointInfo.SetStatus(netmap.Maintenance) + } + healthError = !healthError + } + }, + }, + { + name: "unhealthy, error, restart, maintenance, unhealthy", + wasHealthy: false, + willHealthy: false, + prepareCli: func(c *mockClient) { + healthError := true + c.healthcheckFn = func() { + if healthError { + c.errorOnEndpointInfo = errors.New("error") + } else { + c.errorOnEndpointInfo = nil + c.resOnEndpointInfo.SetStatus(netmap.Maintenance) + } + healthError = !healthError + } + }, + }, + { + name: "healthy, error, restart, healthy", + wasHealthy: true, + willHealthy: true, + prepareCli: func(c *mockClient) { + healthError := true + c.healthcheckFn = func() { + if healthError { + c.errorOnEndpointInfo = errors.New("error") + } else { + c.errorOnEndpointInfo = nil + } + healthError = !healthError + } + }, + }, + { + name: "unhealthy, error, restart, healthy", + wasHealthy: false, + willHealthy: true, + prepareCli: func(c *mockClient) { + healthError := true + c.healthcheckFn = func() { + if healthError { + c.errorOnEndpointInfo = errors.New("error") + } else { + c.errorOnEndpointInfo = nil + } + healthError = !healthError + } + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + cli := newMockClientHealthy("peer0", *key, tc.wasHealthy) + tc.prepareCli(cli) + p, log := newPool(t, cli) + + p.updateNodesHealth(ctx, [][]float64{{1}}) + + changed := tc.wasHealthy != tc.willHealthy + require.Equalf(t, tc.willHealthy, cli.isHealthy(), "healthy status should be: %v", tc.willHealthy) + require.Equalf(t, changed, 1 == log.Len(), "healthy status should be changed: %v", changed) + }) + } +} + +func newPool(t *testing.T, cli *mockClient) (*Pool, *observer.ObservedLogs) { + log, observedLog := getObservedLogger() + + cache, err := newCache(0) + require.NoError(t, err) + + return &Pool{ + innerPools: []*innerPool{{ + sampler: newSampler([]float64{1}, rand.NewSource(0)), + clients: []client{cli}, + }}, + cache: cache, + key: newPrivateKey(t), + closedCh: make(chan struct{}), + rebalanceParams: rebalanceParameters{ + nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}}, + nodeRequestTimeout: time.Second, + clientRebalanceInterval: 200 * time.Millisecond, + }, + logger: log, + }, observedLog +} + +func getObservedLogger() (*zap.Logger, *observer.ObservedLogs) { + loggerCore, observedLog := observer.New(zap.DebugLevel) + return zap.New(loggerCore), observedLog +} + func TestTwoFailed(t *testing.T) { var clientKeys []*ecdsa.PrivateKey mockClientBuilder := func(addr string) client { @@ -529,13 +705,6 @@ func TestStatusMonitor(t *testing.T) { isHealthy bool description string }{ - { - action: func(m *clientStatusMonitor) { m.setUnhealthyOnDial() }, - status: statusUnhealthyOnDial, - isDialed: false, - isHealthy: false, - description: "set unhealthy on dial", - }, { action: func(m *clientStatusMonitor) { m.setUnhealthy() }, status: statusUnhealthyOnRequest, @@ -554,7 +723,6 @@ func TestStatusMonitor(t *testing.T) { for _, tc := range cases { tc.action(&monitor) require.Equal(t, tc.status, monitor.healthy.Load()) - require.Equal(t, tc.isDialed, monitor.isDialed()) require.Equal(t, tc.isHealthy, monitor.isHealthy()) } })