package tree

import (
	"context"
	"errors"
	"path"
	"path/filepath"
	"slices"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/pilorama"
	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	"github.com/stretchr/testify/require"
	"google.golang.org/grpc"
)

func TestGetSubTree(t *testing.T) {
	d := pilorama.CIDDescriptor{CID: cidtest.ID(), Size: 1}
	treeID := "sometree"
	p := pilorama.NewMemoryForest()

	tree := []struct {
		path []string
		id   uint64
	}{
		{path: []string{"dir1"}},
		{path: []string{"dir2"}},
		{path: []string{"dir1", "sub1"}},
		{path: []string{"dir2", "sub1"}},
		{path: []string{"dir2", "sub2"}},
		{path: []string{"dir2", "sub1", "subsub1"}},
	}

	for i := range tree {
		path := tree[i].path
		meta := []pilorama.KeyValue{
			{Key: pilorama.AttributeFilename, Value: []byte(path[len(path)-1])},
		}

		lm, err := p.TreeAddByPath(context.Background(), d, treeID, pilorama.AttributeFilename, path[:len(path)-1], meta)
		require.NoError(t, err)
		require.Equal(t, 1, len(lm))

		tree[i].id = lm[0].Child
	}

	testGetSubTree := func(t *testing.T, rootID uint64, depth uint32, errIndex int) []uint64 {
		acc := subTreeAcc{errIndex: errIndex}
		err := getSubTree(context.Background(), &acc, d.CID, &GetSubTreeRequest_Body{
			TreeId: treeID,
			RootId: []uint64{rootID},
			Depth:  depth,
		}, p)
		if errIndex == -1 {
			require.NoError(t, err)
		} else {
			require.ErrorIs(t, err, errSubTreeSend)
		}

		// GetSubTree must return child only after is has returned the parent.
		require.Equal(t, rootID, acc.seen[0].Body.NodeId[0])
	loop:
		for i := 1; i < len(acc.seen); i++ {
			parent := acc.seen[i].Body.ParentId
			for j := range i {
				if acc.seen[j].Body.NodeId[0] == parent[0] {
					continue loop
				}
			}
			require.Fail(t, "node has parent %d, but it hasn't been seen", parent)
		}

		// GetSubTree must return valid meta.
		for i := range acc.seen {
			b := acc.seen[i].Body
			meta, node, err := p.TreeGetMeta(context.Background(), d.CID, treeID, b.NodeId[0])
			require.NoError(t, err)
			require.Equal(t, node, b.ParentId[0])
			require.Equal(t, meta.Time, b.Timestamp[0])
			require.Equal(t, metaToProto(meta.Items), b.Meta)
		}

		ordered := make([]uint64, len(acc.seen))
		for i := range acc.seen {
			ordered[i] = acc.seen[i].Body.NodeId[0]
		}
		return ordered
	}

	t.Run("depth = 1, only root", func(t *testing.T) {
		actual := testGetSubTree(t, 0, 1, -1)
		require.Equal(t, []uint64{0}, actual)

		t.Run("custom root", func(t *testing.T) {
			actual := testGetSubTree(t, tree[2].id, 1, -1)
			require.Equal(t, []uint64{tree[2].id}, actual)
		})
	})
	t.Run("depth = 2", func(t *testing.T) {
		actual := testGetSubTree(t, 0, 2, -1)
		require.Equal(t, []uint64{0, tree[0].id, tree[1].id}, actual)

		t.Run("error in the middle", func(t *testing.T) {
			actual := testGetSubTree(t, 0, 2, 0)
			require.Equal(t, []uint64{0}, actual)

			actual = testGetSubTree(t, 0, 2, 1)
			require.Equal(t, []uint64{0, tree[0].id}, actual)
		})
	})
	t.Run("depth = 0 (unrestricted)", func(t *testing.T) {
		actual := testGetSubTree(t, 0, 0, -1)
		expected := []uint64{
			0,
			tree[0].id, // dir1
			tree[2].id, // dir1/sub1
			tree[1].id, // dir2
			tree[3].id, // dir2/sub1
			tree[5].id, // dir2/sub1/subsub1
			tree[4].id, // dir2/sub2
		}
		require.Equal(t, expected, actual)
	})
}

func TestGetSubTreeOrderAsc(t *testing.T) {
	t.Run("memory forest", func(t *testing.T) {
		testGetSubTreeOrderAsc(t, pilorama.NewMemoryForest())
	})

	t.Run("boltdb forest", func(t *testing.T) {
		p := pilorama.NewBoltForest(pilorama.WithPath(filepath.Join(t.TempDir(), "pilorama")))
		require.NoError(t, p.Open(context.Background(), 0o644))
		require.NoError(t, p.Init())
		testGetSubTreeOrderAsc(t, p)
	})
}

