package shard

import (
	"context"
	"path/filepath"
	"sync"
	"testing"
	"time"

	objectcore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/fstree"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/testutil"
	meta "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/metabase"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/pilorama"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard/mode"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"github.com/stretchr/testify/require"
)

type metricsStore struct {
	mtx         sync.Mutex
	objCounters map[string]uint64
	cnrSize     map[string]int64
	cnrCount    map[string]uint64
	pldSize     int64
	mode        mode.Mode
	errCounter  int64
}

func (m *metricsStore) SetShardID(_ string) {}

func (m *metricsStore) SetObjectCounter(objectType string, v uint64) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.objCounters[objectType] = v
}

func (m *metricsStore) getObjectCounter(objectType string) uint64 {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	return m.objCounters[objectType]
}

func (m *metricsStore) containerSizes() map[string]int64 {
	m.mtx.Lock()
	defer m.mtx.Unlock()

	r := make(map[string]int64, len(m.cnrSize))
	for c, s := range m.cnrSize {
		r[c] = s
	}
	return r
}

func (m *metricsStore) payloadSize() int64 {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	return m.pldSize
}

func (m *metricsStore) AddToObjectCounter(objectType string, delta int) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	switch {
	case delta > 0:
		m.objCounters[objectType] += uint64(delta)
	case delta < 0:
		uDelta := uint64(-delta)

		if m.objCounters[objectType] >= uDelta {
			m.objCounters[objectType] -= uDelta
		} else {
			m.objCounters[objectType] = 0
		}
	case delta == 0:
		return
	}
}

func (m *metricsStore) IncObjectCounter(objectType string) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.objCounters[objectType] += 1
}

func (m *metricsStore) DecObjectCounter(objectType string) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.AddToObjectCounter(objectType, -1)
}

func (m *metricsStore) SetMode(mode mode.Mode) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.mode = mode
}

func (m *metricsStore) AddToContainerSize(cnr string, size int64) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.cnrSize[cnr] += size
}

func (m *metricsStore) AddToPayloadSize(size int64) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.pldSize += size
}

func (m *metricsStore) IncErrorCounter() {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.errCounter += 1
}

func (m *metricsStore) ClearErrorCounter() {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.errCounter = 0
}

func (m *metricsStore) DeleteShardMetrics() {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.errCounter = 0
}

func (m *metricsStore) SetContainerObjectsCount(cnrID string, objectType string, value uint64) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.cnrCount[cnrID+objectType] = value
}

func (m *metricsStore) IncContainerObjectsCount(cnrID string, objectType string) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	m.cnrCount[cnrID+objectType]++
}

func (m *metricsStore) SubContainerObjectsCount(cnrID string, objectType string, value uint64) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	existed := m.cnrCount[cnrID+objectType]
	if existed < value {
		panic("existed value smaller than value to sustract")
	}
	if existed == value {
		delete(m.cnrCount, cnrID+objectType)
	} else {
		m.cnrCount[cnrID+objectType] -= value
	}
}

func (m *metricsStore) getContainerCount(cnrID, objectType string) (uint64, bool) {
	m.mtx.Lock()
	defer m.mtx.Unlock()
	v, ok := m.cnrCount[cnrID+objectType]
	return v, ok
}

