package meta

import (
	"bytes"
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"os"
	"strconv"
	"sync/atomic"
	"time"

	objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"go.etcd.io/bbolt"
	"golang.org/x/sync/errgroup"
)

const (
	upgradeLogFrequency = 50_000
	upgradeWorkersCount = 1_000
	compactMaxTxSize    = 256 << 20
	upgradeTimeout      = 1 * time.Second
)

var updates = map[uint64]func(ctx context.Context, db *bbolt.DB, log func(a ...any)) error{
	2: upgradeFromV2ToV3,
	3: func(_ context.Context, _ *bbolt.DB, log func(a ...any)) error {
		log("metabase already upgraded")
		return nil
	},
}

func Upgrade(ctx context.Context, path string, compact bool, log func(a ...any)) error {
	if _, err := os.Stat(path); err != nil {
		return fmt.Errorf("check metabase existence: %w", err)
	}
	opts := bbolt.DefaultOptions
	opts.Timeout = upgradeTimeout
	db, err := bbolt.Open(path, os.ModePerm, opts)
	if err != nil {
		return fmt.Errorf("open metabase: %w", err)
	}
	var version uint64
	if err := db.View(func(tx *bbolt.Tx) error {
		var e error
		version, e = currentVersion(tx)
		return e
	}); err != nil {
		return err
	}
	updater, found := updates[version]
	if !found {
		return fmt.Errorf("unsupported version %d: no update available", version)
	}
	if err := db.Update(func(tx *bbolt.Tx) error {
		b := tx.Bucket(shardInfoBucket)
		return b.Put(upgradeKey, zeroValue)
	}); err != nil {
		return fmt.Errorf("set upgrade key %w", err)
	}
	if err := updater(ctx, db, log); err != nil {
		return fmt.Errorf("update metabase schema: %w", err)
	}
	if err := db.Update(func(tx *bbolt.Tx) error {
		b := tx.Bucket(shardInfoBucket)
		return b.Delete(upgradeKey)
	}); err != nil {
		return fmt.Errorf("delete upgrade key %w", err)
	}
	if compact {
		log("compacting metabase...")
		err := compactDB(db)
		if err != nil {
			return fmt.Errorf("compact metabase: %w", err)
		}
		log("metabase compacted")
	}
	return db.Close()
}

func compactDB(db *bbolt.DB) error {
	sourcePath := db.Path()
	tmpFileName := sourcePath + "." + time.Now().Format(time.RFC3339)
	f, err := os.Stat(sourcePath)
	if err != nil {
		return err
	}
	dst, err := bbolt.Open(tmpFileName, f.Mode(), &bbolt.Options{
		Timeout: 100 * time.Millisecond,
		NoSync:  true,
	})
	if err != nil {
		return fmt.Errorf("can't open new metabase to compact: %w", err)
	}
	if err := bbolt.Compact(dst, db, compactMaxTxSize); err != nil {
		return fmt.Errorf("compact metabase: %w", errors.Join(err, dst.Close(), os.Remove(tmpFileName)))
	}
	if err := dst.Sync(); err != nil {
		return fmt.Errorf("sync compacted metabase: %w", errors.Join(err, os.Remove(tmpFileName)))
	}
	if err := dst.Close(); err != nil {
		return fmt.Errorf("close compacted metabase: %w", errors.Join(err, os.Remove(tmpFileName)))
	}
	if err := db.Close(); err != nil {
		return fmt.Errorf("close source metabase: %w", errors.Join(err, os.Remove(tmpFileName)))
	}
	if err := os.Rename(tmpFileName, sourcePath); err != nil {
		return fmt.Errorf("replace source metabase with compacted: %w", errors.Join(err, os.Remove(tmpFileName)))
	}
	return nil
}

func upgradeFromV2ToV3(ctx context.Context, db *bbolt.DB, log func(a ...any)) error {
	if err := createExpirationEpochBuckets(ctx, db, log); err != nil {
		return err
	}
	if err := dropUserAttributes(ctx, db, log); err != nil {
		return err
	}
	if err := dropOwnerIDIndex(ctx, db, log); err != nil {
		return err
	}
	if err := dropPayloadChecksumIndex(ctx, db, log); err != nil {
		return err
	}
	return db.Update(func(tx *bbolt.Tx) error {
		return updateVersion(tx, version)
	})
}

type objectIDToExpEpoch struct {
	containerID     cid.ID
	objectID        oid.ID
	expirationEpoch uint64
}

