package policer

import (
	"bytes"
	"context"
	"crypto/rand"
	"errors"
	"fmt"
	"sync/atomic"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/container"
	objectcore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/util"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/replicator"
	apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	netmapSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/stretchr/testify/require"
)

func TestECChunkHasValidPlacement(t *testing.T) {
	t.Parallel()
	chunkAddress := oidtest.Address()
	parentID := oidtest.ID()

	var policy netmapSDK.PlacementPolicy
	require.NoError(t, policy.DecodeString("EC 2.1"))

	cnr := &container.Container{}
	cnr.Value.Init()
	cnr.Value.SetPlacementPolicy(policy)
	containerSrc := containerSrc{
		get: func(id cid.ID) (*container.Container, error) {
			if id.Equals(chunkAddress.Container()) {
				return cnr, nil
			}
			return nil, new(apistatus.ContainerNotFound)
		},
	}

	nodes := make([]netmapSDK.NodeInfo, 4)
	for i := range nodes {
		nodes[i].SetPublicKey([]byte{byte(i)})
	}

	placementBuilder := func(cnr cid.ID, obj *oid.ID, p netmapSDK.PlacementPolicy) ([][]netmapSDK.NodeInfo, error) {
		if cnr.Equals(chunkAddress.Container()) && obj.Equals(parentID) {
			return [][]netmapSDK.NodeInfo{nodes}, nil
		}
		return nil, errors.New("unexpected placement build")
	}

	remoteHeadFn := func(_ context.Context, ni netmapSDK.NodeInfo, a oid.Address, raw bool) (*objectSDK.Object, error) {
		require.True(t, raw, "remote header for parent object must be called with raw flag")
		index := int(ni.PublicKey()[0])
		require.True(t, index == 1 || index == 2, "invalid node to get parent header")
		require.True(t, a.Container() == chunkAddress.Container() && a.Object() == parentID, "invalid address to get remote header")
		ei := objectSDK.NewECInfo()
		var ch objectSDK.ECChunk
		ch.SetID(oidtest.ID())
		ch.Index = uint32(index)
		ch.Total = 3
		ei.AddChunk(ch)
		return nil, objectSDK.NewECInfoError(ei)
	}

	localHeadFn := func(_ context.Context, a oid.Address) (*objectSDK.Object, error) {
		require.True(t, a.Container() == chunkAddress.Container() && a.Object() == parentID, "invalid address to get remote header")
		ei := objectSDK.NewECInfo()
		var ch objectSDK.ECChunk
		ch.SetID(oidtest.ID())
		ch.Index = uint32(0)
		ch.Total = 3
		ei.AddChunk(ch)
		return nil, objectSDK.NewECInfoError(ei)
	}

	p := New(
		WithContainerSource(containerSrc),
		WithPlacementBuilder(placementBuilderFunc(placementBuilder)),
		WithNetmapKeys(announcedKeysFunc(func(k []byte) bool {
			return bytes.Equal(k, nodes[0].PublicKey())
		})),
		WithRemoteObjectHeaderFunc(remoteHeadFn),
		WithLocalObjectHeaderFunc(localHeadFn),
		WithPool(testPool(t)),
	)

	objInfo := objectcore.Info{
		Address: chunkAddress,
		Type:    objectSDK.TypeRegular,
		ECInfo: &objectcore.ECInfo{
			ParentID: parentID,
			Index:    0,
			Total:    3,
		},
	}
	err := p.processObject(context.Background(), objInfo)
	require.NoError(t, err)
}

