package meta

import (
	"bytes"
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/metaerr"
	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"go.etcd.io/bbolt"
	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/trace"
)

var (
	objectPhyCounterKey   = []byte("phy_counter")
	objectLogicCounterKey = []byte("logic_counter")
	objectUserCounterKey  = []byte("user_counter")
)

var (
	errInvalidKeyLenght   = errors.New("invalid key length")
	errInvalidValueLenght = errors.New("invalid value length")
)

type objectType uint8

const (
	_ objectType = iota
	phy
	logical
	user
)

// ObjectCounters groups object counter
// according to metabase state.
type ObjectCounters struct {
	Logic uint64
	Phy   uint64
	User  uint64
}

func (o ObjectCounters) IsZero() bool {
	return o.Phy == 0 && o.Logic == 0 && o.User == 0
}

// ObjectCounters returns object counters that metabase has
// tracked since it was opened and initialized.
//
// Returns only the errors that do not allow reading counter
// in Bolt database.
func (db *DB) ObjectCounters() (cc ObjectCounters, err error) {
	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return ObjectCounters{}, ErrDegradedMode
	}

	err = db.boltDB.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket(shardInfoBucket)
		if b != nil {
			data := b.Get(objectPhyCounterKey)
			if len(data) == 8 {
				cc.Phy = binary.LittleEndian.Uint64(data)
			}

			data = b.Get(objectLogicCounterKey)
			if len(data) == 8 {
				cc.Logic = binary.LittleEndian.Uint64(data)
			}

			data = b.Get(objectUserCounterKey)
			if len(data) == 8 {
				cc.User = binary.LittleEndian.Uint64(data)
			}
		}

		return nil
	})

	return cc, metaerr.Wrap(err)
}

type ContainerCounters struct {
	Counts map[cid.ID]ObjectCounters
}

// ContainerCounters returns object counters for each container
// that metabase has tracked since it was opened and initialized.
//
// Returns only the errors that do not allow reading counter
// in Bolt database.
//
// It is guaranteed that the ContainerCounters fields are not nil.
func (db *DB) ContainerCounters(ctx context.Context) (ContainerCounters, error) {
	var (
		startedAt = time.Now()
		success   = false
	)
	defer func() {
		db.metrics.AddMethodDuration("ContainerCounters", time.Since(startedAt), success)
	}()

	ctx, span := tracing.StartSpanFromContext(ctx, "metabase.ContainerCounters")
	defer span.End()

	cc := ContainerCounters{
		Counts: make(map[cid.ID]ObjectCounters),
	}

	lastKey := make([]byte, cidSize)

	// there is no limit for containers count, so use batching with cancellation
	for {
		select {
		case <-ctx.Done():
			return cc, ctx.Err()
		default:
		}

		completed, err := db.containerCountersNextBatch(lastKey, func(id cid.ID, entity ObjectCounters) {
			cc.Counts[id] = entity
		})
		if err != nil {
			return cc, err
		}
		if completed {
			break
		}
	}

	success = true
	return cc, nil
}

func (db *DB) containerCountersNextBatch(lastKey []byte, f func(id cid.ID, entity ObjectCounters)) (bool, error) {
	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return false, ErrDegradedMode
	}

	counter := 0
	const batchSize = 1000

	err := db.boltDB.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket(containerCounterBucketName)
		if b == nil {
			return ErrInterruptIterator
		}
		c := b.Cursor()
		var key, value []byte
		for key, value = c.Seek(lastKey); key != nil; key, value = c.Next() {
			if bytes.Equal(lastKey, key) {
				continue
			}
			copy(lastKey, key)

			cnrID, err := parseContainerCounterKey(key)
			if err != nil {
				return err
			}
			ent, err := parseContainerCounterValue(value)
			if err != nil {
				return err
			}
			f(cnrID, ent)

			counter++
			if counter == batchSize {
				break
			}
		}

		if counter < batchSize { // last batch
			return ErrInterruptIterator
		}
		return nil
	})
	if err != nil {
		if errors.Is(err, ErrInterruptIterator) {
			return true, nil
		}
		return false, metaerr.Wrap(err)
	}
	return false, nil
}