func TestCounters(t *testing.T) {
	t.Parallel()

	dir := t.TempDir()
	sh, mm := shardWithMetrics(t, dir)

	sh.SetMode(mode.ReadOnly)
	require.Equal(t, mode.ReadOnly, mm.mode)
	sh.SetMode(mode.ReadWrite)
	require.Equal(t, mode.ReadWrite, mm.mode)

	const objNumber = 10
	oo := make([]*objectSDK.Object, objNumber)
	for i := 0; i < objNumber; i++ {
		oo[i] = testutil.GenerateObject()
	}

	cc := meta.ContainerCounters{Logical: make(map[cid.ID]uint64), Physical: make(map[cid.ID]uint64)}

	t.Run("defaults", func(t *testing.T) {
		require.Zero(t, mm.getObjectCounter(physical))
		require.Zero(t, mm.getObjectCounter(logical))
		require.Empty(t, mm.containerSizes())
		require.Zero(t, mm.payloadSize())

		for _, obj := range oo {
			contID, _ := obj.ContainerID()
			v, ok := mm.getContainerCount(contID.EncodeToString(), physical)
			require.Zero(t, v)
			require.False(t, ok)
			v, ok = mm.getContainerCount(contID.EncodeToString(), logical)
			require.Zero(t, v)
			require.False(t, ok)
		}
	})

	var totalPayload int64

	expectedLogicalSizes := make(map[string]int64)
	expectedLogCC := make(map[cid.ID]uint64)
	expectedPhyCC := make(map[cid.ID]uint64)
	for i := range oo {
		cnr, _ := oo[i].ContainerID()
		oSize := int64(oo[i].PayloadSize())
		expectedLogicalSizes[cnr.EncodeToString()] += oSize
		totalPayload += oSize
		expectedLogCC[cnr]++
		expectedPhyCC[cnr]++
	}

	var prm PutPrm

	for i := 0; i < objNumber; i++ {
		prm.SetObject(oo[i])

		_, err := sh.Put(context.Background(), prm)
		require.NoError(t, err)
	}

	require.Equal(t, uint64(objNumber), mm.getObjectCounter(physical))
	require.Equal(t, uint64(objNumber), mm.getObjectCounter(logical))
	require.Equal(t, expectedLogicalSizes, mm.containerSizes())
	require.Equal(t, totalPayload, mm.payloadSize())

	cc, err := sh.metaBase.ContainerCounters(context.Background())
	require.NoError(t, err)
	require.Equal(t, expectedLogCC, cc.Logical)
	require.Equal(t, expectedPhyCC, cc.Physical)

	t.Run("inhume_GC", func(t *testing.T) {
		var prm InhumePrm
		inhumedNumber := objNumber / 4

		for i := 0; i < inhumedNumber; i++ {
			prm.MarkAsGarbage(objectcore.AddressOf(oo[i]))

			_, err := sh.Inhume(context.Background(), prm)
			require.NoError(t, err)

			cid, ok := oo[i].ContainerID()
			require.True(t, ok)
			expectedLogicalSizes[cid.EncodeToString()] -= int64(oo[i].PayloadSize())

			expectedLogCC[cid]--
			if expectedLogCC[cid] == 0 {
				delete(expectedLogCC, cid)
			}
		}

		require.Equal(t, uint64(objNumber), mm.getObjectCounter(physical))
		require.Equal(t, uint64(objNumber-inhumedNumber), mm.getObjectCounter(logical))
		require.Equal(t, expectedLogicalSizes, mm.containerSizes())
		require.Equal(t, totalPayload, mm.payloadSize())

		cc, err := sh.metaBase.ContainerCounters(context.Background())
		require.NoError(t, err)
		require.Equal(t, expectedLogCC, cc.Logical)
		require.Equal(t, expectedPhyCC, cc.Physical)

		oo = oo[inhumedNumber:]
	})

	t.Run("inhume_TS", func(t *testing.T) {
		var prm InhumePrm
		ts := objectcore.AddressOf(testutil.GenerateObject())

		phy := mm.getObjectCounter(physical)
		logic := mm.getObjectCounter(logical)

		inhumedNumber := int(phy / 4)
		prm.SetTarget(ts, addrFromObjs(oo[:inhumedNumber])...)

		_, err := sh.Inhume(context.Background(), prm)
		require.NoError(t, err)

		for i := 0; i < inhumedNumber; i++ {
			cid, ok := oo[i].ContainerID()
			require.True(t, ok)
			expectedLogicalSizes[cid.EncodeToString()] -= int64(oo[i].PayloadSize())

			expectedLogCC[cid]--
			if expectedLogCC[cid] == 0 {
				delete(expectedLogCC, cid)
			}
		}

		require.Equal(t, phy, mm.getObjectCounter(physical))
		require.Equal(t, logic-uint64(inhumedNumber), mm.getObjectCounter(logical))
		require.Equal(t, expectedLogicalSizes, mm.containerSizes())
		require.Equal(t, totalPayload, mm.payloadSize())

		cc, err = sh.metaBase.ContainerCounters(context.Background())
		require.NoError(t, err)
		require.Equal(t, expectedLogCC, cc.Logical)
		require.Equal(t, expectedPhyCC, cc.Physical)

		oo = oo[inhumedNumber:]
	})

	t.Run("Delete", func(t *testing.T) {
		var prm DeletePrm

		phy := mm.getObjectCounter(physical)
		logic := mm.getObjectCounter(logical)

		deletedNumber := int(phy / 4)
		prm.SetAddresses(addrFromObjs(oo[:deletedNumber])...)

		_, err := sh.Delete(context.Background(), prm)
		require.NoError(t, err)

		require.Equal(t, phy-uint64(deletedNumber), mm.getObjectCounter(physical))
		require.Equal(t, logic-uint64(deletedNumber), mm.getObjectCounter(logical))
		var totalRemovedpayload uint64
		for i := range oo[:deletedNumber] {
			removedPayload := oo[i].PayloadSize()
			totalRemovedpayload += removedPayload

			cnr, _ := oo[i].ContainerID()
			expectedLogicalSizes[cnr.EncodeToString()] -= int64(removedPayload)

			expectedLogCC[cnr]--
			if expectedLogCC[cnr] == 0 {
				delete(expectedLogCC, cnr)
			}

			expectedPhyCC[cnr]--
			if expectedPhyCC[cnr] == 0 {
				delete(expectedPhyCC, cnr)
			}
		}
		require.Equal(t, expectedLogicalSizes, mm.containerSizes())
		require.Equal(t, totalPayload-int64(totalRemovedpayload), mm.payloadSize())

		cc, err = sh.metaBase.ContainerCounters(context.Background())
		require.NoError(t, err)
		require.Equal(t, expectedLogCC, cc.Logical)
		require.Equal(t, expectedPhyCC, cc.Physical)
	})
}