func TestECChunkHasInvalidPlacement(t *testing.T) {
	t.Parallel()
	chunkAddress := oidtest.Address()
	parentID := oidtest.ID()
	chunkObject := objectSDK.New()
	chunkObject.SetContainerID(chunkAddress.Container())
	chunkObject.SetID(chunkAddress.Object())
	chunkObject.SetPayload([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
	chunkObject.SetPayloadSize(uint64(10))
	chunkObject.SetECHeader(objectSDK.NewECHeader(objectSDK.ECParentInfo{ID: parentID}, 0, 3, []byte{}, 0))

	var policy netmapSDK.PlacementPolicy
	require.NoError(t, policy.DecodeString("EC 2.1"))

	cnr := &container.Container{}
	cnr.Value.Init()
	cnr.Value.SetPlacementPolicy(policy)
	containerSrc := containerSrc{
		get: func(id cid.ID) (*container.Container, error) {
			if id.Equals(chunkAddress.Container()) {
				return cnr, nil
			}
			return nil, new(apistatus.ContainerNotFound)
		},
	}

	nodes := make([]netmapSDK.NodeInfo, 4)
	for i := range nodes {
		nodes[i].SetPublicKey([]byte{byte(i)})
	}

	placementBuilder := func(cnr cid.ID, obj *oid.ID, p netmapSDK.PlacementPolicy) ([][]netmapSDK.NodeInfo, error) {
		if cnr.Equals(chunkAddress.Container()) && obj.Equals(parentID) {
			return [][]netmapSDK.NodeInfo{nodes}, nil
		}
		return nil, errors.New("unexpected placement build")
	}

	objInfo := objectcore.Info{
		Address: chunkAddress,
		Type:    objectSDK.TypeRegular,
		ECInfo: &objectcore.ECInfo{
			ParentID: parentID,
			Index:    1,
			Total:    3,
		},
	}

	t.Run("node0 has chunk1, node1 has chunk0 and chunk1", func(t *testing.T) {
		// policer should pull chunk0 on first run and drop chunk1 on second run
		var allowDrop bool
		requiredChunkID := oidtest.ID()
		headFn := func(_ context.Context, ni netmapSDK.NodeInfo, a oid.Address, raw bool) (*objectSDK.Object, error) {
			if bytes.Equal(ni.PublicKey(), nodes[1].PublicKey()) && a == chunkAddress && !raw {
				return chunkObject, nil
			}
			if bytes.Equal(ni.PublicKey(), nodes[1].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				ei := objectSDK.NewECInfo()
				var ch objectSDK.ECChunk
				ch.SetID(oidtest.ID())
				ch.Index = 1
				ch.Total = 3
				ei.AddChunk(ch)
				ch.Index = 0
				ch.SetID(requiredChunkID)
				ei.AddChunk(ch)
				return nil, objectSDK.NewECInfoError(ei)
			}
			if bytes.Equal(ni.PublicKey(), nodes[2].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				ei := objectSDK.NewECInfo()
				var ch objectSDK.ECChunk
				ch.SetID(oidtest.ID())
				ch.Index = 2
				ch.Total = 3
				ei.AddChunk(ch)
				return nil, objectSDK.NewECInfoError(ei)
			}
			if bytes.Equal(ni.PublicKey(), nodes[3].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				return nil, new(apistatus.ObjectNotFound)
			}
			require.Fail(t, "unexpected remote HEAD")
			return nil, fmt.Errorf("unexpected remote HEAD")
		}

		localHeadF := func(_ context.Context, addr oid.Address) (*objectSDK.Object, error) {
			require.True(t, addr.Container() == chunkAddress.Container() && addr.Object() == parentID, "unexpected local HEAD")
			if allowDrop {
				ei := objectSDK.NewECInfo()
				var ch objectSDK.ECChunk
				ch.SetID(oidtest.ID())
				ch.Index = 1
				ch.Total = 3
				ei.AddChunk(ch)
				ch.SetID(requiredChunkID)
				ch.Index = 0
				ei.AddChunk(ch)
				return nil, objectSDK.NewECInfoError(ei)
			}
			ei := objectSDK.NewECInfo()
			var ch objectSDK.ECChunk
			ch.SetID(oidtest.ID())
			ch.Index = 1
			ch.Total = 3
			ei.AddChunk(ch)
			return nil, objectSDK.NewECInfoError(ei)
		}

		var pullCounter atomic.Int64
		var dropped []oid.Address
		p := New(
			WithContainerSource(containerSrc),
			WithPlacementBuilder(placementBuilderFunc(placementBuilder)),
			WithNetmapKeys(announcedKeysFunc(func(k []byte) bool {
				return bytes.Equal(k, nodes[0].PublicKey())
			})),
			WithRemoteObjectHeaderFunc(headFn),
			WithLocalObjectHeaderFunc(localHeadF),
			WithReplicator(&testReplicator{
				handlePullTask: (func(ctx context.Context, r replicator.Task) {
					require.True(t, r.Addr.Container() == chunkAddress.Container() && r.Addr.Object() == requiredChunkID &&
						len(r.Nodes) == 1 && bytes.Equal(r.Nodes[0].PublicKey(), nodes[1].PublicKey()), "invalid pull task")
					pullCounter.Add(1)
				}),
			}),
			WithRedundantCopyCallback(func(ctx context.Context, a oid.Address) {
				require.True(t, allowDrop, "invalid redundent copy call")
				dropped = append(dropped, a)
			}),
			WithPool(testPool(t)),
		)

		err := p.processObject(context.Background(), objInfo)
		require.NoError(t, err)
		require.Equal(t, int64(1), pullCounter.Load(), "invalid pull count")
		require.Equal(t, 0, len(dropped), "invalid dropped count")
		allowDrop = true
		err = p.processObject(context.Background(), objInfo)
		require.NoError(t, err)
		require.Equal(t, int64(1), pullCounter.Load(), "invalid pull count")
		require.Equal(t, 1, len(dropped), "invalid dropped count")
		require.True(t, chunkAddress.Equals(dropped[0]), "invalid dropped object")
	})

	t.Run("node0 has chunk0 and chunk1, node1 has chunk1", func(t *testing.T) {
		// policer should drop chunk1
		headFn := func(_ context.Context, ni netmapSDK.NodeInfo, a oid.Address, raw bool) (*objectSDK.Object, error) {
			if bytes.Equal(ni.PublicKey(), nodes[1].PublicKey()) && a == chunkAddress && !raw {
				return chunkObject, nil
			}
			if bytes.Equal(ni.PublicKey(), nodes[1].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				ei := objectSDK.NewECInfo()
				var ch objectSDK.ECChunk
				ch.SetID(chunkAddress.Object())
				ch.Index = 1
				ch.Total = 3
				ei.AddChunk(ch)
				return nil, objectSDK.NewECInfoError(ei)
			}
			if bytes.Equal(ni.PublicKey(), nodes[2].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				ei := objectSDK.NewECInfo()
				var ch objectSDK.ECChunk
				ch.SetID(oidtest.ID())
				ch.Index = 2
				ch.Total = 3
				ei.AddChunk(ch)
				return nil, objectSDK.NewECInfoError(ei)
			}
			if bytes.Equal(ni.PublicKey(), nodes[3].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				return nil, new(apistatus.ObjectNotFound)
			}
			require.Fail(t, "unexpected remote HEAD")
			return nil, fmt.Errorf("unexpected remote HEAD")
		}

		localHeadF := func(_ context.Context, addr oid.Address) (*objectSDK.Object, error) {
			require.True(t, addr.Container() == chunkAddress.Container() && addr.Object() == parentID, "unexpected local HEAD")
			ei := objectSDK.NewECInfo()
			var ch objectSDK.ECChunk
			ch.SetID(chunkAddress.Object())
			ch.Index = 1
			ch.Total = 3
			ei.AddChunk(ch)
			ch.SetID(oidtest.ID())
			ch.Index = 0
			ei.AddChunk(ch)
			return nil, objectSDK.NewECInfoError(ei)
		}

		var dropped []oid.Address
		p := New(
			WithContainerSource(containerSrc),
			WithPlacementBuilder(placementBuilderFunc(placementBuilder)),
			WithNetmapKeys(announcedKeysFunc(func(k []byte) bool {
				return bytes.Equal(k, nodes[0].PublicKey())
			})),
			WithRemoteObjectHeaderFunc(headFn),
			WithLocalObjectHeaderFunc(localHeadF),
			WithRedundantCopyCallback(func(ctx context.Context, a oid.Address) {
				dropped = append(dropped, a)
			}),
			WithPool(testPool(t)),
		)

		err := p.processObject(context.Background(), objInfo)
		require.NoError(t, err)
		require.Equal(t, 1, len(dropped), "invalid dropped count")
		require.True(t, chunkAddress.Equals(dropped[0]), "invalid dropped object")
	})

	t.Run("node0 has chunk0 and chunk1, node1 has no chunks", func(t *testing.T) {
		// policer should replicate chunk1 to node1 on first run and drop chunk1 on node0 on second run
		var secondRun bool
		headFn := func(_ context.Context, ni netmapSDK.NodeInfo, a oid.Address, raw bool) (*objectSDK.Object, error) {
			if bytes.Equal(ni.PublicKey(), nodes[1].PublicKey()) && a == chunkAddress && !raw {
				if !secondRun {
					return nil, new(apistatus.ObjectNotFound)
				}
				return chunkObject, nil
			}
			if bytes.Equal(ni.PublicKey(), nodes[1].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				ei := objectSDK.NewECInfo()
				var ch objectSDK.ECChunk
				ch.SetID(chunkAddress.Object())
				ch.Index = 1
				ch.Total = 3
				ei.AddChunk(ch)
				return nil, objectSDK.NewECInfoError(ei)
			}
			if bytes.Equal(ni.PublicKey(), nodes[2].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				ei := objectSDK.NewECInfo()
				var ch objectSDK.ECChunk
				ch.SetID(oidtest.ID())
				ch.Index = 2
				ch.Total = 3
				ei.AddChunk(ch)
				return nil, objectSDK.NewECInfoError(ei)
			}
			if bytes.Equal(ni.PublicKey(), nodes[3].PublicKey()) && a.Container() == chunkAddress.Container() &&
				a.Object() == parentID && raw {
				return nil, new(apistatus.ObjectNotFound)
			}
			require.Fail(t, "unexpected remote HEAD")
			return nil, fmt.Errorf("unexpected remote HEAD")
		}

		localHeadF := func(_ context.Context, addr oid.Address) (*objectSDK.Object, error) {
			require.True(t, addr.Container() == chunkAddress.Container() && addr.Object() == parentID, "unexpected local HEAD")
			ei := objectSDK.NewECInfo()
			var ch objectSDK.ECChunk
			ch.SetID(chunkAddress.Object())
			ch.Index = 1
			ch.Total = 3
			ei.AddChunk(ch)
			ch.SetID(oidtest.ID())
			ch.Index = 0
			ei.AddChunk(ch)
			return nil, objectSDK.NewECInfoError(ei)
		}

		var dropped []oid.Address
		var replicated []replicator.Task
		p := New(
			WithContainerSource(containerSrc),
			WithPlacementBuilder(placementBuilderFunc(placementBuilder)),
			WithNetmapKeys(announcedKeysFunc(func(k []byte) bool {
				return bytes.Equal(k, nodes[0].PublicKey())
			})),
			WithRemoteObjectHeaderFunc(headFn),
			WithLocalObjectHeaderFunc(localHeadF),
			WithRedundantCopyCallback(func(ctx context.Context, a oid.Address) {
				dropped = append(dropped, a)
			}),
			WithReplicator(&testReplicator{
				handleReplicationTask: func(ctx context.Context, t replicator.Task, tr replicator.TaskResult) {
					replicated = append(replicated, t)
				},
			}),
			WithPool(testPool(t)),
		)

		err := p.processObject(context.Background(), objInfo)
		require.NoError(t, err)
		require.Equal(t, 0, len(dropped), "invalid dropped count")
		require.Equal(t, 1, len(replicated), "invalid replicated count")
		require.Equal(t, chunkAddress, replicated[0].Addr, "invalid replicated object")
		require.True(t, bytes.Equal(replicated[0].Nodes[0].PublicKey(), nodes[1].PublicKey()), "invalid replicate target")

		secondRun = true
		err = p.processObject(context.Background(), objInfo)
		require.NoError(t, err)
		require.Equal(t, 1, len(replicated), "invalid replicated count")
		require.Equal(t, chunkAddress, replicated[0].Addr, "invalid replicated object")
		require.True(t, bytes.Equal(replicated[0].Nodes[0].PublicKey(), nodes[1].PublicKey()), "invalid replicate target")
		require.Equal(t, 1, len(dropped), "invalid dropped count")
		require.True(t, chunkAddress.Equals(dropped[0]), "invalid dropped object")
	})
}

func TestECChunkRestore(t *testing.T) {
	// node0 has chunk0, node1 has chunk1
	// policer should replicate chunk0 to node2 on the first run
	// then restore EC object and replicate chunk2 to node2 on the second run
	t.Parallel()

	payload := make([]byte, 64)
	rand.Read(payload)
	parentAddress := oidtest.Address()
	parentObject := objectSDK.New()
	parentObject.SetContainerID(parentAddress.Container())
	parentObject.SetPayload(payload)
	parentObject.SetPayloadSize(64)
	objectSDK.CalculateAndSetPayloadChecksum(parentObject)
	err := objectSDK.CalculateAndSetID(parentObject)
	require.NoError(t, err)
	id, _ := parentObject.ID()
	parentAddress.SetObject(id)

	chunkIDs := make([]oid.ID, 3)
	c, err := erasurecode.NewConstructor(2, 1)
	require.NoError(t, err)
	key, err := keys.NewPrivateKey()
	require.NoError(t, err)
	chunks, err := c.Split(parentObject, &key.PrivateKey)
	require.NoError(t, err)
	for i, ch := range chunks {
		chunkIDs[i], _ = ch.ID()
	}

	var policy netmapSDK.PlacementPolicy
	require.NoError(t, policy.DecodeString("EC 2.1"))

	cnr := &container.Container{}
	cnr.Value.Init()
	cnr.Value.SetPlacementPolicy(policy)
	containerSrc := containerSrc{
		get: func(id cid.ID) (*container.Container, error) {
			if id.Equals(parentAddress.Container()) {
				return cnr, nil
			}
			return nil, new(apistatus.ContainerNotFound)
		},
	}

	nodes := make([]netmapSDK.NodeInfo, 4)
	for i := range nodes {
		nodes[i].SetPublicKey([]byte{byte(i)})
	}

	placementBuilder := func(cnr cid.ID, obj *oid.ID, p netmapSDK.PlacementPolicy) ([][]netmapSDK.NodeInfo, error) {
		if cnr.Equals(parentAddress.Container()) && obj.Equals(parentAddress.Object()) {
			return [][]netmapSDK.NodeInfo{nodes}, nil
		}
		return nil, errors.New("unexpected placement build")
	}
	var secondRun bool
	remoteHeadFn := func(_ context.Context, ni netmapSDK.NodeInfo, a oid.Address, raw bool) (*objectSDK.Object, error) {
		require.True(t, raw, "remote header for parent object must be called with raw flag")
		index := int(ni.PublicKey()[0])
		require.True(t, index == 1 || index == 2 || index == 3, "invalid node to get parent header")
		require.True(t, a == parentAddress, "invalid address to get remote header")
		if index == 1 {
			ei := objectSDK.NewECInfo()
			var ch objectSDK.ECChunk
			ch.SetID(chunkIDs[1])
			ch.Index = uint32(1)
			ch.Total = 3
			ei.AddChunk(ch)
			return nil, objectSDK.NewECInfoError(ei)
		}
		if index == 2 && secondRun {
			ei := objectSDK.NewECInfo()
			var ch objectSDK.ECChunk
			ch.SetID(chunkIDs[0])
			ch.Index = uint32(0)
			ch.Total = 3
			ei.AddChunk(ch)
			return nil, objectSDK.NewECInfoError(ei)
		}
		return nil, new(apistatus.ObjectNotFound)
	}

	localHeadFn := func(_ context.Context, a oid.Address) (*objectSDK.Object, error) {
		require.True(t, a == parentAddress, "invalid address to get remote header")
		ei := objectSDK.NewECInfo()
		var ch objectSDK.ECChunk
		ch.SetID(chunkIDs[0])
		ch.Index = uint32(0)
		ch.Total = 3
		ei.AddChunk(ch)
		return nil, objectSDK.NewECInfoError(ei)
	}

	var replicatedObj []*objectSDK.Object
	p := New(
		WithContainerSource(containerSrc),
		WithPlacementBuilder(placementBuilderFunc(placementBuilder)),
		WithNetmapKeys(announcedKeysFunc(func(k []byte) bool {
			return bytes.Equal(k, nodes[0].PublicKey())
		})),
		WithRemoteObjectHeaderFunc(remoteHeadFn),
		WithLocalObjectHeaderFunc(localHeadFn),
		WithReplicator(&testReplicator{
			handleReplicationTask: func(ctx context.Context, t replicator.Task, tr replicator.TaskResult) {
				if t.Obj != nil {
					replicatedObj = append(replicatedObj, t.Obj)
				}
			},
		}),
		WithLocalObjectGetFunc(func(ctx context.Context, a oid.Address) (*objectSDK.Object, error) {
			require.True(t, a.Container() == parentAddress.Container() && a.Object() == chunkIDs[0], "invalid local object request")
			return chunks[0], nil
		}),
		WithRemoteObjectGetFunc(func(ctx context.Context, ni netmapSDK.NodeInfo, a oid.Address) (*objectSDK.Object, error) {
			index := ni.PublicKey()[0]
			if index == 2 {
				return nil, new(apistatus.ObjectNotFound)
			}
			return chunks[index], nil
		}),
		WithPool(testPool(t)),
		WithKeyStorage(util.NewKeyStorage(&key.PrivateKey, nil, nil)),
	)

	var chunkAddress oid.Address
	chunkAddress.SetContainer(parentAddress.Container())
	chunkAddress.SetObject(chunkIDs[0])
	objInfo := objectcore.Info{
		Address: chunkAddress,
		Type:    objectSDK.TypeRegular,
		ECInfo: &objectcore.ECInfo{
			ParentID: parentAddress.Object(),
			Index:    0,
			Total:    3,
		},
	}
	err = p.processObject(context.Background(), objInfo)
	require.NoError(t, err)
	secondRun = true
	err = p.processObject(context.Background(), objInfo)
	require.NoError(t, err)

	require.Equal(t, 1, len(replicatedObj), "invalid replicated objects count")
	chunks[2].SetSignature(nil)
	expectedData, err := chunks[2].MarshalJSON()
	require.NoError(t, err)
	replicatedObj[0].SetSignature(nil)
	actualData, err := replicatedObj[0].MarshalJSON()
	require.NoError(t, err)
	require.EqualValues(t, string(expectedData), string(actualData), "invalid restored objects")
}