package engine

import (
	"context"
	"sync"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/common"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/teststore"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/testutil"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard"
	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"github.com/stretchr/testify/require"
)

func TestRebalance(t *testing.T) {
	t.Parallel()

	te := newEngineWithErrorThreshold(t, "", 0)
	defer func() {
		require.NoError(t, te.ng.Close(context.Background()))
	}()

	const (
		objCount  = 20
		copyCount = (objCount + 2) / 3
	)

	type objectWithShard struct {
		bestShard  shard.ID
		worstShard shard.ID
		object     *objectSDK.Object
	}

	objects := make([]objectWithShard, objCount)
	for i := range objects {
		obj := testutil.GenerateObjectWithCID(cidtest.ID())
		obj.SetPayload(make([]byte, errSmallSize))
		objects[i].object = obj

		shards := te.ng.sortShards(object.AddressOf(obj))
		objects[i].bestShard = *shards[0].Shard.ID()
		objects[i].worstShard = *shards[1].Shard.ID()
	}

	for i := range objects {
		var prm shard.PutPrm
		prm.SetObject(objects[i].object)

		var err1, err2 error
		te.ng.mtx.RLock()
		// Every 3rd object (i%3 == 0) is put to both shards, others are distributed.
		if i%3 != 1 {
			_, err1 = te.ng.shards[te.shards[0].id.String()].Shard.Put(context.Background(), prm)
		}
		if i%3 != 2 {
			_, err2 = te.ng.shards[te.shards[1].id.String()].Shard.Put(context.Background(), prm)
		}
		te.ng.mtx.RUnlock()

		require.NoError(t, err1)
		require.NoError(t, err2)
	}

	var removedMtx sync.Mutex
	var removed []deleteEvent
	for _, shard := range te.shards {
		id := *shard.id
		shard.largeFileStorage.SetOption(teststore.WithDelete(func(prm common.DeletePrm) (common.DeleteRes, error) {
			removedMtx.Lock()
			removed = append(removed, deleteEvent{shardID: id, addr: prm.Address})
			removedMtx.Unlock()
			return common.DeleteRes{}, nil
		}))
	}

	err := te.ng.RemoveDuplicates(context.Background(), RemoveDuplicatesPrm{})
	require.NoError(t, err)

	require.Equal(t, copyCount, len(removed))

	removedMask := make([]bool, len(objects))
loop:
	for i := range removed {
		for j := range objects {
			if removed[i].addr == object.AddressOf(objects[j].object) {
				require.Equal(t, objects[j].worstShard, removed[i].shardID,
					"object %d was expected to be removed from another shard", j)
				removedMask[j] = true
				continue loop
			}
		}
		require.FailNow(t, "unexpected object was removed", removed[i].addr)
	}

	for i := range copyCount {
		if i%3 == 0 {
			require.True(t, removedMask[i], "object %d was expected to be removed", i)
		} else {
			require.False(t, removedMask[i], "object %d was not expected to be removed", i)
		}
	}
}

func TestRebalanceSingleThread(t *testing.T) {
	t.Parallel()

	te := newEngineWithErrorThreshold(t, "", 0)
	defer func() {
		require.NoError(t, te.ng.Close(context.Background()))
	}()

	obj := testutil.GenerateObjectWithCID(cidtest.ID())
	obj.SetPayload(make([]byte, errSmallSize))

	var prm shard.PutPrm
	prm.SetObject(obj)
	te.ng.mtx.RLock()
	_, err1 := te.ng.shards[te.shards[0].id.String()].Shard.Put(context.Background(), prm)
	_, err2 := te.ng.shards[te.shards[1].id.String()].Shard.Put(context.Background(), prm)
	te.ng.mtx.RUnlock()
	require.NoError(t, err1)
	require.NoError(t, err2)

	signal := make(chan struct{})  // unblock rebalance
	started := make(chan struct{}) // make sure rebalance is started
	for _, shard := range te.shards {
		shard.largeFileStorage.SetOption(teststore.WithDelete(func(common.DeletePrm) (common.DeleteRes, error) {
			close(started)
			<-signal
			return common.DeleteRes{}, nil
		}))
	}

	var firstErr error
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		firstErr = te.ng.RemoveDuplicates(context.Background(), RemoveDuplicatesPrm{})
	}()

	<-started
	secondErr := te.ng.RemoveDuplicates(context.Background(), RemoveDuplicatesPrm{})
	require.ErrorIs(t, secondErr, errRemoveDuplicatesInProgress)

	close(signal)
	wg.Wait()
	require.NoError(t, firstErr)
}

type deleteEvent struct {
	shardID shard.ID
	addr    oid.Address
}

func TestRebalanceExitByContext(t *testing.T) {
	te := newEngineWithErrorThreshold(t, "", 0)
	defer func() {
		require.NoError(t, te.ng.Close(context.Background()))
	}()

	objects := make([]*objectSDK.Object, 4)
	for i := range objects {
		obj := testutil.GenerateObjectWithCID(cidtest.ID())
		obj.SetPayload(make([]byte, errSmallSize))
		objects[i] = obj
	}

	for i := range objects {
		var prm shard.PutPrm
		prm.SetObject(objects[i])

		te.ng.mtx.RLock()
		_, err1 := te.ng.shards[te.shards[0].id.String()].Shard.Put(context.Background(), prm)
		_, err2 := te.ng.shards[te.shards[1].id.String()].Shard.Put(context.Background(), prm)
		te.ng.mtx.RUnlock()

		require.NoError(t, err1)
		require.NoError(t, err2)
	}

	var removed []deleteEvent
	deleteCh := make(chan struct{})
	signal := make(chan struct{})
	for _, shard := range te.shards {
		id := *shard.id
		shard.largeFileStorage.SetOption(teststore.WithDelete(func(prm common.DeletePrm) (common.DeleteRes, error) {
			deleteCh <- struct{}{}
			<-signal
			removed = append(removed, deleteEvent{shardID: id, addr: prm.Address})
			return common.DeleteRes{}, nil
		}))
	}

	ctx, cancel := context.WithCancel(context.Background())

	var rebalanceErr error
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		rebalanceErr = te.ng.RemoveDuplicates(ctx, RemoveDuplicatesPrm{Concurrency: 1})
	}()

	const removeCount = 3
	for range removeCount - 1 {
		<-deleteCh
		signal <- struct{}{}
	}
	<-deleteCh
	cancel()
	close(signal)

	wg.Wait()
	require.ErrorIs(t, rebalanceErr, context.Canceled)
	require.Equal(t, removeCount, len(removed))
}