func createExpirationEpochBuckets(ctx context.Context, db *bbolt.DB, log func(a ...any)) error {
	log("filling expiration epoch buckets...")
	if err := db.Update(func(tx *bbolt.Tx) error {
		_, err := tx.CreateBucketIfNotExists(expEpochToObjectBucketName)
		return err
	}); err != nil {
		return err
	}
	objects := make(chan objectIDToExpEpoch)
	eg, ctx := errgroup.WithContext(ctx)
	eg.Go(func() error {
		return selectObjectsWithExpirationEpoch(ctx, db, objects)
	})
	var count atomic.Uint64
	for range upgradeWorkersCount {
		eg.Go(func() error {
			for {
				select {
				case <-ctx.Done():
					return ctx.Err()
				case obj, ok := <-objects:
					if !ok {
						return nil
					}
					if err := db.Batch(func(tx *bbolt.Tx) error {
						if err := putUniqueIndexItem(tx, namedBucketItem{
							name: expEpochToObjectBucketName,
							key:  expirationEpochKey(obj.expirationEpoch, obj.containerID, obj.objectID),
							val:  zeroValue,
						}); err != nil {
							return err
						}
						val := make([]byte, epochSize)
						binary.LittleEndian.PutUint64(val, obj.expirationEpoch)
						return putUniqueIndexItem(tx, namedBucketItem{
							name: objectToExpirationEpochBucketName(obj.containerID, make([]byte, bucketKeySize)),
							key:  objectKey(obj.objectID, make([]byte, objectKeySize)),
							val:  val,
						})
					}); err != nil {
						return err
					}
				}
				if c := count.Add(1); c%upgradeLogFrequency == 0 {
					log("expiration epoch filled for", c, "objects...")
				}
			}
		})
	}
	err := eg.Wait()
	if err != nil {
		log("expiration epoch buckets completed completed with error:", err)
		return err
	}
	log("filling expiration epoch buckets completed successfully, total", count.Load(), "objects")
	return nil
}

func selectObjectsWithExpirationEpoch(ctx context.Context, db *bbolt.DB, objects chan objectIDToExpEpoch) error {
	defer close(objects)

	const batchSize = 1000
	it := &objectsWithExpirationEpochBatchIterator{
		lastAttributeKey: usrAttrPrefix,
	}
	for {
		if err := getNextObjectsWithExpirationEpochBatch(ctx, db, it, batchSize); err != nil {
			return err
		}
		for _, item := range it.items {
			select {
			case <-ctx.Done():
				return ctx.Err()
			case objects <- item:
			}
		}

		if len(it.items) < batchSize {
			return nil
		}
		it.items = nil
	}
}

var (
	usrAttrPrefix     = []byte{userAttributePrefix}
	errBatchSizeLimit = errors.New("batch size limit")
)

type objectsWithExpirationEpochBatchIterator struct {
	lastAttributeKey     []byte
	lastAttributeValue   []byte
	lastAttrKeyValueItem []byte
	items                []objectIDToExpEpoch
}

// - {prefix}{containerID}{attributeKey} <- bucket
// -- {attributeValue} <- bucket, expirationEpoch
// --- {objectID}: zeroValue <- record

func getNextObjectsWithExpirationEpochBatch(ctx context.Context, db *bbolt.DB, it *objectsWithExpirationEpochBatchIterator, batchSize int) error {
	seekAttrValue := it.lastAttributeValue
	seekAttrKVItem := it.lastAttrKeyValueItem
	err := db.View(func(tx *bbolt.Tx) error {
		attrKeyC := tx.Cursor()
		for attrKey, _ := attrKeyC.Seek(it.lastAttributeKey); attrKey != nil && bytes.HasPrefix(attrKey, usrAttrPrefix); attrKey, _ = attrKeyC.Next() {
			select {
			case <-ctx.Done():
				return ctx.Err()
			default:
			}
			if len(attrKey) <= 1+cidSize {
				continue
			}
			attributeKey := string(attrKey[1+cidSize:])
			if attributeKey != objectV2.SysAttributeExpEpoch {
				continue
			}
			var containerID cid.ID
			if err := containerID.Decode(attrKey[1 : 1+cidSize]); err != nil {
				return fmt.Errorf("decode container id from user attribute bucket: %w", err)
			}
			if err := iterateExpirationAttributeKeyBucket(ctx, tx.Bucket(attrKey), it, batchSize, containerID, attrKey, seekAttrValue, seekAttrKVItem); err != nil {
				return err
			}
			seekAttrValue = nil
			seekAttrKVItem = nil
		}
		return nil
	})
	if err != nil && !errors.Is(err, errBatchSizeLimit) {
		return err
	}
	return nil
}

