package engine

import (
	"context"
	"errors"
	"fmt"
	"path/filepath"
	"strconv"
	"testing"
	"time"

	objectCore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/fstree"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/testutil"
	meta "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/metabase"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/pilorama"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard/mode"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger/test"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"github.com/stretchr/testify/require"
	"golang.org/x/sync/errgroup"
)

func newEngineEvacuate(t *testing.T, shardNum int, objPerShard int) (*StorageEngine, []*shard.ID, []*objectSDK.Object) {
	dir := t.TempDir()

	te := testNewEngine(t).
		setShardsNumOpts(t, shardNum, func(id int) []shard.Option {
			return []shard.Option{
				shard.WithLogger(test.NewLogger(t)),
				shard.WithBlobStorOptions(
					blobstor.WithStorages([]blobstor.SubStorage{{
						Storage: fstree.New(
							fstree.WithPath(filepath.Join(dir, strconv.Itoa(id))),
							fstree.WithDepth(1)),
					}})),
				shard.WithMetaBaseOptions(
					meta.WithPath(filepath.Join(dir, fmt.Sprintf("%d.metabase", id))),
					meta.WithPermissions(0o700),
					meta.WithEpochState(epochState{})),
				shard.WithPiloramaOptions(
					pilorama.WithPath(filepath.Join(dir, fmt.Sprintf("%d.pilorama", id))),
					pilorama.WithPerm(0o700),
				),
			}
		})
	e, ids := te.engine, te.shardIDs
	require.NoError(t, e.Open(context.Background()))
	require.NoError(t, e.Init(context.Background()))

	objects := make([]*objectSDK.Object, 0, objPerShard*len(ids))
	treeID := "version"
	meta := []pilorama.KeyValue{
		{Key: pilorama.AttributeVersion, Value: []byte("XXX")},
		{Key: pilorama.AttributeFilename, Value: []byte("file.txt")},
	}

	for _, sh := range ids {
		for i := 0; i < objPerShard; i++ {
			contID := cidtest.ID()
			obj := testutil.GenerateObjectWithCID(contID)
			objects = append(objects, obj)

			var putPrm shard.PutPrm
			putPrm.SetObject(obj)
			_, err := e.shards[sh.String()].Put(context.Background(), putPrm)
			require.NoError(t, err)

			_, err = e.shards[sh.String()].TreeAddByPath(context.Background(), pilorama.CIDDescriptor{CID: contID, Position: 0, Size: 1},
				treeID, pilorama.AttributeFilename, []string{"path", "to", "the", "file"}, meta)
			require.NoError(t, err)
		}
	}
	return e, ids, objects
}

func TestEvacuateShardObjects(t *testing.T) {
	t.Parallel()

	const objPerShard = 3

	e, ids, objects := newEngineEvacuate(t, 3, objPerShard)
	defer func() {
		require.NoError(t, e.Close(context.Background()))
	}()

	evacuateShardID := ids[2].String()

	checkHasObjects := func(t *testing.T) {
		for i := range objects {
			var prm GetPrm
			prm.WithAddress(objectCore.AddressOf(objects[i]))

			_, err := e.Get(context.Background(), prm)
			require.NoError(t, err)
		}
	}

	checkHasObjects(t)

	var prm EvacuateShardPrm
	prm.ShardID = ids[2:3]
	prm.Scope = EvacuateScopeObjects

	t.Run("must be read-only", func(t *testing.T) {
		res, err := e.Evacuate(context.Background(), prm)
		require.ErrorIs(t, err, ErrMustBeReadOnly)
		require.Equal(t, uint64(0), res.ObjectsEvacuated())
	})

	require.NoError(t, e.shards[evacuateShardID].SetMode(mode.ReadOnly))

	res, err := e.Evacuate(context.Background(), prm)
	require.NoError(t, err)
	require.Equal(t, uint64(objPerShard), res.ObjectsEvacuated())

	// We check that all objects are available both before and after shard removal.
	// First case is a real-world use-case. It ensures that an object can be put in presense
	// of all metabase checks/marks.
	// Second case ensures that all objects are indeed moved and available.
	checkHasObjects(t)

	// Calling it again is OK, but all objects are already moved, so no new PUTs should be done.
	res, err = e.Evacuate(context.Background(), prm)
	require.NoError(t, err)
	require.Equal(t, uint64(0), res.ObjectsEvacuated())

	checkHasObjects(t)

	e.mtx.Lock()
	delete(e.shards, evacuateShardID)
	delete(e.shardPools, evacuateShardID)
	e.mtx.Unlock()

	checkHasObjects(t)
}

