package meta

import (
	"bytes"
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/internal/metaerr"
	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"go.etcd.io/bbolt"
)

var (
	objectPhyCounterKey   = []byte("phy_counter")
	objectLogicCounterKey = []byte("logic_counter")
)

type objectType uint8

const (
	_ objectType = iota
	phy
	logical
)

// ObjectCounters groups object counter
// according to metabase state.
type ObjectCounters struct {
	logic uint64
	phy   uint64
}

// Logic returns logical object counter.
func (o ObjectCounters) Logic() uint64 {
	return o.logic
}

// Phy returns physical object counter.
func (o ObjectCounters) Phy() uint64 {
	return o.phy
}

// ObjectCounters returns object counters that metabase has
// tracked since it was opened and initialized.
//
// Returns only the errors that do not allow reading counter
// in Bolt database.
func (db *DB) ObjectCounters() (cc ObjectCounters, err error) {
	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return ObjectCounters{}, ErrDegradedMode
	}

	err = db.boltDB.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket(shardInfoBucket)
		if b != nil {
			data := b.Get(objectPhyCounterKey)
			if len(data) == 8 {
				cc.phy = binary.LittleEndian.Uint64(data)
			}

			data = b.Get(objectLogicCounterKey)
			if len(data) == 8 {
				cc.logic = binary.LittleEndian.Uint64(data)
			}
		}

		return nil
	})

	return cc, metaerr.Wrap(err)
}

type ContainerCounters struct {
	Logical  map[cid.ID]uint64
	Physical map[cid.ID]uint64
}

// ContainerCounters returns object counters for each container
// that metabase has tracked since it was opened and initialized.
//
// Returns only the errors that do not allow reading counter
// in Bolt database.
//
// It is guaranteed that the ContainerCounters fields are not nil.
func (db *DB) ContainerCounters(ctx context.Context) (ContainerCounters, error) {
	var (
		startedAt = time.Now()
		success   = false
	)
	defer func() {
		db.metrics.AddMethodDuration("ContainerCounters", time.Since(startedAt), success)
	}()

	ctx, span := tracing.StartSpanFromContext(ctx, "metabase.ContainerCounters")
	defer span.End()

	cc := ContainerCounters{
		Logical:  make(map[cid.ID]uint64),
		Physical: make(map[cid.ID]uint64),
	}

	lastKey := make([]byte, cidSize)

	// there is no limit for containers count, so use batching with cancellation
	for {
		select {
		case <-ctx.Done():
			return cc, ctx.Err()
		default:
		}

		completed, err := db.containerCountersNextBatch(lastKey, &cc)
		if err != nil {
			return cc, err
		}
		if completed {
			break
		}
	}

	success = true
	return cc, nil
}

func (db *DB) containerCountersNextBatch(lastKey []byte, cc *ContainerCounters) (bool, error) {
	db.modeMtx.RLock()
	defer db.modeMtx.RUnlock()

	if db.mode.NoMetabase() {
		return false, ErrDegradedMode
	}

	counter := 0
	const batchSize = 1000

	err := db.boltDB.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket(containerCounterBucketName)
		if b == nil {
			return ErrInterruptIterator
		}
		c := b.Cursor()
		var key, value []byte
		for key, value = c.Seek(lastKey); key != nil; key, value = c.Next() {
			if bytes.Equal(lastKey, key) {
				continue
			}
			copy(lastKey, key)

			cnrID, err := parseContainerCounterKey(key)
			if err != nil {
				return err
			}
			phy, logic, err := parseContainerCounterValue(value)
			if err != nil {
				return err
			}
			if phy > 0 {
				cc.Physical[cnrID] = phy
			}
			if logic > 0 {
				cc.Logical[cnrID] = logic
			}

			counter++
			if counter == batchSize {
				break
			}
		}

		if counter < batchSize { // last batch
			return ErrInterruptIterator
		}
		return nil
	})
	if err != nil {
		if errors.Is(err, ErrInterruptIterator) {
			return true, nil
		}
		return false, metaerr.Wrap(err)
	}
	return false, nil
}

func (db *DB) incCounters(tx *bbolt.Tx, cnrID cid.ID) error {
	if err := db.updateShardObjectCounter(tx, phy, 1, true); err != nil {
		return fmt.Errorf("could not increase phy object counter: %w", err)
	}
	if err := db.updateShardObjectCounter(tx, logical, 1, true); err != nil {
		return fmt.Errorf("could not increase logical object counter: %w", err)
	}
	return db.incContainerObjectCounter(tx, cnrID)
}

func (db *DB) updateShardObjectCounter(tx *bbolt.Tx, typ objectType, delta uint64, inc bool) error {
	b := tx.Bucket(shardInfoBucket)
	if b == nil {
		return nil
	}

	var counter uint64
	var counterKey []byte

	switch typ {
	case phy:
		counterKey = objectPhyCounterKey
	case logical:
		counterKey = objectLogicCounterKey
	default:
		panic("unknown object type counter")
	}

	data := b.Get(counterKey)
	if len(data) == 8 {
		counter = binary.LittleEndian.Uint64(data)
	}

	if inc {
		counter += delta
	} else if counter <= delta {
		counter = 0
	} else {
		counter -= delta
	}

	newCounter := make([]byte, 8)
	binary.LittleEndian.PutUint64(newCounter, counter)

	return b.Put(counterKey, newCounter)
}

func (db *DB) updateContainerCounter(tx *bbolt.Tx, delta map[cid.ID]ObjectCounters, inc bool) error {
	b := tx.Bucket(containerCounterBucketName)
	if b == nil {
		return nil
	}

	key := make([]byte, cidSize)
	for cnrID, cnrDelta := range delta {
		cnrID.Encode(key)
		if err := db.editContainerCounterValue(b, key, cnrDelta, inc); err != nil {
			return err
		}
	}
	return nil
}

