package erasurecode_test

import (
	"context"
	"crypto/rand"
	"math"
	"testing"

	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/transformer"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/version"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/stretchr/testify/require"
)

func TestErasureCodeReconstruct(t *testing.T) {
	const payloadSize = 99
	const dataCount = 3
	const parityCount = 2

	// We would also like to test padding behaviour,
	// so ensure padding is done.
	require.NotZero(t, payloadSize%(dataCount+parityCount))

	pk, err := keys.NewPrivateKey()
	require.NoError(t, err)

	original := newObject(t, payloadSize, pk)

	c, err := erasurecode.NewConstructor(dataCount, parityCount)
	require.NoError(t, err)

	parts, err := c.Split(original, &pk.PrivateKey)
	require.NoError(t, err)

	t.Run("reconstruct header", func(t *testing.T) {
		original := original.CutPayload()
		parts := cloneSlice(parts)
		for i := range parts {
			parts[i] = parts[i].CutPayload()
		}
		t.Run("from data", func(t *testing.T) {
			parts := cloneSlice(parts)
			for i := dataCount; i < dataCount+parityCount; i++ {
				parts[i] = nil
			}
			reconstructed, err := c.ReconstructHeader(parts)
			require.NoError(t, err)
			verifyReconstruction(t, original, reconstructed)
		})
		t.Run("from parity", func(t *testing.T) {
			parts := cloneSlice(parts)
			for i := 0; i < parityCount; i++ {
				parts[i] = nil
			}
			reconstructed, err := c.ReconstructHeader(parts)
			require.NoError(t, err)
			verifyReconstruction(t, original, reconstructed)

			t.Run("not enough shards", func(t *testing.T) {
				parts[parityCount] = nil
				_, err := c.ReconstructHeader(parts)
				require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
			})
		})
		t.Run("only nil parts", func(t *testing.T) {
			parts := make([]*objectSDK.Object, len(parts))
			_, err := c.ReconstructHeader(parts)
			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
		})
		t.Run("missing EC header", func(t *testing.T) {
			parts := cloneSlice(parts)
			parts[0] = deepCopy(t, parts[0])
			parts[0].SetECHeader(nil)

			_, err := c.ReconstructHeader(parts)
			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
		})
		t.Run("invalid index", func(t *testing.T) {
			parts := cloneSlice(parts)
			parts[0] = deepCopy(t, parts[0])

			ec := parts[0].GetECHeader()
			ec.SetIndex(1)
			parts[0].SetECHeader(ec)

			_, err := c.ReconstructHeader(parts)
			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
		})
		t.Run("invalid total", func(t *testing.T) {
			parts := cloneSlice(parts)
			parts[0] = deepCopy(t, parts[0])

			ec := parts[0].GetECHeader()
			ec.SetTotal(uint32(len(parts) + 1))
			parts[0].SetECHeader(ec)

			_, err := c.ReconstructHeader(parts)
			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
		})
		t.Run("inconsistent header length", func(t *testing.T) {
			parts := cloneSlice(parts)
			parts[0] = deepCopy(t, parts[0])

			ec := parts[0].GetECHeader()
			ec.SetHeaderLength(ec.HeaderLength() - 1)
			parts[0].SetECHeader(ec)

			_, err := c.ReconstructHeader(parts)
			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
		})
		t.Run("invalid header length", func(t *testing.T) {
			parts := cloneSlice(parts)
			for i := range parts {
				parts[i] = deepCopy(t, parts[i])

				ec := parts[0].GetECHeader()
				ec.SetHeaderLength(math.MaxUint32)
				parts[0].SetECHeader(ec)
			}

			_, err := c.ReconstructHeader(parts)
			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
		})
	})
	t.Run("reconstruct data", func(t *testing.T) {
		t.Run("from data", func(t *testing.T) {
			parts := cloneSlice(parts)
			for i := dataCount; i < dataCount+parityCount; i++ {
				parts[i] = nil
			}
			reconstructed, err := c.Reconstruct(parts)
			require.NoError(t, err)
			verifyReconstruction(t, original, reconstructed)
		})
		t.Run("from parity", func(t *testing.T) {
			parts := cloneSlice(parts)
			for i := 0; i < parityCount; i++ {
				parts[i] = nil
			}
			reconstructed, err := c.Reconstruct(parts)
			require.NoError(t, err)
			verifyReconstruction(t, original, reconstructed)

			t.Run("not enough shards", func(t *testing.T) {
				parts[parityCount] = nil
				_, err := c.Reconstruct(parts)
				require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
			})
		})
	})
	t.Run("reconstruct parts", func(t *testing.T) {
		// We would like to also test that ReconstructParts doesn't perform
		// excessive work, so ensure this test makes sense.
		require.GreaterOrEqual(t, parityCount, 2)

		t.Run("from data", func(t *testing.T) {
			oldParts := parts
			parts := cloneSlice(parts)
			for i := dataCount; i < dataCount+parityCount; i++ {
				parts[i] = nil
			}

			required := make([]bool, len(parts))
			required[dataCount] = true

			require.NoError(t, c.ReconstructParts(parts, required, nil))

			old := deepCopy(t, oldParts[dataCount])
			old.SetSignature(nil)
			require.Equal(t, old, parts[dataCount])

			for i := dataCount + 1; i < dataCount+parityCount; i++ {
				require.Nil(t, parts[i])
			}
		})
		t.Run("from parity", func(t *testing.T) {
			oldParts := parts
			parts := cloneSlice(parts)
			for i := 0; i < parityCount; i++ {
				parts[i] = nil
			}

			required := make([]bool, len(parts))
			required[0] = true

			require.NoError(t, c.ReconstructParts(parts, required, nil))

			old := deepCopy(t, oldParts[0])
			old.SetSignature(nil)
			require.Equal(t, old, parts[0])

			for i := 1; i < parityCount; i++ {
				require.Nil(t, parts[i])
			}
		})
	})
}