func testGetSubTreeOrderAsc(t *testing.T, p pilorama.ForestStorage) {
	d := pilorama.CIDDescriptor{CID: cidtest.ID(), Size: 1}
	treeID := "sometree"

	tree := []struct {
		path []string
		id   uint64
	}{
		{path: []string{"dir1"}},
		{path: []string{"dir2"}},
		{path: []string{"dir1", "sub1"}},
		{path: []string{"dir2", "sub1"}},
		{path: []string{"dir2", "sub2"}},
		{path: []string{"dir2", "sub1", "subsub1"}},
	}

	for i := range tree {
		path := tree[i].path
		meta := []pilorama.KeyValue{
			{Key: pilorama.AttributeFilename, Value: []byte(path[len(path)-1])},
		}

		lm, err := p.TreeAddByPath(context.Background(), d, treeID, pilorama.AttributeFilename, path[:len(path)-1], meta)
		require.NoError(t, err)
		require.Equal(t, 1, len(lm))
		tree[i].id = lm[0].Child
	}

	t.Run("total", func(t *testing.T) {
		t.Skip()
		acc := subTreeAcc{errIndex: -1}
		err := getSubTree(context.Background(), &acc, d.CID, &GetSubTreeRequest_Body{
			TreeId: treeID,
			OrderBy: &GetSubTreeRequest_Body_Order{
				Direction: GetSubTreeRequest_Body_Order_Asc,
			},
		}, p)
		require.NoError(t, err)
		// GetSubTree must return child only after is has returned the parent.
		require.Equal(t, uint64(0), acc.seen[0].Body.NodeId)

		paths := make([]string, 0, len(acc.seen))
		for i := range acc.seen {
			if i == 0 {
				continue
			}
			found := false
			for j := range tree {
				if acc.seen[i].Body.NodeId[0] == tree[j].id {
					found = true
					paths = append(paths, path.Join(tree[j].path...))
				}
			}
			require.True(t, found, "unknown node %d %v", i, acc.seen[i].GetBody().GetNodeId())
		}

		require.True(t, slices.IsSorted(paths))
	})
	t.Run("depth=1", func(t *testing.T) {
		acc := subTreeAcc{errIndex: -1}
		err := getSubTree(context.Background(), &acc, d.CID, &GetSubTreeRequest_Body{
			TreeId: treeID,
			Depth:  1,
			OrderBy: &GetSubTreeRequest_Body_Order{
				Direction: GetSubTreeRequest_Body_Order_Asc,
			},
		}, p)
		require.NoError(t, err)
		require.Len(t, acc.seen, 1)
		require.Equal(t, uint64(0), acc.seen[0].Body.NodeId[0])
	})
	t.Run("depth=2", func(t *testing.T) {
		acc := subTreeAcc{errIndex: -1}
		err := getSubTree(context.Background(), &acc, d.CID, &GetSubTreeRequest_Body{
			TreeId: treeID,
			Depth:  2,
			OrderBy: &GetSubTreeRequest_Body_Order{
				Direction: GetSubTreeRequest_Body_Order_Asc,
			},
		}, p)
		require.NoError(t, err)
		require.Len(t, acc.seen, 3)
		require.Equal(t, uint64(0), acc.seen[0].Body.NodeId[0])
		require.Equal(t, uint64(0), acc.seen[1].GetBody().GetParentId()[0])
		require.Equal(t, uint64(0), acc.seen[2].GetBody().GetParentId()[0])
	})
}

var (
	errSubTreeSend           = errors.New("send finished with error")
	errSubTreeSendAfterError = errors.New("send was invoked after an error occurred")
	errInvalidResponse       = errors.New("send got invalid response")
)

type subTreeAcc struct {
	grpc.ServerStream // to satisfy the interface
	// IDs of the seen nodes.
	seen     []*GetSubTreeResponse
	errIndex int
}

var _ TreeService_GetSubTreeServer = &subTreeAcc{}

func (s *subTreeAcc) Send(r *GetSubTreeResponse) error {
	b := r.GetBody()
	if len(b.GetNodeId()) > 1 {
		return errInvalidResponse
	}
	if len(b.GetParentId()) > 1 {
		return errInvalidResponse
	}
	if len(b.GetTimestamp()) > 1 {
		return errInvalidResponse
	}
	s.seen = append(s.seen, r)
	if s.errIndex >= 0 {
		if len(s.seen) == s.errIndex+1 {
			return errSubTreeSend
		}
		if s.errIndex >= 0 && len(s.seen) > s.errIndex {
			return errSubTreeSendAfterError
		}
	}
	return nil
}