package engine

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
	"strconv"
	"testing"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/common"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/teststore"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/testutil"
	meta "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/metabase"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/pilorama"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard/mode"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger/test"
	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"github.com/stretchr/testify/require"
)

const errSmallSize = 256

type testEngine struct {
	ng     *StorageEngine
	dir    string
	shards [2]*testShard
}

type testShard struct {
	id               *shard.ID
	smallFileStorage *teststore.TestStore
	largeFileStorage *teststore.TestStore
}

func newEngineWithErrorThreshold(t testing.TB, dir string, errThreshold uint32) *testEngine {
	if dir == "" {
		dir = t.TempDir()
	}

	var testShards [2]*testShard

	te := testNewEngine(t,
		WithShardPoolSize(1),
		WithErrorThreshold(errThreshold),
	).
		setShardsNumOpts(t, 2, func(id int) []shard.Option {
			storages, smallFileStorage, largeFileStorage := newTestStorages(filepath.Join(dir, strconv.Itoa(id)), errSmallSize)
			testShards[id] = &testShard{
				smallFileStorage: smallFileStorage,
				largeFileStorage: largeFileStorage,
			}
			return []shard.Option{
				shard.WithLogger(test.NewLogger(t)),
				shard.WithBlobStorOptions(blobstor.WithStorages(storages)),
				shard.WithMetaBaseOptions(
					meta.WithPath(filepath.Join(dir, fmt.Sprintf("%d.metabase", id))),
					meta.WithPermissions(0o700),
					meta.WithEpochState(epochState{}),
				),
				shard.WithPiloramaOptions(
					pilorama.WithPath(filepath.Join(dir, fmt.Sprintf("%d.pilorama", id))),
					pilorama.WithPerm(0o700)),
			}
		}).prepare(t)
	e := te.engine

	for i, id := range te.shardIDs {
		testShards[i].id = id
	}

	return &testEngine{
		ng:     e,
		dir:    dir,
		shards: testShards,
	}
}

func TestErrorReporting(t *testing.T) {
	t.Run("ignore errors by default", func(t *testing.T) {
		te := newEngineWithErrorThreshold(t, "", 0)

		obj := testutil.GenerateObjectWithCID(cidtest.ID())
		obj.SetPayload(make([]byte, errSmallSize))

		var prm shard.PutPrm
		prm.SetObject(obj)
		te.ng.mtx.RLock()
		_, err := te.ng.shards[te.shards[0].id.String()].Shard.Put(context.Background(), prm)
		te.ng.mtx.RUnlock()
		require.NoError(t, err)

		_, err = te.ng.Get(context.Background(), GetPrm{addr: object.AddressOf(obj)})
		require.NoError(t, err)

		checkShardState(t, te.ng, te.shards[0].id, 0, mode.ReadWrite)
		checkShardState(t, te.ng, te.shards[1].id, 0, mode.ReadWrite)

		for _, shard := range te.shards {
			shard.largeFileStorage.SetOption(teststore.WithGet(func(common.GetPrm) (common.GetRes, error) {
				return common.GetRes{}, teststore.ErrDiskExploded
			}))
		}

		for i := uint32(1); i < 3; i++ {
			_, err = te.ng.Get(context.Background(), GetPrm{addr: object.AddressOf(obj)})
			require.Error(t, err)
			checkShardState(t, te.ng, te.shards[0].id, i, mode.ReadWrite)
			checkShardState(t, te.ng, te.shards[1].id, 0, mode.ReadWrite)
		}
		require.NoError(t, te.ng.Close(context.Background()))
	})
	t.Run("with error threshold", func(t *testing.T) {
		const errThreshold = 3

		te := newEngineWithErrorThreshold(t, "", errThreshold)

		obj := testutil.GenerateObjectWithCID(cidtest.ID())
		obj.SetPayload(make([]byte, errSmallSize))

		var prm shard.PutPrm
		prm.SetObject(obj)
		te.ng.mtx.RLock()
		_, err := te.ng.shards[te.shards[0].id.String()].Put(context.Background(), prm)
		te.ng.mtx.RUnlock()
		require.NoError(t, err)

		_, err = te.ng.Get(context.Background(), GetPrm{addr: object.AddressOf(obj)})
		require.NoError(t, err)

		checkShardState(t, te.ng, te.shards[0].id, 0, mode.ReadWrite)
		checkShardState(t, te.ng, te.shards[1].id, 0, mode.ReadWrite)

		for _, shard := range te.shards {
			shard.largeFileStorage.SetOption(teststore.WithGet(func(common.GetPrm) (common.GetRes, error) {
				return common.GetRes{}, teststore.ErrDiskExploded
			}))
		}

		for i := uint32(1); i < errThreshold; i++ {
			_, err = te.ng.Get(context.Background(), GetPrm{addr: object.AddressOf(obj)})
			require.Error(t, err)
			checkShardState(t, te.ng, te.shards[0].id, i, mode.ReadWrite)
			checkShardState(t, te.ng, te.shards[1].id, 0, mode.ReadWrite)
		}

		for i := range uint32(2) {
			_, err = te.ng.Get(context.Background(), GetPrm{addr: object.AddressOf(obj)})
			require.Error(t, err)
			checkShardState(t, te.ng, te.shards[0].id, errThreshold+i, mode.ReadOnly)
			checkShardState(t, te.ng, te.shards[1].id, 0, mode.ReadWrite)
		}

		require.NoError(t, te.ng.SetShardMode(context.Background(), te.shards[0].id, mode.ReadWrite, false))
		checkShardState(t, te.ng, te.shards[0].id, errThreshold+1, mode.ReadWrite)

		require.NoError(t, te.ng.SetShardMode(context.Background(), te.shards[0].id, mode.ReadWrite, true))
		checkShardState(t, te.ng, te.shards[0].id, 0, mode.ReadWrite)
		require.NoError(t, te.ng.Close(context.Background()))
	})
}