func TestEvacuateObjectsNetwork(t *testing.T) {
	t.Parallel()

	errReplication := errors.New("handler error")

	acceptOneOf := func(objects []*objectSDK.Object, max uint64) func(context.Context, oid.Address, *objectSDK.Object) error {
		var n uint64
		return func(_ context.Context, addr oid.Address, obj *objectSDK.Object) error {
			if n == max {
				return errReplication
			}

			n++
			for i := range objects {
				if addr == objectCore.AddressOf(objects[i]) {
					require.Equal(t, objects[i], obj)
					return nil
				}
			}
			require.FailNow(t, "handler was called with an unexpected object: %s", addr)
			panic("unreachable")
		}
	}

	t.Run("single shard", func(t *testing.T) {
		t.Parallel()
		e, ids, objects := newEngineEvacuate(t, 1, 3)
		defer func() {
			require.NoError(t, e.Close(context.Background()))
		}()

		evacuateShardID := ids[0].String()

		require.NoError(t, e.shards[evacuateShardID].SetMode(mode.ReadOnly))

		var prm EvacuateShardPrm
		prm.ShardID = ids[0:1]
		prm.Scope = EvacuateScopeObjects

		res, err := e.Evacuate(context.Background(), prm)
		require.ErrorIs(t, err, errMustHaveTwoShards)
		require.Equal(t, uint64(0), res.ObjectsEvacuated())

		prm.ObjectsHandler = acceptOneOf(objects, 2)

		res, err = e.Evacuate(context.Background(), prm)
		require.ErrorIs(t, err, errReplication)
		require.Equal(t, uint64(2), res.ObjectsEvacuated())
	})
	t.Run("multiple shards, evacuate one", func(t *testing.T) {
		t.Parallel()
		e, ids, objects := newEngineEvacuate(t, 2, 3)
		defer func() {
			require.NoError(t, e.Close(context.Background()))
		}()

		require.NoError(t, e.shards[ids[0].String()].SetMode(mode.ReadOnly))
		require.NoError(t, e.shards[ids[1].String()].SetMode(mode.ReadOnly))

		var prm EvacuateShardPrm
		prm.ShardID = ids[1:2]
		prm.ObjectsHandler = acceptOneOf(objects, 2)
		prm.Scope = EvacuateScopeObjects

		res, err := e.Evacuate(context.Background(), prm)
		require.ErrorIs(t, err, errReplication)
		require.Equal(t, uint64(2), res.ObjectsEvacuated())

		t.Run("no errors", func(t *testing.T) {
			prm.ObjectsHandler = acceptOneOf(objects, 3)

			res, err := e.Evacuate(context.Background(), prm)
			require.NoError(t, err)
			require.Equal(t, uint64(3), res.ObjectsEvacuated())
		})
	})
	t.Run("multiple shards, evacuate many", func(t *testing.T) {
		t.Parallel()
		e, ids, objects := newEngineEvacuate(t, 4, 5)
		defer func() {
			require.NoError(t, e.Close(context.Background()))
		}()

		evacuateIDs := ids[0:3]

		var totalCount uint64
		for i := range evacuateIDs {
			res, err := e.shards[ids[i].String()].List(context.Background())
			require.NoError(t, err)

			totalCount += uint64(len(res.AddressList()))
		}

		for i := range ids {
			require.NoError(t, e.shards[ids[i].String()].SetMode(mode.ReadOnly))
		}

		var prm EvacuateShardPrm
		prm.ShardID = evacuateIDs
		prm.ObjectsHandler = acceptOneOf(objects, totalCount-1)
		prm.Scope = EvacuateScopeObjects

		res, err := e.Evacuate(context.Background(), prm)
		require.ErrorIs(t, err, errReplication)
		require.Equal(t, totalCount-1, res.ObjectsEvacuated())

		t.Run("no errors", func(t *testing.T) {
			prm.ObjectsHandler = acceptOneOf(objects, totalCount)

			res, err := e.Evacuate(context.Background(), prm)
			require.NoError(t, err)
			require.Equal(t, totalCount, res.ObjectsEvacuated())
		})
	})
}

