package qos import ( "context" "errors" "testing" "git.frostfs.info/TrueCloudLab/frostfs-qos/limiting" "git.frostfs.info/TrueCloudLab/frostfs-qos/tagging" apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" "github.com/stretchr/testify/require" "google.golang.org/grpc" ) const ( okKey = "ok" ) var ( errTest = errors.New("mock") errResExhausted = new(apistatus.ResourceExhausted) releaseFunc func() ctxNoTag = ctxKey("") ) type ctxKey string type mockGRPCServerStream struct { grpc.ServerStream ctx context.Context } func (m *mockGRPCServerStream) Context() context.Context { return m.ctx } type limiter struct{} func (l *limiter) Acquire(key string) (limiting.ReleaseFunc, bool) { if key != okKey { return nil, false } return releaseFunc, true } func TestUnaryServerInterceptor_MaxActiveRPCLimiter_Fail(t *testing.T) { interceptor := NewMaxActiveRPCLimiterUnaryServerInterceptor(func() limiting.Limiter { return &limiter{} }) pCtx := context.WithValue(context.Background(), ctxNoTag, true) called := false handler := func(ctx context.Context, req any) (any, error) { called = true return nil, errTest } // fail: get apistatus.ResourceExhausted _, err := interceptor(pCtx, nil, &grpc.UnaryServerInfo{FullMethod: ""}, handler) require.Error(t, err) require.Equal(t, err.Error(), errResExhausted.Error()) require.False(t, called) } func TestUnaryServerInterceptor_MaxActiveRPCLimiter_PassCritical(t *testing.T) { interceptor := NewMaxActiveRPCLimiterUnaryServerInterceptor(func() limiting.Limiter { return &limiter{} }) ctx := tagging.ContextWithIOTag(context.Background(), IOTagCritical.String()) called := false handler := func(ctx context.Context, req any) (any, error) { called = true return nil, errTest } released := false releaseFunc = func() { released = true } _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{FullMethod: ""}, handler) require.EqualError(t, err, errTest.Error()) require.True(t, called) require.False(t, released) } func TestUnaryServerInterceptor_MaxActiveRPCLimiter_Pass(t *testing.T) { interceptor := NewMaxActiveRPCLimiterUnaryServerInterceptor(func() limiting.Limiter { return &limiter{} }) pCtx := context.WithValue(context.Background(), ctxNoTag, true) called := false handler := func(ctx context.Context, req any) (any, error) { called = true return nil, errTest } released := false releaseFunc = func() { released = true } _, err := interceptor(pCtx, nil, &grpc.UnaryServerInfo{FullMethod: okKey}, handler) require.EqualError(t, err, errTest.Error()) require.True(t, called && released) } func TestStreamServerInterceptor_MaxActiveRPCLimiter_Fail(t *testing.T) { ctx := context.WithValue(context.Background(), ctxNoTag, true) interceptor := NewMaxActiveRPCLimiterStreamServerInterceptor(func() limiting.Limiter { return &limiter{} }) called := false handler := func(srv any, stream grpc.ServerStream) error { called = true return errTest } // fail: get apistatus.ResourceExhausted err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, &grpc.StreamServerInfo{ FullMethod: "", }, handler) require.Error(t, err) require.Equal(t, err.Error(), errResExhausted.Error()) require.False(t, called) } func TestStreamServerInterceptor_MaxActiveRPCLimiter_PassCritical(t *testing.T) { interceptor := NewMaxActiveRPCLimiterStreamServerInterceptor(func() limiting.Limiter { return &limiter{} }) ctx := tagging.ContextWithIOTag(context.Background(), IOTagCritical.String()) called := false handler := func(srv any, stream grpc.ServerStream) error { called = true return errTest } released := false releaseFunc = func() { released = true } err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, &grpc.StreamServerInfo{FullMethod: okKey}, handler) require.EqualError(t, err, errTest.Error()) require.True(t, called) require.False(t, released) } func TestStreamServerInterceptor_MaxActiveRPCLimiter_Pass(t *testing.T) { ctx := context.WithValue(context.Background(), ctxNoTag, true) interceptor := NewMaxActiveRPCLimiterStreamServerInterceptor(func() limiting.Limiter { return &limiter{} }) called := false handler := func(srv any, stream grpc.ServerStream) error { called = true return errTest } released := false releaseFunc = func() { released = true } err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, &grpc.StreamServerInfo{FullMethod: okKey}, handler) require.EqualError(t, err, errTest.Error()) require.True(t, called && released) } func TestSetCriticalIOTagUnaryServerInterceptor_Pass(t *testing.T) { interceptor := NewSetCriticalIOTagUnaryServerInterceptor() handler := func(ctx context.Context, req any) (any, error) { if tag, ok := tagging.IOTagFromContext(ctx); ok && tag == IOTagCritical.String() { return nil, nil } return nil, errTest } _, err := interceptor(context.Background(), nil, nil, handler) require.NoError(t, err) }