package marshal

import (
	"encoding/binary"
	"math"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestMarshalling(t *testing.T) {
	t.Parallel()
	t.Run("slice", func(t *testing.T) {
		t.Parallel()
		t.Run("nil slice", func(t *testing.T) {
			t.Parallel()

			var int64s []int64
			expectedSize := SliceSize(int64s, Int64Size)
			require.Equal(t, 1, expectedSize)
			buf := make([]byte, expectedSize)
			offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			result, offset, err := SliceUnmarshal(buf, 0, Int64Unmarshal)
			require.NoError(t, err)
			require.NoError(t, VerifyUnmarshal(buf, offset))
			require.Nil(t, result)
		})

		t.Run("empty slice", func(t *testing.T) {
			t.Parallel()

			int64s := make([]int64, 0)
			expectedSize := SliceSize(int64s, Int64Size)
			require.Equal(t, 1, expectedSize)
			buf := make([]byte, expectedSize)
			offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			result, offset, err := SliceUnmarshal(buf, 0, Int64Unmarshal)
			require.NoError(t, err)
			require.NoError(t, VerifyUnmarshal(buf, offset))
			require.NotNil(t, result)
			require.Len(t, result, 0)
		})

		t.Run("non empty slice", func(t *testing.T) {
			t.Parallel()

			int64s := make([]int64, 100)
			for i := range int64s {
				int64s[i] = int64(i)
			}
			expectedSize := SliceSize(int64s, Int64Size)
			buf := make([]byte, expectedSize)
			offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			result, offset, err := SliceUnmarshal(buf, 0, Int64Unmarshal)
			require.NoError(t, err)
			require.NoError(t, VerifyUnmarshal(buf, offset))
			require.Equal(t, int64s, result)
		})

		t.Run("corrupted slice size", func(t *testing.T) {
			t.Parallel()

			int64s := make([]int64, 100)
			for i := range int64s {
				int64s[i] = int64(i)
			}
			expectedSize := SliceSize(int64s, Int64Size)
			buf := make([]byte, expectedSize)
			offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			for i := 0; i < binary.MaxVarintLen64; i++ {
				buf[i] = 129
			}

			_, _, err = SliceUnmarshal(buf, 0, Int64Unmarshal)
			var mErr *MarshallerError
			require.ErrorAs(t, err, &mErr)

			for i := 0; i < binary.MaxVarintLen64; i++ {
				buf[i] = 127
			}
			_, _, err = SliceUnmarshal(buf, 0, Int64Unmarshal)
			require.ErrorAs(t, err, &mErr)
		})

		t.Run("corrupted slice item", func(t *testing.T) {
			t.Parallel()

			int64s := make([]int64, 100)
			for i := range int64s {
				int64s[i] = int64(i)
			}
			expectedSize := SliceSize(int64s, Int64Size)
			buf := make([]byte, expectedSize)
			offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			for i := 2; i < binary.MaxVarintLen64+2; i++ {
				buf[i] = 129
			}

			_, _, err = SliceUnmarshal(buf, 0, Int64Unmarshal)
			var mErr *MarshallerError
			require.ErrorAs(t, err, &mErr)
		})

		t.Run("small buffer", func(t *testing.T) {
			t.Parallel()

			int64s := make([]int64, 100)
			for i := range int64s {
				int64s[i] = int64(i)
			}
			buf := make([]byte, 1)
			_, err := SliceMarshal(buf, 0, int64s, Int64Marshal)
			var mErr *MarshallerError
			require.ErrorAs(t, err, &mErr)

			buf = make([]byte, 10)
			_, err = SliceMarshal(buf, 0, int64s, Int64Marshal)
			require.ErrorAs(t, err, &mErr)
		})
	})

	t.Run("int64", func(t *testing.T) {
		t.Parallel()

		t.Run("success", func(t *testing.T) {
			t.Parallel()

			require.Equal(t, 1, Int64Size(0))
			require.Equal(t, binary.MaxVarintLen64, Int64Size(math.MaxInt64))
			require.Equal(t, binary.MaxVarintLen64, Int64Size(math.MinInt64))

			for _, v := range []int64{0, math.MinInt64, math.MaxInt64} {
				size := Int64Size(v)
				buf := make([]byte, size)
				offset, err := Int64Marshal(buf, 0, v)
				require.NoError(t, err)
				require.NoError(t, VerifyMarshal(buf, offset))

				uv, offset, err := Int64Unmarshal(buf, 0)
				require.NoError(t, err)
				require.NoError(t, VerifyUnmarshal(buf, offset))
				require.Equal(t, v, uv)
			}
		})

		t.Run("invalid buffer", func(t *testing.T) {
			t.Parallel()

			var mErr *MarshallerError

			_, err := Int64Marshal([]byte{}, 0, 100500)
			require.ErrorAs(t, err, &mErr)

			_, _, err = Int64Unmarshal(nil, 0)
			require.ErrorAs(t, err, &mErr)
		})

		t.Run("overflow", func(t *testing.T) {
			t.Parallel()

			var mErr *MarshallerError

			var v int64 = math.MaxInt64
			buf := make([]byte, Int64Size(v))
			_, err := Int64Marshal(buf, 0, v)
			require.NoError(t, err)

			buf[9] = 2

			_, _, err = Int64Unmarshal(buf, 0)
			require.ErrorAs(t, err, &mErr)
		})
	})

	t.Run("string", func(t *testing.T) {
		t.Parallel()

		t.Run("success", func(t *testing.T) {
			t.Parallel()
			for _, v := range []string{
				"", "arn:aws:iam::namespace:group/some_group", "$Object:homomorphicHash",
				"native:container/ns/9LPLUFZpEmfidG4n44vi2cjXKXSqWT492tCvLJiJ8W1J",
			} {
				size := StringSize(v)
				buf := make([]byte, size)
				offset, err := StringMarshal(buf, 0, v)
				require.NoError(t, err)
				require.NoError(t, VerifyMarshal(buf, offset))

				uv, offset, err := StringUnmarshal(buf, 0)
				require.NoError(t, err)
				require.NoError(t, VerifyUnmarshal(buf, offset))
				require.Equal(t, v, uv)
			}
		})

		t.Run("invalid buffer", func(t *testing.T) {
			t.Parallel()

			str := "avada kedavra"

			var mErr *MarshallerError
			_, err := StringMarshal(nil, 0, str)
			require.ErrorAs(t, err, &mErr)

			_, _, err = StringUnmarshal(nil, 0)
			require.ErrorAs(t, err, &mErr)

			buf := make([]byte, StringSize(str))
			offset, err := StringMarshal(buf, 0, str)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))
			buf = buf[:len(buf)-1]
			_, _, err = StringUnmarshal(buf, 0)
			require.ErrorAs(t, err, &mErr)
		})
	})

	t.Run("uint8, byte", func(t *testing.T) {
		t.Parallel()

		for _, v := range []byte{0, 8, 16, 32, 64, 128, 255} {
			buf := make([]byte, ByteSize)
			offset, err := ByteMarshal(buf, 0, v)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			ub, offset, err := ByteUnmarshal(buf, 0)
			require.NoError(t, err)
			require.NoError(t, VerifyUnmarshal(buf, offset))
			require.Equal(t, v, ub)

			buf = make([]byte, UInt8Size)
			offset, err = UInt8Marshal(buf, 0, v)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			uu, offset, err := UInt8Unmarshal(buf, 0)
			require.NoError(t, err)
			require.NoError(t, VerifyUnmarshal(buf, offset))
			require.Equal(t, v, uu)
		}
	})

	t.Run("bool", func(t *testing.T) {
		t.Parallel()

		t.Run("success", func(t *testing.T) {
			t.Parallel()
			for _, v := range []bool{false, true} {
				buf := make([]byte, BoolSize)
				offset, err := BoolMarshal(buf, 0, v)
				require.NoError(t, err)
				require.NoError(t, VerifyMarshal(buf, offset))

				ub, offset, err := BoolUnmarshal(buf, 0)
				require.NoError(t, err)
				require.NoError(t, VerifyUnmarshal(buf, offset))
				require.Equal(t, v, ub)
			}
		})

		t.Run("invalid value", func(t *testing.T) {
			t.Parallel()
			buf := make([]byte, BoolSize)
			offset, err := BoolMarshal(buf, 0, true)
			require.NoError(t, err)
			require.NoError(t, VerifyMarshal(buf, offset))

			buf[0] = 2

			_, _, err = BoolUnmarshal(buf, 0)
			var mErr *MarshallerError
			require.ErrorAs(t, err, &mErr)
		})

		t.Run("invalid buffer", func(t *testing.T) {
			t.Parallel()
			var mErr *MarshallerError

			_, err := BoolMarshal(nil, 0, true)
			require.ErrorAs(t, err, &mErr)

			buf := append(make([]byte, BoolSize), 100)
			offset, err := BoolMarshal(buf, 0, true)
			require.NoError(t, err)
			require.ErrorAs(t, VerifyMarshal(buf, offset), &mErr)

			v, offset, err := BoolUnmarshal(buf, 0)
			require.NoError(t, err)
			require.True(t, v)
			require.ErrorAs(t, VerifyUnmarshal(buf, offset), &mErr)

			_, _, err = BoolUnmarshal(nil, 0)
			require.ErrorAs(t, err, &mErr)
		})
	})
}