func TestEvacuateCancellation(t *testing.T) {
	t.Parallel()
	e, ids, _ := newEngineEvacuate(t, 2, 3)
	defer func() {
		require.NoError(t, e.Close(context.Background()))
	}()

	require.NoError(t, e.shards[ids[0].String()].SetMode(mode.ReadOnly))
	require.NoError(t, e.shards[ids[1].String()].SetMode(mode.ReadOnly))

	var prm EvacuateShardPrm
	prm.ShardID = ids[1:2]
	prm.ObjectsHandler = func(ctx context.Context, a oid.Address, o *objectSDK.Object) error {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		return nil
	}
	prm.Scope = EvacuateScopeObjects

	ctx, cancel := context.WithCancel(context.Background())
	cancel()

	res, err := e.Evacuate(ctx, prm)
	require.ErrorContains(t, err, "context canceled")
	require.Equal(t, uint64(0), res.ObjectsEvacuated())
}

func TestEvacuateSingleProcess(t *testing.T) {
	e, ids, _ := newEngineEvacuate(t, 2, 3)
	defer func() {
		require.NoError(t, e.Close(context.Background()))
	}()

	require.NoError(t, e.shards[ids[0].String()].SetMode(mode.ReadOnly))
	require.NoError(t, e.shards[ids[1].String()].SetMode(mode.ReadOnly))

	blocker := make(chan interface{})
	running := make(chan interface{})

	var prm EvacuateShardPrm
	prm.ShardID = ids[1:2]
	prm.Scope = EvacuateScopeObjects
	prm.ObjectsHandler = func(ctx context.Context, a oid.Address, o *objectSDK.Object) error {
		select {
		case <-running:
		default:
			close(running)
		}
		<-blocker
		return nil
	}

	eg, egCtx := errgroup.WithContext(context.Background())
	eg.Go(func() error {
		res, err := e.Evacuate(egCtx, prm)
		require.NoError(t, err, "first evacuation failed")
		require.Equal(t, uint64(3), res.ObjectsEvacuated())
		return nil
	})
	eg.Go(func() error {
		<-running
		res, err := e.Evacuate(egCtx, prm)
		require.ErrorContains(t, err, "evacuate is already running for shard ids", "second evacuation not failed")
		require.Equal(t, uint64(0), res.ObjectsEvacuated())
		close(blocker)
		return nil
	})
	require.NoError(t, eg.Wait())
}

func TestEvacuateObjectsAsync(t *testing.T) {
	e, ids, _ := newEngineEvacuate(t, 2, 3)
	defer func() {
		require.NoError(t, e.Close(context.Background()))
	}()

	require.NoError(t, e.shards[ids[0].String()].SetMode(mode.ReadOnly))
	require.NoError(t, e.shards[ids[1].String()].SetMode(mode.ReadOnly))

	blocker := make(chan interface{})
	running := make(chan interface{})

	var prm EvacuateShardPrm
	prm.ShardID = ids[1:2]
	prm.Scope = EvacuateScopeObjects
	prm.ObjectsHandler = func(ctx context.Context, a oid.Address, o *objectSDK.Object) error {
		select {
		case <-running:
		default:
			close(running)
		}
		<-blocker
		return nil
	}

	st, err := e.GetEvacuationState(context.Background())
	require.NoError(t, err, "get init state failed")
	require.Equal(t, EvacuateProcessStateUndefined, st.ProcessingStatus(), "invalid init state")
	require.Equal(t, uint64(0), st.ObjectsEvacuated(), "invalid init count")
	require.Nil(t, st.StartedAt(), "invalid init started at")
	require.Nil(t, st.FinishedAt(), "invalid init finished at")
	require.ElementsMatch(t, []string{}, st.ShardIDs(), "invalid init shard ids")
	require.Equal(t, "", st.ErrorMessage(), "invalid init error message")

	eg, egCtx := errgroup.WithContext(context.Background())
	eg.Go(func() error {
		res, err := e.Evacuate(egCtx, prm)
		require.NoError(t, err, "first evacuation failed")
		require.Equal(t, uint64(3), res.ObjectsEvacuated())
		return nil
	})

	<-running

	st, err = e.GetEvacuationState(context.Background())
	require.NoError(t, err, "get running state failed")
	require.Equal(t, EvacuateProcessStateRunning, st.ProcessingStatus(), "invalid running state")
	require.Equal(t, uint64(0), st.ObjectsEvacuated(), "invalid running count")
	require.NotNil(t, st.StartedAt(), "invalid running started at")
	require.Nil(t, st.FinishedAt(), "invalid init finished at")
	expectedShardIDs := make([]string, 0, 2)
	for _, id := range ids[1:2] {
		expectedShardIDs = append(expectedShardIDs, id.String())
	}
	require.ElementsMatch(t, expectedShardIDs, st.ShardIDs(), "invalid running shard ids")
	require.Equal(t, "", st.ErrorMessage(), "invalid init error message")

	require.Error(t, e.ResetEvacuationStatus(context.Background()))

	close(blocker)

	require.Eventually(t, func() bool {
		st, err = e.GetEvacuationState(context.Background())
		return st.ProcessingStatus() == EvacuateProcessStateCompleted
	}, 3*time.Second, 10*time.Millisecond, "invalid final state")

	require.NoError(t, err, "get final state failed")
	require.Equal(t, uint64(3), st.ObjectsEvacuated(), "invalid final count")
	require.NotNil(t, st.StartedAt(), "invalid final started at")
	require.NotNil(t, st.FinishedAt(), "invalid final finished at")
	require.ElementsMatch(t, expectedShardIDs, st.ShardIDs(), "invalid final shard ids")
	require.Equal(t, "", st.ErrorMessage(), "invalid final error message")

	require.NoError(t, eg.Wait())

	require.NoError(t, e.ResetEvacuationStatus(context.Background()))
	st, err = e.GetEvacuationState(context.Background())
	require.NoError(t, err, "get state after reset failed")
	require.Equal(t, EvacuateProcessStateUndefined, st.ProcessingStatus(), "invalid state after reset")
	require.Equal(t, uint64(0), st.ObjectsEvacuated(), "invalid count after reset")
	require.Nil(t, st.StartedAt(), "invalid started at after reset")
	require.Nil(t, st.FinishedAt(), "invalid finished at after reset")
	require.ElementsMatch(t, []string{}, st.ShardIDs(), "invalid shard ids after reset")
	require.Equal(t, "", st.ErrorMessage(), "invalid error message after reset")
}

