package client

import (
	"bytes"
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"testing"

	v2object "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
	"github.com/stretchr/testify/require"
)

type mockPatchStream struct {
	streamedPayloadPatches []*object.PayloadPatch
}

func (m *mockPatchStream) Write(r *v2object.PatchRequest) error {
	pp := new(object.PayloadPatch)
	pp.FromV2(r.GetBody().GetPatch())

	if r.GetBody().GetPatch() != nil {
		bodyChunk := r.GetBody().GetPatch().Chunk
		pp.Chunk = make([]byte, len(bodyChunk))
		copy(pp.Chunk, bodyChunk)
	}

	m.streamedPayloadPatches = append(m.streamedPayloadPatches, pp)

	return nil
}

func (m *mockPatchStream) Close() error {
	return nil
}

func TestObjectPatcher(t *testing.T) {
	type part struct {
		offset int
		length int
		chunk  string
	}

	for _, test := range []struct {
		name         string
		patchPayload string
		rng          *object.Range
		maxChunkLen  int
		expectParts  []part
	}{
		{
			name:         "no split payload patch",
			patchPayload: "011111",
			rng:          newRange(0, 6),
			maxChunkLen:  defaultGRPCPayloadChunkLen,
			expectParts: []part{
				{
					offset: 0,
					length: 6,
					chunk:  "011111",
				},
			},
		},
		{
			name:         "splitted payload patch",
			patchPayload: "012345",
			rng:          newRange(0, 6),
			maxChunkLen:  2,
			expectParts: []part{
				{
					offset: 0,
					length: 6,
					chunk:  "01",
				},
				{
					offset: 6,
					length: 0,
					chunk:  "23",
				},
				{
					offset: 6,
					length: 0,
					chunk:  "45",
				},
			},
		},
		{
			name:         "splitted payload patch with zero-length subpatches",
			patchPayload: "0123456789!@",
			rng:          newRange(0, 4),
			maxChunkLen:  2,
			expectParts: []part{
				{
					offset: 0,
					length: 4,
					chunk:  "01",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "23",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "45",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "67",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "89",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "!@",
				},
			},
		},
		{
			name:         "splitted payload patch with zero-length subpatches only",
			patchPayload: "0123456789!@",
			rng:          newRange(0, 0),
			maxChunkLen:  2,
			expectParts: []part{
				{
					offset: 0,
					length: 0,
					chunk:  "01",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "23",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "45",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "67",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "89",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "!@",
				},
			},
		},
	} {
		t.Run(test.name, func(t *testing.T) {
			m := &mockPatchStream{}

			pk, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)

			patcher := objectPatcher{
				client:      &Client{},
				stream:      m,
				addr:        oidtest.Address(),
				key:         pk,
				maxChunkLen: test.maxChunkLen,
			}

			success := patcher.PatchHeader(context.Background(), PatchHeaderPrm{})
			require.True(t, success)

			success = patcher.PatchPayload(context.Background(), test.rng, bytes.NewReader([]byte(test.patchPayload)))
			require.True(t, success)

			require.Len(t, m.streamedPayloadPatches, len(test.expectParts)+1)

			// m.streamedPayloadPatches[0] is attribute patch, so skip it
			for i, part := range test.expectParts {
				requireRangeChunk(t, m.streamedPayloadPatches[i+1], part.offset, part.length, part.chunk)
			}
		})
	}
}

func TestRepeatPayloadPatch(t *testing.T) {
	t.Run("no payload patch partioning", func(t *testing.T) {
		m := &mockPatchStream{}

		pk, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)

		const maxChunkLen = 20

		patcher := objectPatcher{
			client:      &Client{},
			stream:      m,
			addr:        oidtest.Address(),
			key:         pk,
			maxChunkLen: maxChunkLen,
		}

		for _, pp := range []struct {
			patchPayload string
			rng          *object.Range
		}{
			{
				patchPayload: "xxxxxxxxxx",
				rng:          newRange(1, 6),
			},
			{
				patchPayload: "yyyyyyyyyy",
				rng:          newRange(5, 9),
			},
			{
				patchPayload: "zzzzzzzzzz",
				rng:          newRange(10, 0),
			},
		} {
			success := patcher.PatchPayload(context.Background(), pp.rng, bytes.NewReader([]byte(pp.patchPayload)))
			require.True(t, success)
		}

		requireRangeChunk(t, m.streamedPayloadPatches[0], 1, 6, "xxxxxxxxxx")
		requireRangeChunk(t, m.streamedPayloadPatches[1], 5, 9, "yyyyyyyyyy")
		requireRangeChunk(t, m.streamedPayloadPatches[2], 10, 0, "zzzzzzzzzz")
	})

	t.Run("payload patch partioning", func(t *testing.T) {
		m := &mockPatchStream{}

		pk, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)

		const maxChunkLen = 5

		patcher := objectPatcher{
			client:      &Client{},
			stream:      m,
			addr:        oidtest.Address(),
			key:         pk,
			maxChunkLen: maxChunkLen,
		}

		for _, pp := range []struct {
			patchPayload string
			rng          *object.Range
		}{
			{
				patchPayload: "xxxxxxxxxx",
				rng:          newRange(1, 6),
			},
			{
				patchPayload: "yyyyyyyyyy",
				rng:          newRange(5, 9),
			},
			{
				patchPayload: "zzzzzzzzzz",
				rng:          newRange(10, 0),
			},
		} {
			success := patcher.PatchPayload(context.Background(), pp.rng, bytes.NewReader([]byte(pp.patchPayload)))
			require.True(t, success)
		}

		requireRangeChunk(t, m.streamedPayloadPatches[0], 1, 6, "xxxxx")
		requireRangeChunk(t, m.streamedPayloadPatches[1], 7, 0, "xxxxx")
		requireRangeChunk(t, m.streamedPayloadPatches[2], 5, 9, "yyyyy")
		requireRangeChunk(t, m.streamedPayloadPatches[3], 14, 0, "yyyyy")
		requireRangeChunk(t, m.streamedPayloadPatches[4], 10, 0, "zzzzz")
		requireRangeChunk(t, m.streamedPayloadPatches[5], 10, 0, "zzzzz")
	})
}

func requireRangeChunk(t *testing.T, pp *object.PayloadPatch, offset, length int, chunk string) {
	require.NotNil(t, pp)
	require.Equal(t, uint64(offset), pp.Range.GetOffset())
	require.Equal(t, uint64(length), pp.Range.GetLength())
	require.Equal(t, []byte(chunk), pp.Chunk)
}

func newRange(offest, length uint64) *object.Range {
	rng := &object.Range{}
	rng.SetOffset(offest)
	rng.SetLength(length)
	return rng
}