package qos_test import ( "context" "errors" "fmt" "testing" "git.frostfs.info/TrueCloudLab/frostfs-node/internal/qos" "git.frostfs.info/TrueCloudLab/frostfs-qos/limiting" "git.frostfs.info/TrueCloudLab/frostfs-qos/tagging" apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" "github.com/stretchr/testify/require" "google.golang.org/grpc" ) const ( okKey = "ok" ) var ( errTest = errors.New("mock") errWrongTag = errors.New("wrong tag") errNoTag = errors.New("failed to get tag from context") errResExhausted = new(apistatus.ResourceExhausted) tags = []qos.IOTag{qos.IOTagBackground, qos.IOTagWritecache, qos.IOTagPolicer, qos.IOTagTreeSync} ) type mockGRPCServerStream struct { grpc.ServerStream ctx context.Context } func (m *mockGRPCServerStream) Context() context.Context { return m.ctx } type limiter struct { released bool } func (l *limiter) Acquire(key string) (limiting.ReleaseFunc, bool) { if key != okKey { return nil, false } return func() { l.released = true }, true } func unaryMaxActiveRPCLimiter(ctx context.Context, lim *limiter, methodName string) (bool, error) { interceptor := qos.NewMaxActiveRPCLimiterUnaryServerInterceptor(func() limiting.Limiter { return lim }) called := false handler := func(ctx context.Context, req any) (any, error) { called = true return nil, errTest } _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{FullMethod: methodName}, handler) return called, err } func streamMaxActiveRPCLimiter(ctx context.Context, lim *limiter, methodName string) (bool, error) { interceptor := qos.NewMaxActiveRPCLimiterStreamServerInterceptor(func() limiting.Limiter { return lim }) called := false handler := func(srv any, stream grpc.ServerStream) error { called = true return errTest } err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, &grpc.StreamServerInfo{ FullMethod: methodName, }, handler) return called, err } func Test_MaxActiveRPCLimiter(t *testing.T) { // UnaryServerInterceptor t.Run("unary fail", func(t *testing.T) { var lim limiter called, err := unaryMaxActiveRPCLimiter(context.Background(), &lim, "") require.EqualError(t, err, errResExhausted.Error()) require.False(t, called) }) t.Run("unary pass critical", func(t *testing.T) { var lim limiter ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String()) called, err := unaryMaxActiveRPCLimiter(ctx, &lim, "") require.EqualError(t, err, errTest.Error()) require.True(t, called) require.False(t, lim.released) }) t.Run("unary pass", func(t *testing.T) { var lim limiter called, err := unaryMaxActiveRPCLimiter(context.Background(), &lim, okKey) require.EqualError(t, err, errTest.Error()) require.True(t, called && lim.released) }) // StreamServerInterceptor t.Run("stream fail", func(t *testing.T) { var lim limiter called, err := streamMaxActiveRPCLimiter(context.Background(), &lim, "") require.EqualError(t, err, errResExhausted.Error()) require.False(t, called) }) t.Run("stream pass critical", func(t *testing.T) { var lim limiter ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String()) called, err := streamMaxActiveRPCLimiter(ctx, &lim, "") require.EqualError(t, err, errTest.Error()) require.True(t, called) require.False(t, lim.released) }) t.Run("stream pass", func(t *testing.T) { var lim limiter called, err := streamMaxActiveRPCLimiter(context.Background(), &lim, okKey) require.EqualError(t, err, errTest.Error()) require.True(t, called && lim.released) }) } func TestSetCriticalIOTagUnaryServerInterceptor_Pass(t *testing.T) { interceptor := qos.NewSetCriticalIOTagUnaryServerInterceptor() handler := func(ctx context.Context, req any) (any, error) { if tag, ok := tagging.IOTagFromContext(ctx); ok && tag == qos.IOTagCritical.String() { return nil, nil } return nil, errWrongTag } _, err := interceptor(context.Background(), nil, nil, handler) require.NoError(t, err) } func TestAdjustOutgoingIOTagUnaryClientInterceptor(t *testing.T) { interceptor := qos.NewAdjustOutgoingIOTagUnaryClientInterceptor() // check context with no value called := false invoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { called = true if _, ok := tagging.IOTagFromContext(ctx); ok { return fmt.Errorf("%v: expected no IO tags", errWrongTag) } return nil } require.NoError(t, interceptor(context.Background(), "", nil, nil, nil, invoker, nil)) require.True(t, called) // check context for internal tag targetTag := qos.IOTagInternal.String() invoker = func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { raw, ok := tagging.IOTagFromContext(ctx) if !ok { return errNoTag } if raw != targetTag { return errWrongTag } return nil } for _, tag := range tags { ctx := tagging.ContextWithIOTag(context.Background(), tag.String()) require.NoError(t, interceptor(ctx, "", nil, nil, nil, invoker, nil)) } // check context for client tag ctx := tagging.ContextWithIOTag(context.Background(), "") targetTag = qos.IOTagClient.String() require.NoError(t, interceptor(ctx, "", nil, nil, nil, invoker, nil)) } func TestAdjustOutgoingIOTagStreamClientInterceptor(t *testing.T) { interceptor := qos.NewAdjustOutgoingIOTagStreamClientInterceptor() // check context with no value called := false streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { called = true if _, ok := tagging.IOTagFromContext(ctx); ok { return nil, fmt.Errorf("%v: expected no IO tags", errWrongTag) } return nil, nil } _, err := interceptor(context.Background(), nil, nil, "", streamer, nil) require.True(t, called) require.NoError(t, err) // check context for internal tag targetTag := qos.IOTagInternal.String() streamer = func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { raw, ok := tagging.IOTagFromContext(ctx) if !ok { return nil, errNoTag } if raw != targetTag { return nil, errWrongTag } return nil, nil } for _, tag := range tags { ctx := tagging.ContextWithIOTag(context.Background(), tag.String()) _, err := interceptor(ctx, nil, nil, "", streamer, nil) require.NoError(t, err) } // check context for client tag ctx := tagging.ContextWithIOTag(context.Background(), "") targetTag = qos.IOTagClient.String() _, err = interceptor(ctx, nil, nil, "", streamer, nil) require.NoError(t, err) }