package placement

import (
	"context"
	"slices"
	"strconv"
	"testing"

	netmapcore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/netmap"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/network"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"github.com/stretchr/testify/require"
)

type testBuilder struct {
	vectors [][]netmap.NodeInfo
}

func (b testBuilder) BuildPlacement(context.Context, cid.ID, *oid.ID, netmap.PlacementPolicy) ([][]netmap.NodeInfo, error) {
	return b.vectors, nil
}

func testNode(v uint32) (n netmap.NodeInfo) {
	ip := "/ip4/0.0.0.0/tcp/" + strconv.Itoa(int(v))
	n.SetNetworkEndpoints(ip)
	n.SetPublicKey([]byte(ip))

	return n
}

func copyVectors(v [][]netmap.NodeInfo) [][]netmap.NodeInfo {
	vc := make([][]netmap.NodeInfo, 0, len(v))

	for i := range v {
		ns := slices.Clone(v[i])

		vc = append(vc, ns)
	}

	return vc
}

func testPlacement(ss []int, rs []int) ([][]netmap.NodeInfo, container.Container) {
	return placement(ss, rs, nil)
}

func testECPlacement(ss []int, ec [][]int) ([][]netmap.NodeInfo, container.Container) {
	return placement(ss, nil, ec)
}

func placement(ss []int, rs []int, ec [][]int) ([][]netmap.NodeInfo, container.Container) {
	nodes := make([][]netmap.NodeInfo, 0, len(rs))
	replicas := make([]netmap.ReplicaDescriptor, 0, len(rs))
	num := uint32(0)

	for i := range ss {
		ns := make([]netmap.NodeInfo, 0, ss[i])

		for range ss[i] {
			ns = append(ns, testNode(num))
			num++
		}

		nodes = append(nodes, ns)

		var rd netmap.ReplicaDescriptor
		if len(rs) > 0 {
			rd.SetNumberOfObjects(uint32(rs[i]))
		} else {
			rd.SetECDataCount(uint32(ec[i][0]))
			rd.SetECParityCount(uint32(ec[i][1]))
		}

		replicas = append(replicas, rd)
	}

	var policy netmap.PlacementPolicy
	policy.AddReplicas(replicas...)

	var cnr container.Container
	cnr.SetPlacementPolicy(policy)

	return nodes, cnr
}

func assertSameAddress(t *testing.T, ni netmap.NodeInfo, addr network.AddressGroup) {
	var netAddr network.AddressGroup

	err := netAddr.FromIterator(netmapcore.Node(ni))
	require.NoError(t, err)
	require.True(t, netAddr.Intersects(addr))
}

func TestTraverserObjectScenarios(t *testing.T) {
	t.Run("search scenario", func(t *testing.T) {
		selectors := []int{2, 3}
		replicas := []int{1, 2}

		nodes, cnr := testPlacement(selectors, replicas)

		nodesCopy := copyVectors(nodes)

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{vectors: nodesCopy}),
			WithoutSuccessTracking(),
		)
		require.NoError(t, err)

		for i := range selectors {
			addrs := tr.Next()

			require.Len(t, addrs, len(nodes[i]))

			for j, n := range nodes[i] {
				assertSameAddress(t, n, addrs[j].Addresses())
			}
		}

		require.Empty(t, tr.Next())
		require.True(t, tr.Success())
	})

	t.Run("read scenario", func(t *testing.T) {
		selectors := []int{5, 3}
		replicas := []int{2, 2}

		nodes, cnr := testPlacement(selectors, replicas)

		nodesCopy := copyVectors(nodes)

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: nodesCopy,
			}),
			SuccessAfter(1),
		)
		require.NoError(t, err)

		for range len(nodes[0]) {
			require.NotNil(t, tr.Next())
		}

		var n network.AddressGroup

		err = n.FromIterator(netmapcore.Node(nodes[1][0]))
		require.NoError(t, err)

		require.Equal(t, []Node{{addresses: n, key: []byte("/ip4/0.0.0.0/tcp/5")}}, tr.Next())
	})

	t.Run("put scenario", func(t *testing.T) {
		selectors := []int{5, 3}
		replicas := []int{2, 2}

		nodes, cnr := testPlacement(selectors, replicas)

		nodesCopy := copyVectors(nodes)

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{vectors: nodesCopy}),
		)
		require.NoError(t, err)

		fn := func(curVector int) {
			for i := 0; i+replicas[curVector] < selectors[curVector]; i += replicas[curVector] {
				addrs := tr.Next()
				require.Len(t, addrs, replicas[curVector])

				for j := range addrs {
					assertSameAddress(t, nodes[curVector][i+j], addrs[j].Addresses())
				}
			}

			require.Empty(t, tr.Next())
			require.False(t, tr.Success())

			for range replicas[curVector] {
				tr.SubmitSuccess()
			}
		}

		for i := range selectors {
			fn(i)

			if i < len(selectors)-1 {
				require.False(t, tr.Success())
			} else {
				require.True(t, tr.Success())
			}
		}
	})

	t.Run("local operation scenario", func(t *testing.T) {
		selectors := []int{2, 3}
		replicas := []int{1, 2}

		nodes, cnr := testPlacement(selectors, replicas)

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: [][]netmap.NodeInfo{{nodes[1][1]}}, // single node (local)
			}),
			SuccessAfter(1),
		)
		require.NoError(t, err)

		require.NotEmpty(t, tr.Next())
		require.False(t, tr.Success())

		// add 1 OK
		tr.SubmitSuccess()

		// nothing more to do
		require.Empty(t, tr.Next())

		// common success
		require.True(t, tr.Success())
	})
}

