package proto_test

import (
	"math"
	"math/rand"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/util/proto/test"
	generated "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/util/proto/test/custom"
	"github.com/stretchr/testify/require"
	"google.golang.org/protobuf/encoding/protojson"
	goproto "google.golang.org/protobuf/proto"
)

type protoInt interface {
	~int32 | ~uint32 | ~int64 | ~uint64
}

func nonZero[T protoInt]() T {
	var r T
	for r == 0 {
		r = T(rand.Uint64())
	}
	return r
}

func TestStableMarshalSingle(t *testing.T) {
	t.Run("empty", func(t *testing.T) {
		t.Run("proto", func(t *testing.T) {
			input := &generated.Primitives{}
			require.Zero(t, input.StableSize())

			r := input.MarshalProtobuf(nil)
			require.Empty(t, r)
		})
		t.Run("json", func(t *testing.T) {
			input := &generated.Primitives{}
			r, err := input.MarshalJSON()
			require.NoError(t, err)
			require.NotEmpty(t, r)

			var actual test.Primitives
			require.NoError(t, protojson.Unmarshal(r, &actual))

			t.Run("protojson compatibility", func(t *testing.T) {
				data, err := protojson.MarshalOptions{EmitUnpopulated: true}.Marshal(&actual)
				require.NoError(t, err)
				require.JSONEq(t, string(data), string(r))
			})

			var actualFrostfs generated.Primitives
			require.NoError(t, actualFrostfs.UnmarshalJSON(r))
			require.Equal(t, input, &actualFrostfs)

			primitivesEqual(t, input, &actual)
		})
	})

	marshalCases := []struct {
		name  string
		input *generated.Primitives
	}{
		{name: "bytes", input: &generated.Primitives{FieldA: []byte{1, 2, 3}}},
		{name: "string", input: &generated.Primitives{FieldB: "123"}},
		{name: "bool", input: &generated.Primitives{FieldC: true}},
		{name: "int32", input: &generated.Primitives{FieldD: -10}},
		{name: "uint32", input: &generated.Primitives{FieldE: nonZero[uint32]()}},
		{name: "int64", input: &generated.Primitives{FieldF: nonZero[int64]()}},
		{name: "uint64", input: &generated.Primitives{FieldG: nonZero[uint64]()}},
		{name: "uint64", input: &generated.Primitives{FieldI: nonZero[uint64]()}},
		{name: "float64", input: &generated.Primitives{FieldJ: math.Float64frombits(12345677890)}},
		{name: "fixed32", input: &generated.Primitives{FieldK: nonZero[uint32]()}},
		{name: "enum, positive", input: &generated.Primitives{FieldH: generated.Primitives_POSITIVE}},
		{name: "enum, negative", input: &generated.Primitives{FieldH: generated.Primitives_NEGATIVE}},
		{name: "oneof, first", input: &generated.Primitives{FieldM: &generated.Primitives_FieldMa{FieldMa: []byte{4, 2}}}},
		{name: "oneof, second", input: &generated.Primitives{FieldM: &generated.Primitives_FieldMe{FieldMe: nonZero[uint32]()}}},
	}
	for _, tc := range marshalCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Run("proto", func(t *testing.T) {
				r := tc.input.MarshalProtobuf(nil)
				require.Equal(t, len(r), tc.input.StableSize())
				require.NotEmpty(t, r)

				var actual test.Primitives
				require.NoError(t, goproto.Unmarshal(r, &actual))

				var actualFrostfs generated.Primitives
				require.NoError(t, actualFrostfs.UnmarshalProtobuf(r))
				require.Equal(t, tc.input, &actualFrostfs)

				primitivesEqual(t, tc.input, &actual)
			})
			t.Run("json", func(t *testing.T) {
				r, err := tc.input.MarshalJSON()
				require.NoError(t, err)
				require.NotEmpty(t, r)

				var actual test.Primitives
				require.NoError(t, protojson.Unmarshal(r, &actual))

				t.Run("protojson compatibility", func(t *testing.T) {
					data, err := protojson.MarshalOptions{EmitUnpopulated: true}.Marshal(&actual)
					require.NoError(t, err)
					require.JSONEq(t, string(data), string(r))
				})

				var actualFrostfs generated.Primitives
				require.NoError(t, actualFrostfs.UnmarshalJSON(r))
				require.Equal(t, tc.input, &actualFrostfs)

				primitivesEqual(t, tc.input, &actual)
			})
		})
	}
}

func primitivesEqual(t *testing.T, a *generated.Primitives, b *test.Primitives) {
	// Compare each field directly, because proto-generated code has private fields.
	require.Equal(t, len(a.FieldA), len(b.FieldA))
	require.Equal(t, a.FieldA, b.FieldA)
	require.Equal(t, a.FieldB, b.FieldB)
	require.Equal(t, a.FieldC, b.FieldC)
	require.Equal(t, a.FieldD, b.FieldD)
	require.Equal(t, a.FieldE, b.FieldE)
	require.Equal(t, a.FieldF, b.FieldF)
	require.Equal(t, a.FieldG, b.FieldG)
	require.Equal(t, a.FieldI, b.FieldI)
	require.Equal(t, a.FieldJ, b.FieldJ)
	require.Equal(t, a.FieldK, b.FieldK)
	require.EqualValues(t, a.FieldH, b.FieldH)
	require.Equal(t, a.GetFieldMa(), b.GetFieldMa())
	require.Equal(t, a.GetFieldMe(), b.GetFieldMe())
	require.Equal(t, a.GetFieldAux().GetInnerField(), b.GetFieldAux().GetInnerField())
}

