package client

import (
	"bytes"
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"testing"

	v2object "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
	"github.com/stretchr/testify/require"
)

type mockPatchStream struct {
	streamedPayloadPatches []*object.PayloadPatch
}

func (m *mockPatchStream) Write(r *v2object.PatchRequest) error {
	pp := new(object.PayloadPatch)
	pp.FromV2(r.GetBody().GetPatch())

	if r.GetBody().GetPatch() != nil {
		bodyChunk := r.GetBody().GetPatch().Chunk
		pp.Chunk = make([]byte, len(bodyChunk))
		copy(pp.Chunk, bodyChunk)
	}

	m.streamedPayloadPatches = append(m.streamedPayloadPatches, pp)

	return nil
}

func (m *mockPatchStream) Close() error {
	return nil
}

func TestObjectPatcher(t *testing.T) {
	type part struct {
		offset int
		length int
		chunk  string
	}

	for _, test := range []struct {
		name         string
		patchPayload string
		rng          *object.Range
		maxChunkLen  int
		expectParts  []part
	}{
		{
			name:         "no split payload patch",
			patchPayload: "011111",
			rng:          newRange(0, 6),
			maxChunkLen:  defaultGRPCPayloadChunkLen,
			expectParts: []part{
				{
					offset: 0,
					length: 6,
					chunk:  "011111",
				},
			},
		},
		{
			name:         "splitted payload patch",
			patchPayload: "012345",
			rng:          newRange(0, 6),
			maxChunkLen:  2,
			expectParts: []part{
				{
					offset: 0,
					length: 6,
					chunk:  "01",
				},
				{
					offset: 6,
					length: 0,
					chunk:  "23",
				},
				{
					offset: 6,
					length: 0,
					chunk:  "45",
				},
			},
		},
		{
			name:         "splitted payload patch with zero-length subpatches",
			patchPayload: "0123456789!@",
			rng:          newRange(0, 4),
			maxChunkLen:  2,
			expectParts: []part{
				{
					offset: 0,
					length: 4,
					chunk:  "01",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "23",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "45",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "67",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "89",
				},
				{
					offset: 4,
					length: 0,
					chunk:  "!@",
				},
			},
		},
		{
			name:         "splitted payload patch with zero-length subpatches only",
			patchPayload: "0123456789!@",
			rng:          newRange(0, 0),
			maxChunkLen:  2,
			expectParts: []part{
				{
					offset: 0,
					length: 0,
					chunk:  "01",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "23",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "45",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "67",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "89",
				},
				{
					offset: 0,
					length: 0,
					chunk:  "!@",
				},
			},
		},
	} {
		t.Run(test.name, func(t *testing.T) {
			m := &mockPatchStream{}

			pk, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)

			patcher := objectPatcher{
				client:            &Client{},
				stream:            m,
				addr:              oidtest.Address(),
				key:               pk,
				maxChunkLen:       test.maxChunkLen,
				firstPatchPayload: true,
			}

			success := patcher.PatchAttributes(context.Background(), nil, false)
			require.True(t, success)

			success = patcher.PatchPayload(context.Background(), test.rng, bytes.NewReader([]byte(test.patchPayload)))
			require.True(t, success)

			require.Len(t, m.streamedPayloadPatches, len(test.expectParts)+1)

			// m.streamedPayloadPatches[0] is attribute patch, so skip it
			for i, part := range test.expectParts {
				requireRangeChunk(t, m.streamedPayloadPatches[i+1], part.offset, part.length, part.chunk)
			}
		})
	}
}

func requireRangeChunk(t *testing.T, pp *object.PayloadPatch, offset, length int, chunk string) {
	require.NotNil(t, pp)
	require.Equal(t, uint64(offset), pp.Range.GetOffset())
	require.Equal(t, uint64(length), pp.Range.GetLength())
	require.Equal(t, []byte(chunk), pp.Chunk)
}

func newRange(offest, length uint64) *object.Range {
	rng := &object.Range{}
	rng.SetOffset(offest)
	rng.SetLength(length)
	return rng
}