func TestTraverserRemValues(t *testing.T) {
	selectors := []int{3, 4, 5}
	replicas := []int{2, 3, 4}

	nodes, cnr := testPlacement(selectors, replicas)
	nodesCopy := copyVectors(nodes)

	testCases := [...]struct {
		name        string
		copyNumbers []uint32
		expectedRem []int
		expectedErr error
	}{
		{
			name:        "zero copy numbers",
			copyNumbers: []uint32{},
			expectedRem: replicas,
		},
		{
			name:        "compatible zero copy numbers, len 1",
			copyNumbers: []uint32{0},
			expectedRem: replicas,
		},
		{
			name:        "compatible zero copy numbers, len 3",
			copyNumbers: []uint32{0, 0, 0},
			expectedRem: replicas,
		},
		{
			name:        "copy numbers for all replicas",
			copyNumbers: []uint32{1, 1, 1},
			expectedRem: []int{1, 1, 1},
		},
		{
			name:        "single copy numbers for multiple replicas",
			copyNumbers: []uint32{1},
			expectedRem: []int{1}, // may be a bit unexpected
		},
		{
			name:        "multiple copy numbers for multiple replicas",
			copyNumbers: []uint32{1, 1, 4},
			expectedRem: []int{1, 1, 4},
		},
		{
			name:        "incompatible copies number vector",
			copyNumbers: []uint32{1, 1},
			expectedErr: errCopiesNumberLen,
		},
	}

	for _, testCase := range testCases {
		t.Run(testCase.name, func(t *testing.T) {
			tr, err := NewTraverser(context.Background(),
				ForContainer(cnr),
				UseBuilder(&testBuilder{vectors: nodesCopy}),
				WithCopyNumbers(testCase.copyNumbers),
			)
			if testCase.expectedErr == nil {
				require.NoError(t, err, testCase.name)
				require.Equal(t, testCase.expectedRem, tr.rem, testCase.name)
			} else {
				require.Error(t, err, testCase.expectedErr, testCase.name)
			}
		})
	}
}

type nodeState struct {
	node *netmap.NodeInfo
}

func (n *nodeState) LocalNodeInfo() *netmap.NodeInfo {
	return n.node
}