func (*DB) editContainerCounterValue(b *bbolt.Bucket, key []byte, delta ObjectCounters, inc bool) error {
	var phyValue, logicValue uint64
	var err error
	data := b.Get(key)
	if len(data) > 0 {
		phyValue, logicValue, err = parseContainerCounterValue(data)
		if err != nil {
			return err
		}
	}
	phyValue = nextValue(phyValue, delta.phy, inc)
	logicValue = nextValue(logicValue, delta.logic, inc)
	if phyValue > 0 || logicValue > 0 {
		value := containerCounterValue(phyValue, logicValue)
		return b.Put(key, value)
	}
	return b.Delete(key)
}

func nextValue(existed, delta uint64, inc bool) uint64 {
	if inc {
		existed += delta
	} else if existed <= delta {
		existed = 0
	} else {
		existed -= delta
	}
	return existed
}

func (db *DB) incContainerObjectCounter(tx *bbolt.Tx, cnrID cid.ID) error {
	b := tx.Bucket(containerCounterBucketName)
	if b == nil {
		return nil
	}

	key := make([]byte, cidSize)
	cnrID.Encode(key)
	return db.editContainerCounterValue(b, key, ObjectCounters{logic: 1, phy: 1}, true)
}

// syncCounter updates object counters according to metabase state:
// it counts all the physically/logically stored objects using internal
// indexes. Tx MUST be writable.
//
// Does nothing if counters are not empty and force is false. If force is
// true, updates the counters anyway.
func syncCounter(tx *bbolt.Tx, force bool) error {
	shardInfoB, err := tx.CreateBucketIfNotExists(shardInfoBucket)
	if err != nil {
		return fmt.Errorf("could not get shard info bucket: %w", err)
	}
	shardObjectCounterInitialized := len(shardInfoB.Get(objectPhyCounterKey)) == 8 && len(shardInfoB.Get(objectLogicCounterKey)) == 8
	containerCounterInitialized := tx.Bucket(containerCounterBucketName) != nil
	if !force && shardObjectCounterInitialized && containerCounterInitialized {
		// the counters are already inited
		return nil
	}

	containerCounterB, err := tx.CreateBucketIfNotExists(containerCounterBucketName)
	if err != nil {
		return fmt.Errorf("could not get container counter bucket: %w", err)
	}

	var addr oid.Address
	counters := make(map[cid.ID]ObjectCounters)

	graveyardBKT := tx.Bucket(graveyardBucketName)
	garbageBKT := tx.Bucket(garbageBucketName)
	key := make([]byte, addressKeySize)

	err = iteratePhyObjects(tx, func(cnr cid.ID, obj oid.ID) error {
		if v, ok := counters[cnr]; ok {
			v.phy++
			counters[cnr] = v
		} else {
			counters[cnr] = ObjectCounters{
				phy: 1,
			}
		}

		addr.SetContainer(cnr)
		addr.SetObject(obj)

		// check if an object is available: not with GCMark
		// and not covered with a tombstone
		if inGraveyardWithKey(addressKey(addr, key), graveyardBKT, garbageBKT) == 0 {
			if v, ok := counters[cnr]; ok {
				v.logic++
				counters[cnr] = v
			} else {
				counters[cnr] = ObjectCounters{
					logic: 1,
				}
			}
		}

		return nil
	})
	if err != nil {
		return fmt.Errorf("could not iterate objects: %w", err)
	}

	return setObjectCounters(counters, shardInfoB, containerCounterB)
}

func setObjectCounters(counters map[cid.ID]ObjectCounters, shardInfoB, containerCounterB *bbolt.Bucket) error {
	var phyTotal uint64
	var logicTotal uint64
	key := make([]byte, cidSize)
	for cnrID, count := range counters {
		phyTotal += count.phy
		logicTotal += count.logic

		cnrID.Encode(key)
		value := containerCounterValue(count.phy, count.logic)
		err := containerCounterB.Put(key, value)
		if err != nil {
			return fmt.Errorf("could not update phy container object counter: %w", err)
		}
	}
	phyData := make([]byte, 8)
	binary.LittleEndian.PutUint64(phyData, phyTotal)

	err := shardInfoB.Put(objectPhyCounterKey, phyData)
	if err != nil {
		return fmt.Errorf("could not update phy object counter: %w", err)
	}

	logData := make([]byte, 8)
	binary.LittleEndian.PutUint64(logData, logicTotal)

	err = shardInfoB.Put(objectLogicCounterKey, logData)
	if err != nil {
		return fmt.Errorf("could not update logic object counter: %w", err)
	}

	return nil
}

func containerCounterValue(phy, logic uint64) []byte {
	res := make([]byte, 16)
	binary.LittleEndian.PutUint64(res, phy)
	binary.LittleEndian.PutUint64(res[8:], logic)
	return res
}

func parseContainerCounterKey(buf []byte) (cid.ID, error) {
	if len(buf) != cidSize {
		return cid.ID{}, fmt.Errorf("invalid key length")
	}
	var cnrID cid.ID
	if err := cnrID.Decode(buf); err != nil {
		return cid.ID{}, fmt.Errorf("failed to decode container ID: %w", err)
	}
	return cnrID, nil
}

// parseContainerCounterValue return phy, logic values.
func parseContainerCounterValue(buf []byte) (uint64, uint64, error) {
	if len(buf) != 16 {
		return 0, 0, fmt.Errorf("invalid value length")
	}
	return binary.LittleEndian.Uint64(buf), binary.LittleEndian.Uint64(buf[8:]), nil
}