package messagetest

import (
	"encoding/json"
	"errors"
	"fmt"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/rpc/message"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/util/proto/encoding"
	"github.com/stretchr/testify/require"
)

type jsonMessage interface {
	json.Marshaler
	json.Unmarshaler
}

type binaryMessage interface {
	StableMarshal([]byte) []byte
	StableSize() int
	Unmarshal([]byte) error
}

func TestRPCMessage(t *testing.T, msgGens ...func(empty bool) message.Message) {
	for _, msgGen := range msgGens {
		msg := msgGen(false)

		t.Run(fmt.Sprintf("convert_%T", msg), func(t *testing.T) {
			msg := msgGen(false)

			err := msg.FromGRPCMessage(100)

			require.True(t, errors.As(err, new(message.ErrUnexpectedMessageType)))

			msg2 := msgGen(true)

			err = msg2.FromGRPCMessage(msg.ToGRPCMessage())
			require.NoError(t, err)

			require.Equal(t, msg, msg2)
		})

		t.Run("encoding", func(t *testing.T) {
			if jm, ok := msg.(jsonMessage); ok {
				t.Run(fmt.Sprintf("JSON_%T", msg), func(t *testing.T) {
					data, err := jm.MarshalJSON()
					require.NoError(t, err)

					jm2 := msgGen(true).(jsonMessage)
					require.NoError(t, jm2.UnmarshalJSON(data))

					require.Equal(t, jm, jm2)
				})
			}

			if bm, ok := msg.(binaryMessage); ok {
				t.Run(fmt.Sprintf("%T.StableSize() does no allocations", bm), func(t *testing.T) {
					require.Zero(t, testing.AllocsPerRun(1000, func() {
						_ = bm.StableSize()
					}))
				})
				t.Run(fmt.Sprintf("Binary_%T", msg), func(t *testing.T) {
					data := bm.StableMarshal(nil)

					bm2 := msgGen(true).(binaryMessage)
					require.NoError(t, bm2.Unmarshal(data))

					require.Equal(t, bm, bm2)
				})
			}
			t.Run("compatibility", func(t *testing.T) {
				testCompatibility(t, msgGen)
			})
		})
	}
}

func testCompatibility(t *testing.T, msgGen func(empty bool) message.Message) {
	compareBinary := func(t *testing.T, msg message.Message) {
		am, ok := msg.(binaryMessage)
		if !ok {
			t.Skip()
		}

		a := am.StableMarshal(nil)
		b := msg.ToGRPCMessage().(encoding.ProtoMarshaler).MarshalProtobuf(nil)
		if len(a) == 0 {
			require.Empty(t, b)
		} else {
			require.Equal(t, a, b)
		}
	}
	compareJSON := func(t *testing.T, msg message.Message) {
		am, ok := msg.(jsonMessage)
		if !ok {
			t.Skip()
		}

		a, err := am.MarshalJSON()
		require.NoError(t, err)

		b, err := json.Marshal(msg.ToGRPCMessage())
		require.NoError(t, err)

		require.JSONEq(t, string(a), string(b))
	}
	t.Run("empty", func(t *testing.T) {
		msg := msgGen(true)
		t.Run(fmt.Sprintf("Binary_%T", msg), func(t *testing.T) {
			compareBinary(t, msg)
		})
		t.Run(fmt.Sprintf("JSON_%T", msg), func(t *testing.T) {
			compareJSON(t, msg)
		})
	})
	t.Run("not empty", func(t *testing.T) {
		msg := msgGen(false)
		t.Run(fmt.Sprintf("Binary_%T", msg), func(t *testing.T) {
			compareBinary(t, msg)
		})
		t.Run(fmt.Sprintf("JSON_%T", msg), func(t *testing.T) {
			compareJSON(t, msg)
		})
	})
}