package qos_test

import (
	"context"
	"errors"
	"fmt"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/qos"
	"git.frostfs.info/TrueCloudLab/frostfs-qos/limiting"
	"git.frostfs.info/TrueCloudLab/frostfs-qos/tagging"
	apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
	"github.com/stretchr/testify/require"
	"google.golang.org/grpc"
)

const (
	okKey = "ok"
)

var (
	errTest         = errors.New("mock")
	errWrongTag     = errors.New("wrong tag")
	errNoTag        = errors.New("failed to get tag from context")
	errResExhausted *apistatus.ResourceExhausted
	tags            = []qos.IOTag{qos.IOTagBackground, qos.IOTagWritecache, qos.IOTagPolicer, qos.IOTagTreeSync}
)

type mockGRPCServerStream struct {
	grpc.ServerStream

	ctx context.Context
}

func (m *mockGRPCServerStream) Context() context.Context {
	return m.ctx
}

type limiter struct {
	acquired bool
	released bool
}

func (l *limiter) Acquire(key string) (limiting.ReleaseFunc, bool) {
	l.acquired = true
	if key != okKey {
		return nil, false
	}
	return func() { l.released = true }, true
}

func unaryMaxActiveRPCLimiter(ctx context.Context, lim *limiter, methodName string) error {
	interceptor := qos.NewMaxActiveRPCLimiterUnaryServerInterceptor(func() limiting.Limiter { return lim })
	handler := func(ctx context.Context, req any) (any, error) {
		return nil, errTest
	}
	_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{FullMethod: methodName}, handler)
	return err
}

func streamMaxActiveRPCLimiter(ctx context.Context, lim *limiter, methodName string) error {
	interceptor := qos.NewMaxActiveRPCLimiterStreamServerInterceptor(func() limiting.Limiter { return lim })
	handler := func(srv any, stream grpc.ServerStream) error {
		return errTest
	}
	err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, &grpc.StreamServerInfo{
		FullMethod: methodName,
	}, handler)
	return err
}

func Test_MaxActiveRPCLimiter(t *testing.T) {
	// UnaryServerInterceptor
	t.Run("unary fail", func(t *testing.T) {
		var lim limiter

		err := unaryMaxActiveRPCLimiter(context.Background(), &lim, "")
		require.ErrorAs(t, err, &errResExhausted)
		require.True(t, lim.acquired)
		require.False(t, lim.released)
	})
	t.Run("unary pass critical", func(t *testing.T) {
		var lim limiter
		ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String())

		err := unaryMaxActiveRPCLimiter(ctx, &lim, "")
		require.ErrorIs(t, err, errTest)
		require.False(t, lim.acquired)
		require.False(t, lim.released)
	})
	t.Run("unary pass", func(t *testing.T) {
		var lim limiter

		err := unaryMaxActiveRPCLimiter(context.Background(), &lim, okKey)
		require.ErrorIs(t, err, errTest)
		require.True(t, lim.acquired)
		require.True(t, lim.released)
	})
	// StreamServerInterceptor
	t.Run("stream fail", func(t *testing.T) {
		var lim limiter

		err := streamMaxActiveRPCLimiter(context.Background(), &lim, "")
		require.ErrorAs(t, err, &errResExhausted)
		require.True(t, lim.acquired)
		require.False(t, lim.released)
	})
	t.Run("stream pass critical", func(t *testing.T) {
		var lim limiter
		ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String())

		err := streamMaxActiveRPCLimiter(ctx, &lim, "")
		require.ErrorIs(t, err, errTest)
		require.False(t, lim.acquired)
		require.False(t, lim.released)
	})
	t.Run("stream pass", func(t *testing.T) {
		var lim limiter

		err := streamMaxActiveRPCLimiter(context.Background(), &lim, okKey)
		require.ErrorIs(t, err, errTest)
		require.True(t, lim.acquired)
		require.True(t, lim.released)
	})
}

func TestSetCriticalIOTagUnaryServerInterceptor_Pass(t *testing.T) {
	interceptor := qos.NewSetCriticalIOTagUnaryServerInterceptor()
	called := false
	handler := func(ctx context.Context, req any) (any, error) {
		called = true
		if tag, ok := tagging.IOTagFromContext(ctx); ok && tag == qos.IOTagCritical.String() {
			return nil, nil
		}
		return nil, errWrongTag
	}
	_, err := interceptor(context.Background(), nil, nil, handler)
	require.NoError(t, err)
	require.True(t, called)
}

func TestAdjustOutgoingIOTagUnaryClientInterceptor(t *testing.T) {
	interceptor := qos.NewAdjustOutgoingIOTagUnaryClientInterceptor()

	// check context with no value
	called := false
	invoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
		called = true
		if _, ok := tagging.IOTagFromContext(ctx); ok {
			return fmt.Errorf("%v: expected no IO tags", errWrongTag)
		}
		return nil
	}
	require.NoError(t, interceptor(context.Background(), "", nil, nil, nil, invoker, nil))
	require.True(t, called)

	// check context for internal tag
	targetTag := qos.IOTagInternal.String()
	invoker = func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
		raw, ok := tagging.IOTagFromContext(ctx)
		if !ok {
			return errNoTag
		}
		if raw != targetTag {
			return errWrongTag
		}
		return nil
	}
	for _, tag := range tags {
		ctx := tagging.ContextWithIOTag(context.Background(), tag.String())
		require.NoError(t, interceptor(ctx, "", nil, nil, nil, invoker, nil))
	}

	// check context for client tag
	ctx := tagging.ContextWithIOTag(context.Background(), "")
	targetTag = qos.IOTagClient.String()
	require.NoError(t, interceptor(ctx, "", nil, nil, nil, invoker, nil))
}

func TestAdjustOutgoingIOTagStreamClientInterceptor(t *testing.T) {
	interceptor := qos.NewAdjustOutgoingIOTagStreamClientInterceptor()

	// check context with no value
	called := false
	streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
		called = true
		if _, ok := tagging.IOTagFromContext(ctx); ok {
			return nil, fmt.Errorf("%v: expected no IO tags", errWrongTag)
		}
		return nil, nil
	}
	_, err := interceptor(context.Background(), nil, nil, "", streamer, nil)
	require.True(t, called)
	require.NoError(t, err)

	// check context for internal tag
	targetTag := qos.IOTagInternal.String()
	streamer = func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
		raw, ok := tagging.IOTagFromContext(ctx)
		if !ok {
			return nil, errNoTag
		}
		if raw != targetTag {
			return nil, errWrongTag
		}
		return nil, nil
	}
	for _, tag := range tags {
		ctx := tagging.ContextWithIOTag(context.Background(), tag.String())
		_, err := interceptor(ctx, nil, nil, "", streamer, nil)
		require.NoError(t, err)
	}

	// check context for client tag
	ctx := tagging.ContextWithIOTag(context.Background(), "")
	targetTag = qos.IOTagClient.String()
	_, err = interceptor(ctx, nil, nil, "", streamer, nil)
	require.NoError(t, err)
}