package meta_test

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"os"
	"runtime"
	"testing"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/testutil"
	meta "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/metabase"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client"
	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/stretchr/testify/require"
)

func TestDB_Get(t *testing.T) {
	db := newDB(t, meta.WithEpochState(epochState{currEpoch}))
	defer func() { require.NoError(t, db.Close()) }()

	raw := testutil.GenerateObject()

	// equal fails on diff of <nil> attributes and <{}> attributes,
	/* so we make non empty attribute slice in parent*/
	testutil.AddAttribute(raw, "foo", "bar")

	t.Run("object not found", func(t *testing.T) {
		_, err := metaGet(db, object.AddressOf(raw), false)
		require.Error(t, err)
	})

	t.Run("put regular object", func(t *testing.T) {
		err := putBig(db, raw)
		require.NoError(t, err)

		newObj, err := metaGet(db, object.AddressOf(raw), false)
		require.NoError(t, err)
		require.Equal(t, raw.CutPayload(), newObj)
	})

	t.Run("put tombstone object", func(t *testing.T) {
		raw.SetType(objectSDK.TypeTombstone)
		raw.SetID(oidtest.ID())

		err := putBig(db, raw)
		require.NoError(t, err)

		newObj, err := metaGet(db, object.AddressOf(raw), false)
		require.NoError(t, err)
		require.Equal(t, raw.CutPayload(), newObj)
	})

	t.Run("put lock object", func(t *testing.T) {
		raw.SetType(objectSDK.TypeLock)
		raw.SetID(oidtest.ID())

		err := putBig(db, raw)
		require.NoError(t, err)

		newObj, err := metaGet(db, object.AddressOf(raw), false)
		require.NoError(t, err)
		require.Equal(t, raw.CutPayload(), newObj)
	})

	t.Run("put virtual object", func(t *testing.T) {
		cnr := cidtest.ID()
		splitID := objectSDK.NewSplitID()

		parent := testutil.GenerateObjectWithCID(cnr)
		testutil.AddAttribute(parent, "foo", "bar")

		child := testutil.GenerateObjectWithCID(cnr)
		child.SetParent(parent)
		idParent, _ := parent.ID()
		child.SetParentID(idParent)
		child.SetSplitID(splitID)

		err := putBig(db, child)
		require.NoError(t, err)

		t.Run("raw is true", func(t *testing.T) {
			_, err = metaGet(db, object.AddressOf(parent), true)
			require.Error(t, err)

			var siErr *objectSDK.SplitInfoError
			require.ErrorAs(t, err, &siErr)
			require.Equal(t, splitID, siErr.SplitInfo().SplitID())

			id1, _ := child.ID()
			id2, _ := siErr.SplitInfo().LastPart()
			require.Equal(t, id1, id2)

			_, ok := siErr.SplitInfo().Link()
			require.False(t, ok)
		})

		newParent, err := metaGet(db, object.AddressOf(parent), false)
		require.NoError(t, err)
		require.True(t, binaryEqual(parent.CutPayload(), newParent))

		newChild, err := metaGet(db, object.AddressOf(child), true)
		require.NoError(t, err)
		require.True(t, binaryEqual(child.CutPayload(), newChild))
	})

	t.Run("put erasure-coded object", func(t *testing.T) {
		cnr := cidtest.ID()
		virtual := testutil.GenerateObjectWithCID(cnr)
		c, err := erasurecode.NewConstructor(3, 1)
		require.NoError(t, err)
		pk, err := keys.NewPrivateKey()
		require.NoError(t, err)
		parts, err := c.Split(virtual, &pk.PrivateKey)
		require.NoError(t, err)
		for _, part := range parts {
			err = putBig(db, part)
			var eiError *objectSDK.ECInfoError
			if err != nil && !errors.As(err, &eiError) {
				require.NoError(t, err)
			}
		}
		_, err = metaGet(db, object.AddressOf(virtual), true)
		var eiError *objectSDK.ECInfoError
		require.ErrorAs(t, err, &eiError)
		require.Equal(t, len(eiError.ECInfo().Chunks), len(parts))
		for _, chunk := range eiError.ECInfo().Chunks {
			var found bool
			for _, part := range parts {
				partID, _ := part.ID()
				var chunkID oid.ID
				require.NoError(t, chunkID.ReadFromV2(chunk.ID))
				if chunkID.Equals(partID) {
					found = true
				}
			}
			if !found {
				require.Fail(t, "chunk not found")
			}
		}
	})

	t.Run("get removed object", func(t *testing.T) {
		obj := oidtest.Address()

		require.NoError(t, metaInhume(db, obj, oidtest.ID()))
		_, err := metaGet(db, obj, false)
		require.True(t, client.IsErrObjectAlreadyRemoved(err))

		obj = oidtest.Address()

		var prm meta.InhumePrm
		prm.SetAddresses(obj)

		_, err = db.Inhume(context.Background(), prm)
		require.NoError(t, err)
		_, err = metaGet(db, obj, false)
		require.True(t, client.IsErrObjectNotFound(err))
	})

	t.Run("expired object", func(t *testing.T) {
		checkExpiredObjects(t, db, func(exp, nonExp *objectSDK.Object) {
			gotExp, err := metaGet(db, object.AddressOf(exp), false)
			require.Nil(t, gotExp)
			require.ErrorIs(t, err, meta.ErrObjectIsExpired)

			gotNonExp, err := metaGet(db, object.AddressOf(nonExp), false)
			require.NoError(t, err)
			require.True(t, binaryEqual(gotNonExp, nonExp.CutPayload()))
		})
	})
}