func TestEvacuateTreesLocal(t *testing.T) {
	e, ids, _ := newEngineEvacuate(t, 2, 3)
	defer func() {
		require.NoError(t, e.Close(context.Background()))
	}()

	require.NoError(t, e.shards[ids[0].String()].SetMode(mode.ReadOnly))

	var prm EvacuateShardPrm
	prm.ShardID = ids[0:1]
	prm.Scope = EvacuateScopeTrees

	expectedShardIDs := make([]string, 0, 1)
	for _, id := range ids[0:1] {
		expectedShardIDs = append(expectedShardIDs, id.String())
	}

	st, err := e.GetEvacuationState(context.Background())
	require.NoError(t, err, "get init state failed")
	require.Equal(t, EvacuateProcessStateUndefined, st.ProcessingStatus(), "invalid init state")
	require.Equal(t, uint64(0), st.TreesEvacuated(), "invalid init count")
	require.Nil(t, st.StartedAt(), "invalid init started at")
	require.Nil(t, st.FinishedAt(), "invalid init finished at")
	require.ElementsMatch(t, []string{}, st.ShardIDs(), "invalid init shard ids")
	require.Equal(t, "", st.ErrorMessage(), "invalid init error message")

	res, err := e.Evacuate(context.Background(), prm)
	require.NotNil(t, res, "sync evacuation result must be not nil")
	require.NoError(t, err, "evacuation failed")

	st, err = e.GetEvacuationState(context.Background())
	require.NoError(t, err, "get evacuation state failed")
	require.Equal(t, EvacuateProcessStateCompleted, st.ProcessingStatus())

	require.Equal(t, uint64(3), st.TreesTotal(), "invalid trees total count")
	require.Equal(t, uint64(3), st.TreesEvacuated(), "invalid trees evacuated count")
	require.Equal(t, uint64(0), st.TreesFailed(), "invalid trees failed count")
	require.NotNil(t, st.StartedAt(), "invalid final started at")
	require.NotNil(t, st.FinishedAt(), "invalid final finished at")
	require.ElementsMatch(t, expectedShardIDs, st.ShardIDs(), "invalid final shard ids")
	require.Equal(t, "", st.ErrorMessage(), "invalid final error message")

	sourceTrees, err := pilorama.TreeListAll(context.Background(), e.shards[ids[0].String()])
	require.NoError(t, err, "list source trees failed")
	require.Len(t, sourceTrees, 3)

	for _, tr := range sourceTrees {
		exists, err := e.shards[ids[1].String()].TreeExists(context.Background(), tr.CID, tr.TreeID)
		require.NoError(t, err, "failed to check tree existance")
		require.True(t, exists, "tree doesn't exists on target shard")

		var height uint64
		var sourceOps []pilorama.Move
		for {
			op, err := e.shards[ids[0].String()].TreeGetOpLog(context.Background(), tr.CID, tr.TreeID, height)
			require.NoError(t, err)
			if op.Time == 0 {
				break
			}
			sourceOps = append(sourceOps, op)
			height = op.Time + 1
		}

		height = 0
		var targetOps []pilorama.Move
		for {
			op, err := e.shards[ids[1].String()].TreeGetOpLog(context.Background(), tr.CID, tr.TreeID, height)
			require.NoError(t, err)
			if op.Time == 0 {
				break
			}
			targetOps = append(targetOps, op)
			height = op.Time + 1
		}

		require.Equal(t, sourceOps, targetOps)
	}
}

