diff --git a/pool/pool.go b/pool/pool.go index 36d63408..950c2f7a 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -1159,6 +1159,16 @@ func (c *clientStatusMonitor) incErrorRate() { } } +func (c *clientStatusMonitor) incErrorRateToUnhealthy(err error) { + c.mu.Lock() + c.currentErrorCount = 0 + c.overallErrorCount++ + c.setUnhealthy() + c.mu.Unlock() + + c.log(zapcore.WarnLevel, "explicitly mark node unhealthy", zap.String("address", c.addr), zap.Error(err)) +} + func (c *clientStatusMonitor) log(level zapcore.Level, msg string, fields ...zap.Field) { if c.logger == nil { return @@ -1225,9 +1235,10 @@ func (c *clientStatusMonitor) handleError(ctx context.Context, st apistatus.Stat switch stErr.(type) { case *apistatus.ServerInternal, *apistatus.WrongMagicNumber, - *apistatus.SignatureVerification, - *apistatus.NodeUnderMaintenance: + *apistatus.SignatureVerification: c.incErrorRate() + case *apistatus.NodeUnderMaintenance: + c.incErrorRateToUnhealthy(stErr) } if err == nil { @@ -1239,7 +1250,11 @@ func (c *clientStatusMonitor) handleError(ctx context.Context, st apistatus.Stat if err != nil { if needCountError(ctx, err) { - c.incErrorRate() + if sdkClient.IsErrNodeUnderMaintenance(err) { + c.incErrorRateToUnhealthy(err) + } else { + c.incErrorRate() + } } return err diff --git a/pool/pool_test.go b/pool/pool_test.go index 0c611151..9270faf0 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/ecdsa" "errors" - "strconv" "testing" "time" @@ -17,6 +16,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/stretchr/testify/require" "go.uber.org/zap" + "go.uber.org/zap/zaptest" ) func TestBuildPoolClientFailed(t *testing.T) { @@ -562,19 +562,22 @@ func TestStatusMonitor(t *testing.T) { func TestHandleError(t *testing.T) { ctx := context.Background() - monitor := newClientStatusMonitor(zap.NewExample(), "", 10) + log := zaptest.NewLogger(t) canceledCtx, cancel := context.WithCancel(context.Background()) cancel() - for i, tc := range []struct { - ctx context.Context - status apistatus.Status - err error - expectedError bool - countError bool + for _, tc := range []struct { + name string + ctx context.Context + status apistatus.Status + err error + expectedError bool + countError bool + markedUnhealthy bool }{ { + name: "no error, no status", ctx: ctx, status: nil, err: nil, @@ -582,6 +585,7 @@ func TestHandleError(t *testing.T) { countError: false, }, { + name: "no error, success status", ctx: ctx, status: new(apistatus.SuccessDefaultV2), err: nil, @@ -589,6 +593,7 @@ func TestHandleError(t *testing.T) { countError: false, }, { + name: "error, success status", ctx: ctx, status: new(apistatus.SuccessDefaultV2), err: errors.New("error"), @@ -596,6 +601,7 @@ func TestHandleError(t *testing.T) { countError: true, }, { + name: "error, no status", ctx: ctx, status: nil, err: errors.New("error"), @@ -603,6 +609,7 @@ func TestHandleError(t *testing.T) { countError: true, }, { + name: "no error, object not found status", ctx: ctx, status: new(apistatus.ObjectNotFound), err: nil, @@ -610,6 +617,7 @@ func TestHandleError(t *testing.T) { countError: false, }, { + name: "object not found error, object not found status", ctx: ctx, status: new(apistatus.ObjectNotFound), err: &apistatus.ObjectNotFound{}, @@ -617,6 +625,7 @@ func TestHandleError(t *testing.T) { countError: false, }, { + name: "eacl not found error, no status", ctx: ctx, status: nil, err: &apistatus.EACLNotFound{}, @@ -627,6 +636,7 @@ func TestHandleError(t *testing.T) { countError: true, }, { + name: "no error, internal status", ctx: ctx, status: new(apistatus.ServerInternal), err: nil, @@ -634,6 +644,7 @@ func TestHandleError(t *testing.T) { countError: true, }, { + name: "no error, wrong magic status", ctx: ctx, status: new(apistatus.WrongMagicNumber), err: nil, @@ -641,6 +652,7 @@ func TestHandleError(t *testing.T) { countError: true, }, { + name: "no error, signature verification status", ctx: ctx, status: new(apistatus.SignatureVerification), err: nil, @@ -648,13 +660,25 @@ func TestHandleError(t *testing.T) { countError: true, }, { - ctx: ctx, - status: new(apistatus.NodeUnderMaintenance), - err: nil, - expectedError: true, - countError: true, + name: "no error, maintenance status", + ctx: ctx, + status: new(apistatus.NodeUnderMaintenance), + err: nil, + expectedError: true, + countError: true, + markedUnhealthy: true, }, { + name: "maintenance error, no status", + ctx: ctx, + status: nil, + err: &apistatus.NodeUnderMaintenance{}, + expectedError: true, + countError: true, + markedUnhealthy: true, + }, + { + name: "no error, invalid argument status", ctx: ctx, status: new(apistatus.InvalidArgument), err: nil, @@ -662,6 +686,7 @@ func TestHandleError(t *testing.T) { countError: false, }, { + name: "context canceled error, no status", ctx: canceledCtx, status: nil, err: errors.New("error"), @@ -669,8 +694,9 @@ func TestHandleError(t *testing.T) { countError: false, }, } { - t.Run(strconv.Itoa(i), func(t *testing.T) { - errCount := monitor.currentErrorRate() + t.Run(tc.name, func(t *testing.T) { + monitor := newClientStatusMonitor(log, "", 10) + errCount := monitor.overallErrorRate() err := monitor.handleError(tc.ctx, tc.status, tc.err) if tc.expectedError { require.Error(t, err) @@ -680,7 +706,10 @@ func TestHandleError(t *testing.T) { if tc.countError { errCount++ } - require.Equal(t, errCount, monitor.currentErrorRate()) + require.Equal(t, errCount, monitor.overallErrorRate()) + if tc.markedUnhealthy { + require.False(t, monitor.isHealthy()) + } }) } }