func (db *DB) ContainerCount(ctx context.Context, id cid.ID) (ObjectCounters, error) {
	var (
		startedAt = time.Now()
		success   = false
	)
	defer func() {
		db.metrics.AddMethodDuration("ContainerCount", time.Since(startedAt), success)
	}()

	_, span := tracing.StartSpanFromContext(ctx, "metabase.ContainerCount")
	defer span.End()

	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return ObjectCounters{}, ErrDegradedMode
	}

	var result ObjectCounters

	err := db.boltDB.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket(containerCounterBucketName)
		key := make([]byte, cidSize)
		id.Encode(key)
		v := b.Get(key)
		if v == nil {
			return nil
		}
		var err error
		result, err = parseContainerCounterValue(v)
		return err
	})

	return result, metaerr.Wrap(err)
}

func (db *DB) incCounters(tx *bbolt.Tx, cnrID cid.ID, isUserObject bool) error {
	b := tx.Bucket(shardInfoBucket)
	if b == nil {
		return db.incContainerObjectCounter(tx, cnrID, isUserObject)
	}

	if err := db.updateShardObjectCounterBucket(b, phy, 1, true); err != nil {
		return fmt.Errorf("could not increase phy object counter: %w", err)
	}
	if err := db.updateShardObjectCounterBucket(b, logical, 1, true); err != nil {
		return fmt.Errorf("could not increase logical object counter: %w", err)
	}
	if isUserObject {
		if err := db.updateShardObjectCounterBucket(b, user, 1, true); err != nil {
			return fmt.Errorf("could not increase user object counter: %w", err)
		}
	}
	return db.incContainerObjectCounter(tx, cnrID, isUserObject)
}

func (db *DB) updateShardObjectCounter(tx *bbolt.Tx, typ objectType, delta uint64, inc bool) error {
	b := tx.Bucket(shardInfoBucket)
	if b == nil {
		return nil
	}

	return db.updateShardObjectCounterBucket(b, typ, delta, inc)
}

func (*DB) updateShardObjectCounterBucket(b *bbolt.Bucket, typ objectType, delta uint64, inc bool) error {
	var counter uint64
	var counterKey []byte

	switch typ {
	case phy:
		counterKey = objectPhyCounterKey
	case logical:
		counterKey = objectLogicCounterKey
	case user:
		counterKey = objectUserCounterKey
	default:
		panic("unknown object type counter")
	}

	data := b.Get(counterKey)
	if len(data) == 8 {
		counter = binary.LittleEndian.Uint64(data)
	}

	if inc {
		counter += delta
	} else if counter <= delta {
		counter = 0
	} else {
		counter -= delta
	}

	newCounter := make([]byte, 8)
	binary.LittleEndian.PutUint64(newCounter, counter)

	return b.Put(counterKey, newCounter)
}

func (db *DB) updateContainerCounter(tx *bbolt.Tx, delta map[cid.ID]ObjectCounters, inc bool) error {
	b := tx.Bucket(containerCounterBucketName)
	if b == nil {
		return nil
	}

	key := make([]byte, cidSize)
	for cnrID, cnrDelta := range delta {
		cnrID.Encode(key)
		if err := db.editContainerCounterValue(b, key, cnrDelta, inc); err != nil {
			return err
		}
	}
	return nil
}

func (*DB) editContainerCounterValue(b *bbolt.Bucket, key []byte, delta ObjectCounters, inc bool) error {
	var entity ObjectCounters
	var err error
	data := b.Get(key)
	if len(data) > 0 {
		entity, err = parseContainerCounterValue(data)
		if err != nil {
			return err
		}
	}
	entity.Phy = nextValue(entity.Phy, delta.Phy, inc)
	entity.Logic = nextValue(entity.Logic, delta.Logic, inc)
	entity.User = nextValue(entity.User, delta.User, inc)
	value := containerCounterValue(entity)
	return b.Put(key, value)
}