func TestEvacuateTreesRemote(t *testing.T) {
	e, ids, _ := newEngineEvacuate(t, 2, 3)
	defer func() {
		require.NoError(t, e.Close(context.Background()))
	}()

	require.NoError(t, e.shards[ids[0].String()].SetMode(mode.ReadOnly))
	require.NoError(t, e.shards[ids[1].String()].SetMode(mode.ReadOnly))

	evacuatedTreeOps := make(map[string][]*pilorama.Move)
	var prm EvacuateShardPrm
	prm.ShardID = ids
	prm.Scope = EvacuateScopeTrees
	prm.TreeHandler = func(ctx context.Context, contID cid.ID, treeID string, f pilorama.Forest) (string, error) {
		key := contID.String() + treeID
		var height uint64
		for {
			op, err := f.TreeGetOpLog(ctx, contID, treeID, height)
			require.NoError(t, err)

			if op.Time == 0 {
				return "", nil
			}
			evacuatedTreeOps[key] = append(evacuatedTreeOps[key], &op)
			height = op.Time + 1
		}
	}

	expectedShardIDs := make([]string, 0, len(ids))
	for _, id := range ids {
		expectedShardIDs = append(expectedShardIDs, id.String())
	}

	st, err := e.GetEvacuationState(context.Background())
	require.NoError(t, err, "get init state failed")
	require.Equal(t, EvacuateProcessStateUndefined, st.ProcessingStatus(), "invalid init state")
	require.Equal(t, uint64(0), st.TreesEvacuated(), "invalid init count")
	require.Nil(t, st.StartedAt(), "invalid init started at")
	require.Nil(t, st.FinishedAt(), "invalid init finished at")
	require.ElementsMatch(t, []string{}, st.ShardIDs(), "invalid init shard ids")
	require.Equal(t, "", st.ErrorMessage(), "invalid init error message")

	res, err := e.Evacuate(context.Background(), prm)
	require.NotNil(t, res, "sync evacuation must return not nil")
	require.NoError(t, err, "evacuation failed")

	st, err = e.GetEvacuationState(context.Background())
	require.NoError(t, err, "get evacuation state failed")
	require.Equal(t, EvacuateProcessStateCompleted, st.ProcessingStatus())

	require.NoError(t, err, "get final state failed")
	require.Equal(t, uint64(6), st.TreesTotal(), "invalid trees total count")
	require.Equal(t, uint64(6), st.TreesEvacuated(), "invalid trees evacuated count")
	require.Equal(t, uint64(0), st.TreesFailed(), "invalid trees failed count")
	require.NotNil(t, st.StartedAt(), "invalid final started at")
	require.NotNil(t, st.FinishedAt(), "invalid final finished at")
	require.ElementsMatch(t, expectedShardIDs, st.ShardIDs(), "invalid final shard ids")
	require.Equal(t, "", st.ErrorMessage(), "invalid final error message")

	expectedTreeOps := make(map[string][]*pilorama.Move)
	for i := 0; i < len(e.shards); i++ {
		sourceTrees, err := pilorama.TreeListAll(context.Background(), e.shards[ids[i].String()])
		require.NoError(t, err, "list source trees failed")
		require.Len(t, sourceTrees, 3)

		for _, tr := range sourceTrees {
			key := tr.CID.String() + tr.TreeID
			var height uint64
			for {
				op, err := e.shards[ids[i].String()].TreeGetOpLog(context.Background(), tr.CID, tr.TreeID, height)
				require.NoError(t, err)

				if op.Time == 0 {
					break
				}
				expectedTreeOps[key] = append(expectedTreeOps[key], &op)
				height = op.Time + 1
			}
		}
	}

	require.Equal(t, expectedTreeOps, evacuatedTreeOps)
}