func iterateExpirationAttributeKeyBucket(ctx context.Context, b *bbolt.Bucket, it *objectsWithExpirationEpochBatchIterator, batchSize int, containerID cid.ID, attrKey, seekAttrValue, seekAttrKVItem []byte) error {
	attrValueC := b.Cursor()
	for attrValue, v := attrValueC.Seek(seekAttrValue); attrValue != nil; attrValue, v = attrValueC.Next() {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		if v != nil {
			continue // need to iterate over buckets, not records
		}
		expirationEpoch, err := strconv.ParseUint(string(attrValue), 10, 64)
		if err != nil {
			return fmt.Errorf("could not parse expiration epoch: %w", err)
		}
		expirationEpochBucket := b.Bucket(attrValue)
		attrKeyValueC := expirationEpochBucket.Cursor()
		for attrKeyValueItem, v := attrKeyValueC.Seek(seekAttrKVItem); attrKeyValueItem != nil; attrKeyValueItem, v = attrKeyValueC.Next() {
			select {
			case <-ctx.Done():
				return ctx.Err()
			default:
			}
			if v == nil {
				continue // need to iterate over records, not buckets
			}
			if bytes.Equal(it.lastAttributeKey, attrKey) && bytes.Equal(it.lastAttributeValue, attrValue) && bytes.Equal(it.lastAttrKeyValueItem, attrKeyValueItem) {
				continue
			}
			var objectID oid.ID
			if err := objectID.Decode(attrKeyValueItem); err != nil {
				return fmt.Errorf("decode object id from container '%s' expiration epoch %d: %w", containerID, expirationEpoch, err)
			}
			it.lastAttributeKey = bytes.Clone(attrKey)
			it.lastAttributeValue = bytes.Clone(attrValue)
			it.lastAttrKeyValueItem = bytes.Clone(attrKeyValueItem)
			it.items = append(it.items, objectIDToExpEpoch{
				containerID:     containerID,
				objectID:        objectID,
				expirationEpoch: expirationEpoch,
			})
			if len(it.items) == batchSize {
				return errBatchSizeLimit
			}
		}
		seekAttrKVItem = nil
	}
	return nil
}

func dropUserAttributes(ctx context.Context, db *bbolt.DB, log func(a ...any)) error {
	return dropBucketsByPrefix(ctx, db, []byte{userAttributePrefix}, func(a ...any) {
		log(append([]any{"user attributes:"}, a...)...)
	})
}

func dropOwnerIDIndex(ctx context.Context, db *bbolt.DB, log func(a ...any)) error {
	return dropBucketsByPrefix(ctx, db, []byte{ownerPrefix}, func(a ...any) {
		log(append([]any{"owner ID index:"}, a...)...)
	})
}

func dropPayloadChecksumIndex(ctx context.Context, db *bbolt.DB, log func(a ...any)) error {
	return dropBucketsByPrefix(ctx, db, []byte{payloadHashPrefix}, func(a ...any) {
		log(append([]any{"payload checksum:"}, a...)...)
	})
}

func dropBucketsByPrefix(ctx context.Context, db *bbolt.DB, prefix []byte, log func(a ...any)) error {
	log("deleting buckets...")
	const batch = 1000
	var count uint64
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		var keys [][]byte
		if err := db.View(func(tx *bbolt.Tx) error {
			c := tx.Cursor()
			for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix) && len(keys) < batch; k, _ = c.Next() {
				keys = append(keys, bytes.Clone(k))
			}
			return nil
		}); err != nil {
			log("deleting buckets completed with an error:", err)
			return err
		}
		if len(keys) == 0 {
			log("deleting buckets completed successfully, deleted", count, "buckets")
			return nil
		}
		if err := db.Update(func(tx *bbolt.Tx) error {
			for _, k := range keys {
				if err := tx.DeleteBucket(k); err != nil {
					return err
				}
			}
			return nil
		}); err != nil {
			log("deleting buckets completed with an error:", err)
			return err
		}
		count += uint64(len(keys))
		log("deleted", count, "buckets")
	}
}