func nextValue(existed, delta uint64, inc bool) uint64 {
	if inc {
		existed += delta
	} else if existed <= delta {
		existed = 0
	} else {
		existed -= delta
	}
	return existed
}

func (db *DB) incContainerObjectCounter(tx *bbolt.Tx, cnrID cid.ID, isUserObject bool) error {
	b := tx.Bucket(containerCounterBucketName)
	if b == nil {
		return nil
	}

	key := make([]byte, cidSize)
	cnrID.Encode(key)
	c := ObjectCounters{Logic: 1, Phy: 1}
	if isUserObject {
		c.User = 1
	}
	return db.editContainerCounterValue(b, key, c, true)
}

// syncCounter updates object counters according to metabase state:
// it counts all the physically/logically stored objects using internal
// indexes. Tx MUST be writable.
//
// Does nothing if counters are not empty and force is false. If force is
// true, updates the counters anyway.
func syncCounter(tx *bbolt.Tx, force bool) error {
	shardInfoB, err := createBucketLikelyExists(tx, shardInfoBucket)
	if err != nil {
		return fmt.Errorf("could not get shard info bucket: %w", err)
	}
	shardObjectCounterInitialized := len(shardInfoB.Get(objectPhyCounterKey)) == 8 &&
		len(shardInfoB.Get(objectLogicCounterKey)) == 8 &&
		len(shardInfoB.Get(objectUserCounterKey)) == 8
	containerObjectCounterInitialized := containerObjectCounterInitialized(tx)
	if !force && shardObjectCounterInitialized && containerObjectCounterInitialized {
		// the counters are already inited
		return nil
	}

	containerCounterB, err := createBucketLikelyExists(tx, containerCounterBucketName)
	if err != nil {
		return fmt.Errorf("could not get container counter bucket: %w", err)
	}

	var addr oid.Address
	counters := make(map[cid.ID]ObjectCounters)

	graveyardBKT := tx.Bucket(graveyardBucketName)
	garbageBKT := tx.Bucket(garbageBucketName)
	key := make([]byte, addressKeySize)
	var isAvailable bool

	err = iteratePhyObjects(tx, func(cnr cid.ID, objID oid.ID, obj *objectSDK.Object) error {
		if v, ok := counters[cnr]; ok {
			v.Phy++
			counters[cnr] = v
		} else {
			counters[cnr] = ObjectCounters{
				Phy: 1,
			}
		}

		addr.SetContainer(cnr)
		addr.SetObject(objID)
		isAvailable = false

		// check if an object is available: not with GCMark
		// and not covered with a tombstone
		if inGraveyardWithKey(addressKey(addr, key), graveyardBKT, garbageBKT) == 0 {
			if v, ok := counters[cnr]; ok {
				v.Logic++
				counters[cnr] = v
			} else {
				counters[cnr] = ObjectCounters{
					Logic: 1,
				}
			}
			isAvailable = true
		}

		if isAvailable && IsUserObject(obj) {
			if v, ok := counters[cnr]; ok {
				v.User++
				counters[cnr] = v
			} else {
				counters[cnr] = ObjectCounters{
					User: 1,
				}
			}
		}

		return nil
	})
	if err != nil {
		return fmt.Errorf("could not iterate objects: %w", err)
	}

	return setObjectCounters(counters, shardInfoB, containerCounterB)
}