func repPrimitivesEqual(t *testing.T, a *generated.RepPrimitives, b *test.RepPrimitives) {
	// Compare each field directly, because proto-generated code has private fields.
	require.Equal(t, a.FieldA, b.FieldA)
	require.Equal(t, a.FieldB, b.FieldB)
	require.Equal(t, a.FieldC, b.FieldC)
	require.Equal(t, a.FieldD, b.FieldD)
	require.Equal(t, a.FieldE, b.FieldE)
	require.Equal(t, a.FieldF, b.FieldF)
	require.Equal(t, a.FieldFu, b.FieldFu)
	require.Equal(t, len(a.GetFieldAux()), len(b.GetFieldAux()))
	for i := range a.FieldAux {
		require.Equal(t, a.GetFieldAux()[i].GetInnerField(), b.GetFieldAux()[i].GetInnerField())
	}
}

func randIntSlice[T protoInt](n int, includeZero bool) []T {
	r := make([]T, n)
	if n == 0 {
		return r
	}
	for i := range r {
		r[i] = T(rand.Uint64())
	}
	if includeZero {
		r[0] = 0
	}
	return r
}

func uint32SliceToAux(s []uint32) []generated.RepPrimitives_Aux {
	r := make([]generated.RepPrimitives_Aux, len(s))
	for i := range s {
		r[i] = generated.RepPrimitives_Aux{InnerField: s[i]}
	}
	return r
}

func TestStableMarshalRep(t *testing.T) {
	t.Run("empty", func(t *testing.T) {
		marshalCases := []struct {
			name  string
			input *generated.RepPrimitives
		}{
			{name: "default", input: &generated.RepPrimitives{}},
			{name: "bytes", input: &generated.RepPrimitives{FieldA: [][]byte{}}},
			{name: "string", input: &generated.RepPrimitives{FieldB: []string{}}},
			{name: "int32", input: &generated.RepPrimitives{FieldC: []int32{}}},
			{name: "uint32", input: &generated.RepPrimitives{FieldD: []uint32{}}},
			{name: "int64", input: &generated.RepPrimitives{FieldE: []int64{}}},
			{name: "uint64", input: &generated.RepPrimitives{FieldF: []uint64{}}},
			{name: "uint64", input: &generated.RepPrimitives{FieldFu: []uint64{}}},
		}

		for _, tc := range marshalCases {
			t.Run(tc.name, func(t *testing.T) {
				require.Zero(t, tc.input.StableSize())

				r := tc.input.MarshalProtobuf(nil)
				require.Empty(t, r)
			})
		}
	})

	marshalCases := []struct {
		name  string
		input *generated.RepPrimitives
	}{
		{name: "bytes", input: &generated.RepPrimitives{FieldA: [][]byte{{1, 2, 3}}}},
		{name: "string", input: &generated.RepPrimitives{FieldB: []string{"123"}}},
		{name: "int32", input: &generated.RepPrimitives{FieldC: randIntSlice[int32](1, true)}},
		{name: "int32", input: &generated.RepPrimitives{FieldC: randIntSlice[int32](2, true)}},
		{name: "int32", input: &generated.RepPrimitives{FieldC: randIntSlice[int32](2, false)}},
		{name: "uint32", input: &generated.RepPrimitives{FieldD: randIntSlice[uint32](1, true)}},
		{name: "uint32", input: &generated.RepPrimitives{FieldD: randIntSlice[uint32](2, true)}},
		{name: "uint32", input: &generated.RepPrimitives{FieldD: randIntSlice[uint32](2, false)}},
		{name: "int64", input: &generated.RepPrimitives{FieldE: randIntSlice[int64](1, true)}},
		{name: "int64", input: &generated.RepPrimitives{FieldE: randIntSlice[int64](2, true)}},
		{name: "int64", input: &generated.RepPrimitives{FieldE: randIntSlice[int64](2, false)}},
		{name: "uint64", input: &generated.RepPrimitives{FieldF: randIntSlice[uint64](1, true)}},
		{name: "uint64", input: &generated.RepPrimitives{FieldF: randIntSlice[uint64](2, true)}},
		{name: "uint64", input: &generated.RepPrimitives{FieldF: randIntSlice[uint64](2, false)}},
		{name: "uint64", input: &generated.RepPrimitives{FieldFu: randIntSlice[uint64](1, true)}},
		{name: "uint64", input: &generated.RepPrimitives{FieldFu: randIntSlice[uint64](2, true)}},
		{name: "uint64", input: &generated.RepPrimitives{FieldFu: randIntSlice[uint64](2, false)}},
		{name: "message", input: &generated.RepPrimitives{FieldAux: uint32SliceToAux(randIntSlice[uint32](1, true))}},
		{name: "message", input: &generated.RepPrimitives{FieldAux: uint32SliceToAux(randIntSlice[uint32](2, true))}},
		{name: "message", input: &generated.RepPrimitives{FieldAux: uint32SliceToAux(randIntSlice[uint32](2, false))}},
	}
	for _, tc := range marshalCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Run("proto", func(t *testing.T) {
				r := tc.input.MarshalProtobuf(nil)
				require.Equal(t, len(r), tc.input.StableSize())
				require.NotEmpty(t, r)

				var actual test.RepPrimitives
				require.NoError(t, goproto.Unmarshal(r, &actual))
				repPrimitivesEqual(t, tc.input, &actual)
			})
			t.Run("json", func(t *testing.T) {
				r, err := tc.input.MarshalJSON()
				require.NoError(t, err)
				require.NotEmpty(t, r)

				var actual test.RepPrimitives
				require.NoError(t, protojson.Unmarshal(r, &actual))
				repPrimitivesEqual(t, tc.input, &actual)
			})
		})
	}
}