package transformer

import (
	"context"
	"crypto/rand"
	"crypto/sha256"
	"testing"

	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/version"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/stretchr/testify/require"
)

func TestTransformer(t *testing.T) {
	const maxSize = 100

	tt := new(testTarget)

	target, pk := newPayloadSizeLimiter(maxSize, 0, func() ObjectWriter { return tt })

	cnr := cidtest.ID()
	hdr := newObject(cnr)

	var owner user.ID
	user.IDFromKey(&owner, pk.PrivateKey.PublicKey)
	hdr.SetOwnerID(owner)

	expectedPayload := make([]byte, maxSize*2+maxSize/2)
	_, _ = rand.Read(expectedPayload)

	ids := writeObject(t, context.Background(), target, hdr, expectedPayload)
	require.Equal(t, 4, len(tt.objects)) // 3 parts + linking object

	var actualPayload []byte
	for i := range tt.objects {
		childCnr, ok := tt.objects[i].ContainerID()
		require.True(t, ok)
		require.Equal(t, cnr, childCnr)
		require.Equal(t, objectSDK.TypeRegular, tt.objects[i].Type())
		require.Equal(t, owner, tt.objects[i].OwnerID())

		payload := tt.objects[i].Payload()
		require.EqualValues(t, tt.objects[i].PayloadSize(), len(payload))
		actualPayload = append(actualPayload, payload...)

		if len(payload) != 0 {
			cs, ok := tt.objects[i].PayloadChecksum()
			require.True(t, ok)

			h := sha256.Sum256(payload)
			require.Equal(t, h[:], cs.Value())
		}

		require.True(t, tt.objects[i].VerifyIDSignature())
		switch i {
		case 0, 1:
			require.EqualValues(t, maxSize, len(payload))
			require.Nil(t, tt.objects[i].Parent())
		case 2:
			require.EqualValues(t, maxSize/2, len(payload))
			parent := tt.objects[i].Parent()
			require.NotNil(t, parent)
			require.Nil(t, parent.SplitID())
			require.True(t, parent.VerifyIDSignature())
		case 3:
			parID, ok := tt.objects[i].ParentID()
			require.True(t, ok)
			require.Equal(t, ids.ParentID, &parID)

			children := tt.objects[i].Children()
			for j := range i {
				id, ok := tt.objects[j].ID()
				require.True(t, ok)
				require.Equal(t, id, children[j])
			}
		}
	}
	require.Equal(t, expectedPayload, actualPayload)

	t.Run("parent checksum", func(t *testing.T) {
		cs, ok := ids.ParentHeader.PayloadChecksum()
		require.True(t, ok)

		h := sha256.Sum256(expectedPayload)
		require.Equal(t, h[:], cs.Value())
	})
}

func newObject(cnr cid.ID) *objectSDK.Object {
	ver := version.Current()
	hdr := objectSDK.New()
	hdr.SetContainerID(cnr)
	hdr.SetType(objectSDK.TypeRegular)
	hdr.SetVersion(&ver)
	return hdr
}

func writeObject(t *testing.T, ctx context.Context, target ChunkedObjectWriter, header *objectSDK.Object, payload []byte) *AccessIdentifiers {
	require.NoError(t, target.WriteHeader(ctx, header))

	_, err := target.Write(ctx, payload)
	require.NoError(t, err)

	ids, err := target.Close(ctx)
	require.NoError(t, err)

	return ids
}

func BenchmarkTransformer(b *testing.B) {
	hdr := newObject(cidtest.ID())

	const (
		// bufferSize is taken from https://git.frostfs.info/TrueCloudLab/frostfs-sdk-go/src/commit/670619d2426fee233a37efe21a0471989b16a4fc/pool/pool.go#L1825
		bufferSize = 3 * 1024 * 1024
		smallSize  = 8 * 1024
		bigSize    = 64 * 1024 * 1024 * 9 / 2 // 4.5 parts
	)
	b.Run("small", func(b *testing.B) {
		b.Run("no size hint", func(b *testing.B) {
			benchmarkTransformer(b, hdr, smallSize, 0, 0)
		})
		b.Run("no size hint, with buffer", func(b *testing.B) {
			benchmarkTransformer(b, hdr, smallSize, 0, bufferSize)
		})
		b.Run("with size hint, with buffer", func(b *testing.B) {
			benchmarkTransformer(b, hdr, smallSize, smallSize, bufferSize)
		})
	})
	b.Run("big", func(b *testing.B) {
		b.Run("no size hint", func(b *testing.B) {
			benchmarkTransformer(b, hdr, bigSize, 0, 0)
		})
		b.Run("no size hint, with buffer", func(b *testing.B) {
			benchmarkTransformer(b, hdr, bigSize, 0, bufferSize)
		})
		b.Run("with size hint, with buffer", func(b *testing.B) {
			benchmarkTransformer(b, hdr, bigSize, bigSize, bufferSize)
		})
	})
}

func benchmarkTransformer(b *testing.B, header *objectSDK.Object, payloadSize, sizeHint, bufferSize int) {
	const maxSize = 64 * 1024 * 1024

	payload := make([]byte, payloadSize)
	ctx := context.Background()

	b.ReportAllocs()
	b.ResetTimer()
	for range b.N {
		f, _ := newPayloadSizeLimiter(maxSize, uint64(sizeHint), func() ObjectWriter { return benchTarget{} })
		if err := f.WriteHeader(ctx, header); err != nil {
			b.Fatalf("write header: %v", err)
		}
		if bufferSize == 0 {
			if _, err := f.Write(ctx, payload); err != nil {
				b.Fatalf("write: %v", err)
			}
		} else {
			j := 0
			for ; j+bufferSize < payloadSize; j += bufferSize {
				if _, err := f.Write(ctx, payload[j:j+bufferSize]); err != nil {
					b.Fatalf("write: %v", err)
				}
			}
			if _, err := f.Write(ctx, payload[j:payloadSize]); err != nil {
				b.Fatalf("write: %v", err)
			}
		}
		if _, err := f.Close(ctx); err != nil {
			b.Fatalf("close: %v", err)
		}
	}
}

func newPayloadSizeLimiter(maxSize uint64, sizeHint uint64, nextTarget TargetInitializer) (ChunkedObjectWriter, *keys.PrivateKey) {
	p, err := keys.NewPrivateKey()
	if err != nil {
		panic(err)
	}

	return NewPayloadSizeLimiter(Params{
		Key:                    &p.PrivateKey,
		NextTargetInit:         nextTarget,
		NetworkState:           dummyEpochSource(123),
		MaxSize:                maxSize,
		SizeHint:               sizeHint,
		WithoutHomomorphicHash: true,
	}), p
}

type dummyEpochSource uint64

func (s dummyEpochSource) CurrentEpoch() uint64 {
	return uint64(s)
}

type benchTarget struct{}

func (benchTarget) WriteObject(context.Context, *objectSDK.Object) error {
	return nil
}

type testTarget struct {
	objects []*objectSDK.Object
}

func (tt *testTarget) WriteObject(_ context.Context, o *objectSDK.Object) error {
	tt.objects = append(tt.objects, o)
	return nil // AccessIdentifiers should not be used.
}