func setObjectCounters(counters map[cid.ID]ObjectCounters, shardInfoB, containerCounterB *bbolt.Bucket) error {
	var phyTotal uint64
	var logicTotal uint64
	var userTotal uint64
	key := make([]byte, cidSize)
	for cnrID, count := range counters {
		phyTotal += count.Phy
		logicTotal += count.Logic
		userTotal += count.User

		cnrID.Encode(key)
		value := containerCounterValue(count)
		err := containerCounterB.Put(key, value)
		if err != nil {
			return fmt.Errorf("could not update phy container object counter: %w", err)
		}
	}
	phyData := make([]byte, 8)
	binary.LittleEndian.PutUint64(phyData, phyTotal)

	err := shardInfoB.Put(objectPhyCounterKey, phyData)
	if err != nil {
		return fmt.Errorf("could not update phy object counter: %w", err)
	}

	logData := make([]byte, 8)
	binary.LittleEndian.PutUint64(logData, logicTotal)

	err = shardInfoB.Put(objectLogicCounterKey, logData)
	if err != nil {
		return fmt.Errorf("could not update logic object counter: %w", err)
	}

	userData := make([]byte, 8)
	binary.LittleEndian.PutUint64(userData, userTotal)

	err = shardInfoB.Put(objectUserCounterKey, userData)
	if err != nil {
		return fmt.Errorf("could not update user object counter: %w", err)
	}

	return nil
}

func containerCounterValue(entity ObjectCounters) []byte {
	res := make([]byte, 24)
	binary.LittleEndian.PutUint64(res, entity.Phy)
	binary.LittleEndian.PutUint64(res[8:], entity.Logic)
	binary.LittleEndian.PutUint64(res[16:], entity.User)
	return res
}

func parseContainerCounterKey(buf []byte) (cid.ID, error) {
	if len(buf) != cidSize {
		return cid.ID{}, errInvalidKeyLenght
	}
	var cnrID cid.ID
	if err := cnrID.Decode(buf); err != nil {
		return cid.ID{}, fmt.Errorf("failed to decode container ID: %w", err)
	}
	return cnrID, nil
}

// parseContainerCounterValue return phy, logic values.
func parseContainerCounterValue(buf []byte) (ObjectCounters, error) {
	if len(buf) != 24 {
		return ObjectCounters{}, errInvalidValueLenght
	}
	return ObjectCounters{
		Phy:   binary.LittleEndian.Uint64(buf),
		Logic: binary.LittleEndian.Uint64(buf[8:16]),
		User:  binary.LittleEndian.Uint64(buf[16:]),
	}, nil
}

func containerObjectCounterInitialized(tx *bbolt.Tx) bool {
	b := tx.Bucket(containerCounterBucketName)
	if b == nil {
		return false
	}
	k, v := b.Cursor().First()
	if k == nil && v == nil {
		return true
	}
	_, err := parseContainerCounterKey(k)
	if err != nil {
		return false
	}
	_, err = parseContainerCounterValue(v)
	return err == nil
}

func IsUserObject(obj *objectSDK.Object) bool {
	_, hasParentID := obj.ParentID()
	return obj.Type() == objectSDK.TypeRegular &&
		(obj.SplitID() == nil ||
			(hasParentID && len(obj.Children()) == 0))
}

// ZeroSizeContainers returns containers with size = 0.
func (db *DB) ZeroSizeContainers(ctx context.Context) ([]cid.ID, error) {
	var (
		startedAt = time.Now()
		success   = false
	)
	defer func() {
		db.metrics.AddMethodDuration("ZeroSizeContainers", time.Since(startedAt), success)
	}()

	ctx, span := tracing.StartSpanFromContext(ctx, "metabase.ZeroSizeContainers")
	defer span.End()

	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	var result []cid.ID
	lastKey := make([]byte, cidSize)

	for {
		select {
		case <-ctx.Done():
			return nil, ctx.Err()
		default:
		}

		completed, err := db.containerSizesNextBatch(lastKey, func(contID cid.ID, size uint64) {
			if size == 0 {
				result = append(result, contID)
			}
		})
		if err != nil {
			return nil, err
		}
		if completed {
			break
		}
	}

	success = true
	return result, nil
}