func newObject(t *testing.T, size uint64, pk *keys.PrivateKey) *objectSDK.Object {
	// Use transformer to form object to avoid potential bugs with yet another helper object creation in tests.
	tt := &testTarget{}
	p := transformer.NewPayloadSizeLimiter(transformer.Params{
		Key:                    &pk.PrivateKey,
		NextTargetInit:         func() transformer.ObjectWriter { return tt },
		NetworkState:           dummyEpochSource(123),
		MaxSize:                size + 1,
		WithoutHomomorphicHash: true,
	})
	cnr := cidtest.ID()
	ver := version.Current()
	hdr := objectSDK.New()
	hdr.SetContainerID(cnr)
	hdr.SetType(objectSDK.TypeRegular)
	hdr.SetVersion(&ver)

	var owner user.ID
	user.IDFromKey(&owner, pk.PrivateKey.PublicKey)
	hdr.SetOwnerID(owner)

	var attr objectSDK.Attribute
	attr.SetKey("somekey")
	attr.SetValue("somevalue")
	hdr.SetAttributes(attr)

	expectedPayload := make([]byte, size)
	_, _ = rand.Read(expectedPayload)
	writeObject(t, context.Background(), p, hdr, expectedPayload)
	require.Len(t, tt.objects, 1)
	return tt.objects[0]
}

func writeObject(t *testing.T, ctx context.Context, target transformer.ChunkedObjectWriter, header *objectSDK.Object, payload []byte) *transformer.AccessIdentifiers {
	require.NoError(t, target.WriteHeader(ctx, header))

	_, err := target.Write(ctx, payload)
	require.NoError(t, err)

	ids, err := target.Close(ctx)
	require.NoError(t, err)

	return ids
}

func verifyReconstruction(t *testing.T, original, reconstructed *objectSDK.Object) {
	require.True(t, reconstructed.VerifyIDSignature())
	reconstructed.ToV2().SetMarshalData(nil)
	original.ToV2().SetMarshalData(nil)

	require.Equal(t, original, reconstructed)
}

func deepCopy(t *testing.T, obj *objectSDK.Object) *objectSDK.Object {
	data, err := obj.Marshal()
	require.NoError(t, err)

	res := objectSDK.New()
	require.NoError(t, res.Unmarshal(data))
	return res
}

func cloneSlice[T any](src []T) []T {
	dst := make([]T, len(src))
	copy(dst, src)
	return dst
}

type dummyEpochSource uint64

func (s dummyEpochSource) CurrentEpoch() uint64 {
	return uint64(s)
}

type testTarget struct {
	objects []*objectSDK.Object
}

func (tt *testTarget) WriteObject(_ context.Context, o *objectSDK.Object) error {
	tt.objects = append(tt.objects, o)
	return nil // AccessIdentifiers should not be used.
}