// binary equal is used when object contains empty lists in the structure and
// requre.Equal fails on comparing <nil> and []{} lists.
func binaryEqual(a, b *objectSDK.Object) bool {
	binaryA, err := a.Marshal()
	if err != nil {
		return false
	}

	binaryB, err := b.Marshal()
	if err != nil {
		return false
	}

	return bytes.Equal(binaryA, binaryB)
}

func BenchmarkGet(b *testing.B) {
	numOfObjects := [...]int{
		1,
		10,
		100,
	}

	defer func() {
		_ = os.RemoveAll(b.Name())
	}()

	for _, num := range numOfObjects {
		b.Run(fmt.Sprintf("%d_objects", num), func(b *testing.B) {
			benchmarkGet(b, num)
		})
	}
}

func benchmarkGet(b *testing.B, numOfObj int) {
	prepareDb := func(batchSize int) (*meta.DB, []oid.Address) {
		db := newDB(b,
			meta.WithMaxBatchSize(batchSize),
			meta.WithMaxBatchDelay(10*time.Millisecond),
		)
		defer func() { require.NoError(b, db.Close()) }()
		addrs := make([]oid.Address, 0, numOfObj)

		for range numOfObj {
			raw := testutil.GenerateObject()
			addrs = append(addrs, object.AddressOf(raw))

			err := putBig(db, raw)
			require.NoError(b, err)
		}

		return db, addrs
	}

	db, addrs := prepareDb(runtime.NumCPU())

	b.Run("parallel", func(b *testing.B) {
		b.ReportAllocs()
		b.RunParallel(func(pb *testing.PB) {
			var counter int

			for pb.Next() {
				var getPrm meta.GetPrm
				getPrm.SetAddress(addrs[counter%len(addrs)])
				counter++

				_, err := db.Get(context.Background(), getPrm)
				if err != nil {
					b.Fatal(err)
				}
			}
		})
	})

	require.NoError(b, db.Close())
	require.NoError(b, os.RemoveAll(b.Name()))

	db, addrs = prepareDb(1)

	b.Run("serial", func(b *testing.B) {
		b.ReportAllocs()
		for i := range b.N {
			var getPrm meta.GetPrm
			getPrm.SetAddress(addrs[i%len(addrs)])

			_, err := db.Get(context.Background(), getPrm)
			if err != nil {
				b.Fatal(err)
			}
		}
	})
}

func metaGet(db *meta.DB, addr oid.Address, raw bool) (*objectSDK.Object, error) {
	var prm meta.GetPrm
	prm.SetAddress(addr)
	prm.SetRaw(raw)

	res, err := db.Get(context.Background(), prm)
	return res.Header(), err
}