func (db *DB) containerSizesNextBatch(lastKey []byte, f func(cid.ID, uint64)) (bool, error) {
	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return false, ErrDegradedMode
	}

	counter := 0
	const batchSize = 1000

	err := db.boltDB.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket(containerVolumeBucketName)
		c := b.Cursor()
		var key, value []byte
		for key, value = c.Seek(lastKey); key != nil; key, value = c.Next() {
			if bytes.Equal(lastKey, key) {
				continue
			}
			copy(lastKey, key)

			size := parseContainerSize(value)
			var id cid.ID
			if err := id.Decode(key); err != nil {
				return err
			}
			f(id, size)

			counter++
			if counter == batchSize {
				break
			}
		}

		if counter < batchSize {
			return ErrInterruptIterator
		}
		return nil
	})
	if err != nil {
		if errors.Is(err, ErrInterruptIterator) {
			return true, nil
		}
		return false, metaerr.Wrap(err)
	}
	return false, nil
}

func (db *DB) DeleteContainerSize(ctx context.Context, id cid.ID) error {
	var (
		startedAt = time.Now()
		success   = false
	)
	defer func() {
		db.metrics.AddMethodDuration("DeleteContainerSize", time.Since(startedAt), success)
	}()

	_, span := tracing.StartSpanFromContext(ctx, "metabase.DeleteContainerSize",
		trace.WithAttributes(
			attribute.Stringer("container_id", id),
		))
	defer span.End()

	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return ErrDegradedMode
	}

	if db.mode.ReadOnly() {
		return ErrReadOnlyMode
	}

	err := db.boltDB.Update(func(tx *bbolt.Tx) error {
		b := tx.Bucket(containerVolumeBucketName)

		key := make([]byte, cidSize)
		id.Encode(key)
		return b.Delete(key)
	})
	success = err == nil
	return metaerr.Wrap(err)
}

// ZeroCountContainers returns containers with objects count = 0 in metabase.
func (db *DB) ZeroCountContainers(ctx context.Context) ([]cid.ID, error) {
	var (
		startedAt = time.Now()
		success   = false
	)
	defer func() {
		db.metrics.AddMethodDuration("ZeroCountContainers", time.Since(startedAt), success)
	}()

	ctx, span := tracing.StartSpanFromContext(ctx, "metabase.ZeroCountContainers")
	defer span.End()

	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return nil, ErrDegradedMode
	}

	var result []cid.ID

	lastKey := make([]byte, cidSize)
	for {
		select {
		case <-ctx.Done():
			return nil, ctx.Err()
		default:
		}

		completed, err := db.containerCountersNextBatch(lastKey, func(id cid.ID, entity ObjectCounters) {
			if entity.IsZero() {
				result = append(result, id)
			}
		})
		if err != nil {
			return nil, metaerr.Wrap(err)
		}
		if completed {
			break
		}
	}
	success = true
	return result, nil
}

func (db *DB) DeleteContainerCount(ctx context.Context, id cid.ID) error {
	var (
		startedAt = time.Now()
		success   = false
	)
	defer func() {
		db.metrics.AddMethodDuration("DeleteContainerCount", time.Since(startedAt), success)
	}()

	_, span := tracing.StartSpanFromContext(ctx, "metabase.DeleteContainerCount",
		trace.WithAttributes(
			attribute.Stringer("container_id", id),
		))
	defer span.End()

	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return ErrDegradedMode
	}

	if db.mode.ReadOnly() {
		return ErrReadOnlyMode
	}

	err := db.boltDB.Update(func(tx *bbolt.Tx) error {
		b := tx.Bucket(containerCounterBucketName)

		key := make([]byte, cidSize)
		id.Encode(key)
		return b.Delete(key)
	})
	success = err == nil
	return metaerr.Wrap(err)
}