package blobstor

import (
	"context"
	"fmt"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/blobovniczatree"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/common"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/fstree"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/memstore"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/testutil"
	"github.com/stretchr/testify/require"
	"golang.org/x/sync/errgroup"
)

type storage struct {
	desc   string
	create func(string) common.Storage
}

func (s storage) open(b *testing.B) common.Storage {
	st := s.create(b.TempDir())

	require.NoError(b, st.Open(false))
	require.NoError(b, st.Init())

	b.Cleanup(func() {
		require.NoError(b, st.Close())
	})

	return st
}

// The storages to benchmark. Each storage has a description and a function which returns the actual
// storage along with a cleanup function.
var storages = []storage{
	{
		desc: "memstore",
		create: func(string) common.Storage {
			return memstore.New()
		},
	},
	{
		desc: "fstree_nosync",
		create: func(dir string) common.Storage {
			return fstree.New(
				fstree.WithPath(dir),
				fstree.WithDepth(2),
				fstree.WithDirNameLen(2),
				fstree.WithNoSync(true),
			)
		},
	},
	{
		desc: "fstree_without_object_counter",
		create: func(dir string) common.Storage {
			return fstree.New(
				fstree.WithPath(dir),
				fstree.WithDepth(2),
				fstree.WithDirNameLen(2),
			)
		},
	},
	{
		desc: "fstree_with_object_counter",
		create: func(dir string) common.Storage {
			return fstree.New(
				fstree.WithPath(dir),
				fstree.WithDepth(2),
				fstree.WithDirNameLen(2),
				fstree.WithFileCounter(fstree.NewSimpleCounter()),
			)
		},
	},
	{
		desc: "blobovniczatree",
		create: func(dir string) common.Storage {
			return blobovniczatree.NewBlobovniczaTree(
				blobovniczatree.WithRootPath(dir),
			)
		},
	},
}

func BenchmarkSubstorageReadPerf(b *testing.B) {
	readTests := []struct {
		desc    string
		size    int
		objGen  func() testutil.ObjectGenerator
		addrGen func() testutil.AddressGenerator
	}{
		{
			desc:    "seq100",
			size:    10000,
			objGen:  func() testutil.ObjectGenerator { return &testutil.SeqObjGenerator{ObjSize: 100} },
			addrGen: func() testutil.AddressGenerator { return &testutil.SeqAddrGenerator{MaxID: 100} },
		},
		{
			desc:    "rand100",
			size:    10000,
			objGen:  func() testutil.ObjectGenerator { return &testutil.SeqObjGenerator{ObjSize: 100} },
			addrGen: func() testutil.AddressGenerator { return testutil.RandAddrGenerator(10000) },
		},
	}
	for _, tt := range readTests {
		for _, stEntry := range storages {
			b.Run(fmt.Sprintf("%s-%s", stEntry.desc, tt.desc), func(b *testing.B) {
				objGen := tt.objGen()
				st := stEntry.open(b)

				// Fill database
				var errG errgroup.Group
				for i := 0; i < tt.size; i++ {
					obj := objGen.Next()
					addr := testutil.AddressFromObject(b, obj)
					errG.Go(func() error {
						raw, err := obj.Marshal()
						if err != nil {
							return fmt.Errorf("marshal: %v", err)
						}
						_, err = st.Put(context.Background(), common.PutPrm{
							Address: addr,
							RawData: raw,
						})
						return err
					})
				}
				require.NoError(b, errG.Wait())

				// Benchmark reading
				addrGen := tt.addrGen()
				b.ResetTimer()
				b.RunParallel(func(pb *testing.PB) {
					for pb.Next() {
						_, err := st.Get(context.Background(), common.GetPrm{Address: addrGen.Next()})
						require.NoError(b, err)
					}
				})
			})
		}
	}
}

func BenchmarkSubstorageWritePerf(b *testing.B) {
	generators := []struct {
		desc   string
		create func() testutil.ObjectGenerator
	}{
		{desc: "rand10", create: func() testutil.ObjectGenerator { return &testutil.RandObjGenerator{ObjSize: 10} }},
		{desc: "rand100", create: func() testutil.ObjectGenerator { return &testutil.RandObjGenerator{ObjSize: 100} }},
		{desc: "rand1000", create: func() testutil.ObjectGenerator { return &testutil.RandObjGenerator{ObjSize: 1000} }},
		{desc: "overwrite10", create: func() testutil.ObjectGenerator { return &testutil.OverwriteObjGenerator{ObjSize: 10, MaxObjects: 100} }},
		{desc: "overwrite100", create: func() testutil.ObjectGenerator { return &testutil.OverwriteObjGenerator{ObjSize: 100, MaxObjects: 100} }},
		{desc: "overwrite1000", create: func() testutil.ObjectGenerator {
			return &testutil.OverwriteObjGenerator{ObjSize: 1000, MaxObjects: 100}
		}},
	}

	for _, genEntry := range generators {
		for _, stEntry := range storages {
			b.Run(fmt.Sprintf("%s-%s", stEntry.desc, genEntry.desc), func(b *testing.B) {
				gen := genEntry.create()
				st := stEntry.open(b)

				b.ResetTimer()
				b.RunParallel(func(pb *testing.PB) {
					for pb.Next() {
						obj := gen.Next()
						addr := testutil.AddressFromObject(b, obj)
						raw, err := obj.Marshal()
						require.NoError(b, err)
						if _, err := st.Put(context.Background(), common.PutPrm{
							Address: addr,
							RawData: raw,
						}); err != nil {
							b.Fatalf("writing entry: %v", err)
						}
					}
				})
			})
		}
	}
}

func BenchmarkSubstorageIteratePerf(b *testing.B) {
	iterateTests := []struct {
		desc   string
		size   int
		objGen func() testutil.ObjectGenerator
	}{
		{
			desc:   "rand100",
			size:   10000,
			objGen: func() testutil.ObjectGenerator { return &testutil.RandObjGenerator{ObjSize: 100} },
		},
	}
	for _, tt := range iterateTests {
		for _, stEntry := range storages {
			b.Run(fmt.Sprintf("%s-%s", stEntry.desc, tt.desc), func(b *testing.B) {
				objGen := tt.objGen()
				st := stEntry.open(b)

				// Fill database
				for i := 0; i < tt.size; i++ {
					obj := objGen.Next()
					addr := testutil.AddressFromObject(b, obj)
					raw, err := obj.Marshal()
					require.NoError(b, err)
					if _, err := st.Put(context.Background(), common.PutPrm{
						Address: addr,
						RawData: raw,
					}); err != nil {
						b.Fatalf("writing entry: %v", err)
					}
				}

				// Benchmark iterate
				cnt := 0
				b.ResetTimer()
				_, err := st.Iterate(context.Background(), common.IteratePrm{
					Handler: func(elem common.IterationElement) error {
						cnt++
						return nil
					},
				})
				require.NoError(b, err)
				require.Equal(b, tt.size, cnt)
			})
		}
	}
}