func TestTraverserPriorityMetrics(t *testing.T) {
	t.Run("one rep one metric", func(t *testing.T) {
		selectors := []int{4}
		replicas := []int{3}

		nodes, cnr := testPlacement(selectors, replicas)

		// Node_0, PK - ip4/0.0.0.0/tcp/0
		nodes[0][0].SetAttribute("ClusterName", "A")
		// Node_1, PK - ip4/0.0.0.0/tcp/1
		nodes[0][1].SetAttribute("ClusterName", "A")
		// Node_2, PK - ip4/0.0.0.0/tcp/2
		nodes[0][2].SetAttribute("ClusterName", "B")
		// Node_3, PK - ip4/0.0.0.0/tcp/3
		nodes[0][3].SetAttribute("ClusterName", "B")

		sdkNode := testNode(5)
		sdkNode.SetAttribute("ClusterName", "B")

		nodesCopy := copyVectors(nodes)

		m := []Metric{NewAttributeMetric("ClusterName")}

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: nodesCopy,
			}),
			WithoutSuccessTracking(),
			WithPriorityMetrics(m),
			WithNodeState(&nodeState{
				node: &sdkNode,
			}),
		)
		require.NoError(t, err)

		// Without priority metric `ClusterName` the order will be:
		// [ {Node_0 A}, {Node_1 A}, {Node_2 B}, {Node_3 B}]
		// With priority metric `ClusterName` and current node in cluster B
		// the order should be:
		// [ {Node_2 B}, {Node_0 A}, {Node_1 A}, {Node_3 B}]
		next := tr.Next()
		require.NotNil(t, next)
		require.Equal(t, 3, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/2", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/0", string(next[1].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/1", string(next[2].PublicKey()))

		next = tr.Next()
		// The last node is
		require.Equal(t, 1, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/3", string(next[0].PublicKey()))

		next = tr.Next()
		require.Nil(t, next)
	})

	t.Run("one rep one metric fewer nodes", func(t *testing.T) {
		selectors := []int{2}
		replicas := []int{3}

		nodes, cnr := testPlacement(selectors, replicas)

		// Node_0, PK - ip4/0.0.0.0/tcp/0
		nodes[0][0].SetAttribute("ClusterName", "A")
		// Node_1, PK - ip4/0.0.0.0/tcp/1
		nodes[0][1].SetAttribute("ClusterName", "B")

		sdkNode := testNode(5)
		sdkNode.SetAttribute("ClusterName", "B")

		nodesCopy := copyVectors(nodes)

		m := []Metric{NewAttributeMetric("ClusterName")}

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: nodesCopy,
			}),
			WithoutSuccessTracking(),
			WithPriorityMetrics(m),
			WithNodeState(&nodeState{
				node: &sdkNode,
			}),
		)
		require.NoError(t, err)

		// Without priority metric `ClusterName` the order will be:
		// [ {Node_0 A}, {Node_1 A} ]
		// With priority metric `ClusterName` and current node in cluster B
		// the order should be:
		// [ {Node_1 B}, {Node_0 A} ]
		next := tr.Next()
		require.NotNil(t, next)
		require.Equal(t, 2, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/1", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/0", string(next[1].PublicKey()))

		next = tr.Next()
		require.Nil(t, next)
	})

	t.Run("two reps two metrics", func(t *testing.T) {
		selectors := []int{3, 3}
		replicas := []int{2, 2}

		nodes, cnr := testPlacement(selectors, replicas)

		// REPLICA #1
		// Node_0, PK - ip4/0.0.0.0/tcp/0
		nodes[0][0].SetAttribute("ClusterName", "A")
		nodes[0][0].SetAttribute("UN-LOCODE", "RU LED")

		// Node_1, PK - ip4/0.0.0.0/tcp/1
		nodes[0][1].SetAttribute("ClusterName", "A")
		nodes[0][1].SetAttribute("UN-LOCODE", "FI HEL")

		// Node_2, PK - ip4/0.0.0.0/tcp/2
		nodes[0][2].SetAttribute("ClusterName", "A")
		nodes[0][2].SetAttribute("UN-LOCODE", "RU LED")

		// REPLICA #2
		// Node_3 ip4/0.0.0.0/tcp/3
		nodes[1][0].SetAttribute("ClusterName", "B")
		nodes[1][0].SetAttribute("UN-LOCODE", "RU MOW")

		// Node_4, PK - ip4/0.0.0.0/tcp/4
		nodes[1][1].SetAttribute("ClusterName", "B")
		nodes[1][1].SetAttribute("UN-LOCODE", "RU DME")

		// Node_5, PK - ip4/0.0.0.0/tcp/5
		nodes[1][2].SetAttribute("ClusterName", "B")
		nodes[1][2].SetAttribute("UN-LOCODE", "RU MOW")

		sdkNode := testNode(9)
		sdkNode.SetAttribute("ClusterName", "B")
		sdkNode.SetAttribute("UN-LOCODE", "RU DME")

		nodesCopy := copyVectors(nodes)

		m := []Metric{
			NewAttributeMetric("ClusterName"),
			NewAttributeMetric("UN-LOCODE"),
		}

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: nodesCopy,
			}),
			WithoutSuccessTracking(),
			WithPriorityMetrics(m),
			WithNodeState(&nodeState{
				node: &sdkNode,
			}),
		)
		require.NoError(t, err)

		// Check that nodes in the same cluster and
		// in the same location should be the first in slice.
		// Nodes which are follow criteria but stay outside the replica
		// should be in the next slice.

		next := tr.Next()
		require.Equal(t, 4, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/4", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/3", string(next[1].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/0", string(next[2].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/1", string(next[3].PublicKey()))

		next = tr.Next()
		require.Equal(t, 2, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/2", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/5", string(next[1].PublicKey()))

		next = tr.Next()
		require.Nil(t, next)

		sdkNode.SetAttribute("ClusterName", "B")
		sdkNode.SetAttribute("UN-LOCODE", "RU MOW")

		nodesCopy = copyVectors(nodes)

		tr, err = NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: nodesCopy,
			}),
			WithoutSuccessTracking(),
			WithPriorityMetrics(m),
			WithNodeState(&nodeState{
				node: &sdkNode,
			}),
		)
		require.NoError(t, err)

		next = tr.Next()
		require.Equal(t, 4, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/3", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/4", string(next[1].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/0", string(next[2].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/1", string(next[3].PublicKey()))

		next = tr.Next()
		require.Equal(t, 2, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/2", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/5", string(next[1].PublicKey()))

		next = tr.Next()
		require.Nil(t, next)

		sdkNode.SetAttribute("ClusterName", "A")
		sdkNode.SetAttribute("UN-LOCODE", "RU LED")

		nodesCopy = copyVectors(nodes)

		tr, err = NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: nodesCopy,
			}),
			WithoutSuccessTracking(),
			WithPriorityMetrics(m),
			WithNodeState(&nodeState{
				node: &sdkNode,
			}),
		)
		require.NoError(t, err)

		next = tr.Next()
		require.Equal(t, 4, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/0", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/1", string(next[1].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/3", string(next[2].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/4", string(next[3].PublicKey()))

		next = tr.Next()
		require.Equal(t, 2, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/2", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/5", string(next[1].PublicKey()))

		next = tr.Next()
		require.Nil(t, next)
	})

	t.Run("ec container", func(t *testing.T) {
		selectors := []int{4}
		ec := [][]int{{2, 1}}

		nodes, cnr := testECPlacement(selectors, ec)

		// Node_0, PK - ip4/0.0.0.0/tcp/0
		nodes[0][0].SetAttribute("ClusterName", "A")
		// Node_1, PK - ip4/0.0.0.0/tcp/1
		nodes[0][1].SetAttribute("ClusterName", "A")
		// Node_2, PK - ip4/0.0.0.0/tcp/2
		nodes[0][2].SetAttribute("ClusterName", "B")
		// Node_3, PK - ip4/0.0.0.0/tcp/3
		nodes[0][3].SetAttribute("ClusterName", "B")

		sdkNode := testNode(5)
		sdkNode.SetAttribute("ClusterName", "B")

		nodesCopy := copyVectors(nodes)

		m := []Metric{NewAttributeMetric("ClusterName")}

		tr, err := NewTraverser(context.Background(),
			ForContainer(cnr),
			UseBuilder(&testBuilder{
				vectors: nodesCopy,
			}),
			WithoutSuccessTracking(),
			WithPriorityMetrics(m),
			WithNodeState(&nodeState{
				node: &sdkNode,
			}),
		)
		require.NoError(t, err)

		// Without priority metric `ClusterName` the order will be:
		// [ {Node_0 A}, {Node_1 A}, {Node_2 B}, {Node_3 B}]
		// With priority metric `ClusterName` and current node in cluster B
		// the order should be:
		// [ {Node_2 B}, {Node_0 A}, {Node_1 A}, {Node_3 B}]
		next := tr.Next()
		require.NotNil(t, next)
		require.Equal(t, 3, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/2", string(next[0].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/0", string(next[1].PublicKey()))
		require.Equal(t, "/ip4/0.0.0.0/tcp/1", string(next[2].PublicKey()))

		next = tr.Next()
		// The last node is
		require.Equal(t, 1, len(next))
		require.Equal(t, "/ip4/0.0.0.0/tcp/3", string(next[0].PublicKey()))

		next = tr.Next()
		require.Nil(t, next)
	})
}