func TestBlobstorFailback(t *testing.T) {
	dir := t.TempDir()

	te := newEngineWithErrorThreshold(t, dir, 1)

	objs := make([]*objectSDK.Object, 0, 2)
	for _, size := range []int{15, errSmallSize + 1} {
		obj := testutil.GenerateObjectWithCID(cidtest.ID())
		obj.SetPayload(make([]byte, size))

		var prm shard.PutPrm
		prm.SetObject(obj)
		te.ng.mtx.RLock()
		_, err := te.ng.shards[te.shards[0].id.String()].Shard.Put(context.Background(), prm)
		te.ng.mtx.RUnlock()
		require.NoError(t, err)
		objs = append(objs, obj)
	}

	for i := range objs {
		addr := object.AddressOf(objs[i])
		_, err := te.ng.Get(context.Background(), GetPrm{addr: addr})
		require.NoError(t, err)
		_, err = te.ng.GetRange(context.Background(), RngPrm{addr: addr})
		require.NoError(t, err)
	}

	checkShardState(t, te.ng, te.shards[0].id, 0, mode.ReadWrite)
	require.NoError(t, te.ng.Close(context.Background()))

	p1 := te.ng.shards[te.shards[0].id.String()].Shard.DumpInfo().BlobStorInfo.SubStorages[1].Path
	p2 := te.ng.shards[te.shards[1].id.String()].Shard.DumpInfo().BlobStorInfo.SubStorages[1].Path
	tmp := filepath.Join(dir, "tmp")
	require.NoError(t, os.Rename(p1, tmp))
	require.NoError(t, os.Rename(p2, p1))
	require.NoError(t, os.Rename(tmp, p2))

	te = newEngineWithErrorThreshold(t, dir, 1)

	for i := range objs {
		addr := object.AddressOf(objs[i])
		getRes, err := te.ng.Get(context.Background(), GetPrm{addr: addr})
		require.NoError(t, err)
		require.Equal(t, objs[i], getRes.Object())

		rngRes, err := te.ng.GetRange(context.Background(), RngPrm{addr: addr, off: 1, ln: 10})
		require.NoError(t, err)
		require.Equal(t, objs[i].Payload()[1:11], rngRes.Object().Payload())

		_, err = te.ng.GetRange(context.Background(), RngPrm{addr: addr, off: errSmallSize + 10, ln: 1})
		require.True(t, shard.IsErrOutOfRange(err))
	}

	checkShardState(t, te.ng, te.shards[0].id, 0, mode.ReadWrite)
	checkShardState(t, te.ng, te.shards[1].id, 0, mode.ReadWrite)
	require.NoError(t, te.ng.Close(context.Background()))
}

func checkShardState(t *testing.T, e *StorageEngine, id *shard.ID, errCount uint32, mode mode.Mode) {
	e.mtx.RLock()
	sh := e.shards[id.String()]
	e.mtx.RUnlock()

	require.Eventually(t, func() bool {
		return errCount == sh.errorCount.Load() &&
			mode == sh.GetMode()
	}, 10*time.Second, 10*time.Millisecond, "shard mode doesn't changed to expected state in 10 seconds")
}