func shardWithMetrics(t *testing.T, path string) (*Shard, *metricsStore) {
	blobOpts := []blobstor.Option{
		blobstor.WithStorages([]blobstor.SubStorage{
			{
				Storage: fstree.New(
					fstree.WithDirNameLen(2),
					fstree.WithPath(filepath.Join(path, "blob")),
					fstree.WithDepth(1)),
			},
		}),
	}

	mm := &metricsStore{
		objCounters: map[string]uint64{
			"phy":   0,
			"logic": 0,
		},
		cnrSize:  make(map[string]int64),
		cnrCount: make(map[string]uint64),
	}

	sh := New(
		WithID(NewIDFromBytes([]byte{})),
		WithBlobStorOptions(blobOpts...),
		WithPiloramaOptions(pilorama.WithPath(filepath.Join(path, "pilorama"))),
		WithMetaBaseOptions(
			meta.WithPath(filepath.Join(path, "meta")),
			meta.WithEpochState(epochState{})),
		WithMetricsWriter(mm),
		WithGCRemoverSleepInterval(time.Hour),
	)
	require.NoError(t, sh.Open(context.Background()))
	require.NoError(t, sh.Init(context.Background()))

	t.Cleanup(func() {
		sh.Close()
	})

	return sh, mm
}

func addrFromObjs(oo []*objectSDK.Object) []oid.Address {
	aa := make([]oid.Address, len(oo))

	for i := 0; i < len(oo); i++ {
		aa[i] = objectcore.AddressOf(oo[i])
	}

	return aa
}