package shard

import (
	"context"
	"sync"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
	meta "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/metabase"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard/mode"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"go.uber.org/zap"
	"golang.org/x/sync/errgroup"
)

const (
	minExpiredWorkers   = 2
	minExpiredBatchSize = 1
)

// TombstoneSource is an interface that checks
// tombstone status in the FrostFS network.
type TombstoneSource interface {
	// IsTombstoneAvailable must return boolean value that means
	// provided tombstone's presence in the FrostFS network at the
	// time of the passed epoch.
	IsTombstoneAvailable(ctx context.Context, addr oid.Address, epoch uint64) bool
}

// Event represents class of external events.
type Event interface {
	typ() eventType
}

type eventType int

const (
	_ eventType = iota
	eventNewEpoch
)

type newEpoch struct {
	epoch uint64
}

func (e newEpoch) typ() eventType {
	return eventNewEpoch
}

// EventNewEpoch returns new epoch event.
func EventNewEpoch(e uint64) Event {
	return newEpoch{
		epoch: e,
	}
}

type eventHandler func(context.Context, Event)

type eventHandlers struct {
	prevGroup sync.WaitGroup

	cancelFunc context.CancelFunc

	handlers []eventHandler
}

type gcRunResult struct {
	success        bool
	deleted        uint64
	failedToDelete uint64
}

const (
	objectTypeLock      = "lock"
	objectTypeTombstone = "tombstone"
	objectTypeRegular   = "regular"
)

type GCMectrics interface {
	SetShardID(string)
	AddRunDuration(d time.Duration, success bool)
	AddDeletedCount(deleted, failed uint64)
	AddExpiredObjectCollectionDuration(d time.Duration, success bool, objectType string)
	AddInhumedObjectCount(count uint64, objectType string)
}

type noopGCMetrics struct{}

func (m *noopGCMetrics) SetShardID(string)                                              {}
func (m *noopGCMetrics) AddRunDuration(time.Duration, bool)                             {}
func (m *noopGCMetrics) AddDeletedCount(uint64, uint64)                                 {}
func (m *noopGCMetrics) AddExpiredObjectCollectionDuration(time.Duration, bool, string) {}
func (m *noopGCMetrics) AddInhumedObjectCount(uint64, string)                           {}

type gc struct {
	*gcCfg

	onceStop    sync.Once
	stopChannel chan struct{}
	wg          sync.WaitGroup

	workerPool util.WorkerPool

	remover func(context.Context) gcRunResult

	// eventChan is used only for listening for the new epoch event.
	// It is ok to keep opened, we are listening for context done when writing in it.
	eventChan     chan Event
	mEventHandler map[eventType]*eventHandlers
}

type gcCfg struct {
	removerInterval time.Duration

	log *logger.Logger

	workerPoolInit func(int) util.WorkerPool

	expiredCollectorWorkerCount int
	expiredCollectorBatchSize   int

	metrics GCMectrics

	testHookRemover func(ctx context.Context) gcRunResult
}

func defaultGCCfg() gcCfg {
	return gcCfg{
		removerInterval: 10 * time.Second,
		log:             logger.NewLoggerWrapper(zap.L()),
		workerPoolInit: func(int) util.WorkerPool {
			return nil
		},
		metrics: &noopGCMetrics{},
	}
}

func (gc *gc) init(ctx context.Context) {
	sz := 0

	for _, v := range gc.mEventHandler {
		sz += len(v.handlers)
	}

	if sz > 0 {
		gc.workerPool = gc.workerPoolInit(sz)
	}

	gc.wg.Add(2)
	go gc.tickRemover(ctx)
	go gc.listenEvents(ctx)
}

func (gc *gc) listenEvents(ctx context.Context) {
	defer gc.wg.Done()

	for {
		select {
		case <-gc.stopChannel:
			gc.log.Warn(ctx, logs.ShardStopEventListenerByClosedStopChannel)
			return
		case <-ctx.Done():
			gc.log.Warn(ctx, logs.ShardStopEventListenerByContext)
			return
		case event, ok := <-gc.eventChan:
			if !ok {
				gc.log.Warn(ctx, logs.ShardStopEventListenerByClosedEventChannel)
				return
			}

			gc.handleEvent(ctx, event)
		}
	}
}

func (gc *gc) handleEvent(ctx context.Context, event Event) {
	v, ok := gc.mEventHandler[event.typ()]
	if !ok {
		return
	}

	v.cancelFunc()
	v.prevGroup.Wait()

	var runCtx context.Context
	runCtx, v.cancelFunc = context.WithCancel(ctx)

	v.prevGroup.Add(len(v.handlers))

	for i := range v.handlers {
		select {
		case <-ctx.Done():
			return
		default:
		}
		h := v.handlers[i]

		err := gc.workerPool.Submit(func() {
			defer v.prevGroup.Done()
			h(runCtx, event)
		})
		if err != nil {
			gc.log.Warn(ctx, logs.ShardCouldNotSubmitGCJobToWorkerPool,
				zap.String("error", err.Error()),
			)

			v.prevGroup.Done()
		}
	}
}

func (gc *gc) releaseResources() {
	if gc.workerPool != nil {
		gc.workerPool.Release()
	}

	// Avoid to close gc.eventChan here,
	// because it is possible that we are close it earlier than stop writing.
	// It is ok to keep it opened.

	gc.log.Debug(context.Background(), logs.ShardGCIsStopped)
}

func (gc *gc) tickRemover(ctx context.Context) {
	defer gc.wg.Done()

	timer := time.NewTimer(gc.removerInterval)
	defer timer.Stop()

	for {
		select {
		case <-ctx.Done():
			// Context canceled earlier than we start to close shards.
			// It make sense to stop collecting garbage by context too.
			gc.releaseResources()
			return
		case <-gc.stopChannel:
			gc.releaseResources()
			return
		case <-timer.C:
			startedAt := time.Now()

			var result gcRunResult
			if gc.testHookRemover != nil {
				result = gc.testHookRemover(ctx)
			} else {
				result = gc.remover(ctx)
			}
			timer.Reset(gc.removerInterval)

			gc.metrics.AddRunDuration(time.Since(startedAt), result.success)
			gc.metrics.AddDeletedCount(result.deleted, result.failedToDelete)
		}
	}
}

func (gc *gc) stop() {
	gc.onceStop.Do(func() {
		close(gc.stopChannel)
	})

	gc.log.Info(context.Background(), logs.ShardWaitingForGCWorkersToStop)
	gc.wg.Wait()
}

// iterates over metabase and deletes objects
// with GC-marked graves.
// Does nothing if shard is in "read-only" mode.
func (s *Shard) removeGarbage(pctx context.Context) (result gcRunResult) {
	ctx, cancel := context.WithCancel(pctx)
	defer cancel()

	s.gcCancel.Store(cancel)
	if s.setModeRequested.Load() {
		return
	}

	s.m.RLock()
	defer s.m.RUnlock()

	if s.info.Mode != mode.ReadWrite {
		return
	}

	s.log.Debug(ctx, logs.ShardGCRemoveGarbageStarted)
	defer s.log.Debug(ctx, logs.ShardGCRemoveGarbageCompleted)

	buf := make([]oid.Address, 0, s.rmBatchSize)

	var iterPrm meta.GarbageIterationPrm
	iterPrm.SetHandler(func(g meta.GarbageObject) error {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}

		buf = append(buf, g.Address())

		if len(buf) == s.rmBatchSize {
			return meta.ErrInterruptIterator
		}

		return nil
	})

	// iterate over metabase's objects with GC mark
	// (no more than s.rmBatchSize objects)
	err := s.metaBase.IterateOverGarbage(ctx, iterPrm)
	if err != nil {
		s.log.Warn(ctx, logs.ShardIteratorOverMetabaseGraveyardFailed,
			zap.String("error", err.Error()),
		)

		return
	} else if len(buf) == 0 {
		result.success = true
		return
	}

	var deletePrm DeletePrm
	deletePrm.SetAddresses(buf...)

	// delete accumulated objects
	res, err := s.delete(ctx, deletePrm, true)

	result.deleted = res.deleted
	result.failedToDelete = uint64(len(buf)) - res.deleted
	result.success = true

	if err != nil {
		s.log.Warn(ctx, logs.ShardCouldNotDeleteTheObjects,
			zap.String("error", err.Error()),
		)
		result.success = false
	}

	return
}

func (s *Shard) getExpiredObjectsParameters() (workerCount, batchSize int) {
	workerCount = max(minExpiredWorkers, s.gc.gcCfg.expiredCollectorWorkerCount)
	batchSize = max(minExpiredBatchSize, s.gc.gcCfg.expiredCollectorBatchSize)
	return
}

func (s *Shard) collectExpiredObjects(ctx context.Context, e Event) {
	var err error
	startedAt := time.Now()

	defer func() {
		s.gc.metrics.AddExpiredObjectCollectionDuration(time.Since(startedAt), err == nil, objectTypeRegular)
	}()

	s.log.Debug(ctx, logs.ShardGCCollectingExpiredObjectsStarted, zap.Uint64("epoch", e.(newEpoch).epoch))
	defer s.log.Debug(ctx, logs.ShardGCCollectingExpiredObjectsCompleted, zap.Uint64("epoch", e.(newEpoch).epoch))

	workersCount, batchSize := s.getExpiredObjectsParameters()

	errGroup, egCtx := errgroup.WithContext(ctx)
	errGroup.SetLimit(workersCount)

	errGroup.Go(func() error {
		batch := make([]oid.Address, 0, batchSize)
		expErr := s.getExpiredObjects(egCtx, e.(newEpoch).epoch, func(o *meta.ExpiredObject) {
			if o.Type() != objectSDK.TypeTombstone && o.Type() != objectSDK.TypeLock {
				batch = append(batch, o.Address())

				if len(batch) == batchSize {
					expired := batch
					errGroup.Go(func() error {
						s.handleExpiredObjects(egCtx, expired)
						return egCtx.Err()
					})
					batch = make([]oid.Address, 0, batchSize)
				}
			}
		})
		if expErr != nil {
			return expErr
		}

		if len(batch) > 0 {
			expired := batch
			errGroup.Go(func() error {
				s.handleExpiredObjects(egCtx, expired)
				return egCtx.Err()
			})
		}

		return nil
	})

	if err = errGroup.Wait(); err != nil {
		s.log.Warn(ctx, logs.ShardIteratorOverExpiredObjectsFailed, zap.String("error", err.Error()))
	}
}

func (s *Shard) handleExpiredObjects(ctx context.Context, expired []oid.Address) {
	select {
	case <-ctx.Done():
		return
	default:
	}

	s.m.RLock()
	defer s.m.RUnlock()

	if s.info.Mode.NoMetabase() {
		return
	}

	expired, err := s.getExpiredWithLinked(ctx, expired)
	if err != nil {
		s.log.Warn(ctx, logs.ShardGCFailedToGetExpiredWithLinked, zap.Error(err))
		return
	}

	var inhumePrm meta.InhumePrm

	inhumePrm.SetAddresses(expired...)
	inhumePrm.SetGCMark()

	// inhume the collected objects
	res, err := s.metaBase.Inhume(ctx, inhumePrm)
	if err != nil {
		s.log.Warn(ctx, logs.ShardCouldNotInhumeTheObjects,
			zap.String("error", err.Error()),
		)

		return
	}

	s.gc.metrics.AddInhumedObjectCount(res.LogicInhumed(), objectTypeRegular)
	s.decObjectCounterBy(logical, res.LogicInhumed())
	s.decObjectCounterBy(user, res.UserInhumed())
	s.decContainerObjectCounter(res.InhumedByCnrID())

	i := 0
	for i < res.GetDeletionInfoLength() {
		delInfo := res.GetDeletionInfoByIndex(i)
		s.addToContainerSize(delInfo.CID.EncodeToString(), -int64(delInfo.Size))
		i++
	}
}

func (s *Shard) getExpiredWithLinked(ctx context.Context, source []oid.Address) ([]oid.Address, error) {
	result := make([]oid.Address, 0, len(source))
	parentToChildren, err := s.metaBase.GetChildren(ctx, source)
	if err != nil {
		return nil, err
	}
	for parent, children := range parentToChildren {
		result = append(result, parent)
		result = append(result, children...)
	}

	return result, nil
}

func (s *Shard) collectExpiredTombstones(ctx context.Context, e Event) {
	var err error
	startedAt := time.Now()

	defer func() {
		s.gc.metrics.AddExpiredObjectCollectionDuration(time.Since(startedAt), err == nil, objectTypeTombstone)
	}()

	epoch := e.(newEpoch).epoch
	log := s.log.With(zap.Uint64("epoch", epoch))

	log.Debug(ctx, logs.ShardStartedExpiredTombstonesHandling)
	defer log.Debug(ctx, logs.ShardFinishedExpiredTombstonesHandling)

	const tssDeleteBatch = 50
	tss := make([]meta.TombstonedObject, 0, tssDeleteBatch)
	tssExp := make([]meta.TombstonedObject, 0, tssDeleteBatch)

	var iterPrm meta.GraveyardIterationPrm
	iterPrm.SetHandler(func(deletedObject meta.TombstonedObject) error {
		tss = append(tss, deletedObject)

		if len(tss) == tssDeleteBatch {
			return meta.ErrInterruptIterator
		}

		return nil
	})

	for {
		log.Debug(ctx, logs.ShardIteratingTombstones)

		s.m.RLock()

		if s.info.Mode.NoMetabase() {
			s.log.Debug(ctx, logs.ShardShardIsInADegradedModeSkipCollectingExpiredTombstones)
			s.m.RUnlock()

			return
		}

		err = s.metaBase.IterateOverGraveyard(ctx, iterPrm)
		if err != nil {
			log.Error(ctx, logs.ShardIteratorOverGraveyardFailed, zap.Error(err))
			s.m.RUnlock()

			return
		}

		s.m.RUnlock()

		tssLen := len(tss)
		if tssLen == 0 {
			break
		}

		for _, ts := range tss {
			if !s.tsSource.IsTombstoneAvailable(ctx, ts.Tombstone(), epoch) {
				tssExp = append(tssExp, ts)
			}
		}

		log.Debug(ctx, logs.ShardHandlingExpiredTombstonesBatch, zap.Int("number", len(tssExp)))
		if len(tssExp) > 0 {
			s.expiredTombstonesCallback(ctx, tssExp)
		}

		iterPrm.SetOffset(tss[tssLen-1].Address())
		tss = tss[:0]
		tssExp = tssExp[:0]
	}
}

func (s *Shard) collectExpiredLocks(ctx context.Context, e Event) {
	var err error
	startedAt := time.Now()

	defer func() {
		s.gc.metrics.AddExpiredObjectCollectionDuration(time.Since(startedAt), err == nil, objectTypeLock)
	}()

	s.log.Debug(ctx, logs.ShardGCCollectingExpiredLocksStarted, zap.Uint64("epoch", e.(newEpoch).epoch))
	defer s.log.Debug(ctx, logs.ShardGCCollectingExpiredLocksCompleted, zap.Uint64("epoch", e.(newEpoch).epoch))

	workersCount, batchSize := s.getExpiredObjectsParameters()

	errGroup, egCtx := errgroup.WithContext(ctx)
	errGroup.SetLimit(workersCount)

	errGroup.Go(func() error {
		batch := make([]oid.Address, 0, batchSize)

		expErr := s.getExpiredObjects(egCtx, e.(newEpoch).epoch, func(o *meta.ExpiredObject) {
			if o.Type() == objectSDK.TypeLock {
				batch = append(batch, o.Address())

				if len(batch) == batchSize {
					expired := batch
					errGroup.Go(func() error {
						s.expiredLocksCallback(egCtx, e.(newEpoch).epoch, expired)
						return egCtx.Err()
					})
					batch = make([]oid.Address, 0, batchSize)
				}
			}
		})
		if expErr != nil {
			return expErr
		}

		if len(batch) > 0 {
			expired := batch
			errGroup.Go(func() error {
				s.expiredLocksCallback(egCtx, e.(newEpoch).epoch, expired)
				return egCtx.Err()
			})
		}

		return nil
	})

	if err = errGroup.Wait(); err != nil {
		s.log.Warn(ctx, logs.ShardIteratorOverExpiredLocksFailed, zap.String("error", err.Error()))
	}
}

func (s *Shard) getExpiredObjects(ctx context.Context, epoch uint64, onExpiredFound func(*meta.ExpiredObject)) error {
	s.m.RLock()
	defer s.m.RUnlock()

	if s.info.Mode.NoMetabase() {
		return ErrDegradedMode
	}

	err := s.metaBase.IterateExpired(ctx, epoch, func(expiredObject *meta.ExpiredObject) error {
		select {
		case <-ctx.Done():
			return meta.ErrInterruptIterator
		default:
			onExpiredFound(expiredObject)
			return nil
		}
	})
	if err != nil {
		return err
	}
	return ctx.Err()
}

func (s *Shard) selectExpired(ctx context.Context, epoch uint64, addresses []oid.Address) ([]oid.Address, error) {
	s.m.RLock()
	defer s.m.RUnlock()

	if s.info.Mode.NoMetabase() {
		return nil, ErrDegradedMode
	}

	return s.metaBase.FilterExpired(ctx, epoch, addresses)
}

// HandleExpiredTombstones marks tombstones themselves as garbage
// and clears up corresponding graveyard records.
//
// Does not modify tss.
func (s *Shard) HandleExpiredTombstones(ctx context.Context, tss []meta.TombstonedObject) {
	if s.GetMode().NoMetabase() {
		return
	}

	// Mark tombstones as garbage.
	var pInhume meta.InhumePrm

	tsAddrs := make([]oid.Address, 0, len(tss))
	for _, ts := range tss {
		tsAddrs = append(tsAddrs, ts.Tombstone())
	}

	pInhume.SetGCMark()
	pInhume.SetAddresses(tsAddrs...)

	// inhume tombstones
	res, err := s.metaBase.Inhume(ctx, pInhume)
	if err != nil {
		s.log.Warn(ctx, logs.ShardCouldNotMarkTombstonesAsGarbage,
			zap.String("error", err.Error()),
		)

		return
	}

	s.gc.metrics.AddInhumedObjectCount(res.LogicInhumed(), objectTypeTombstone)
	s.decObjectCounterBy(logical, res.LogicInhumed())
	s.decObjectCounterBy(user, res.UserInhumed())
	s.decContainerObjectCounter(res.InhumedByCnrID())

	i := 0
	for i < res.GetDeletionInfoLength() {
		delInfo := res.GetDeletionInfoByIndex(i)
		s.addToContainerSize(delInfo.CID.EncodeToString(), -int64(delInfo.Size))
		i++
	}

	// drop just processed expired tombstones
	// from graveyard
	err = s.metaBase.DropGraves(ctx, tss)
	if err != nil {
		s.log.Warn(ctx, logs.ShardCouldNotDropExpiredGraveRecords, zap.Error(err))
	}
}

// HandleExpiredLocks unlocks all objects which were locked by lockers.
// If successful, marks lockers themselves as garbage.
func (s *Shard) HandleExpiredLocks(ctx context.Context, epoch uint64, lockers []oid.Address) {
	if s.GetMode().NoMetabase() {
		return
	}
	unlocked, err := s.metaBase.FreeLockedBy(lockers)
	if err != nil {
		s.log.Warn(ctx, logs.ShardFailureToUnlockObjects,
			zap.String("error", err.Error()),
		)

		return
	}

	var pInhume meta.InhumePrm
	pInhume.SetAddresses(lockers...)
	pInhume.SetForceGCMark()

	res, err := s.metaBase.Inhume(ctx, pInhume)
	if err != nil {
		s.log.Warn(ctx, logs.ShardFailureToMarkLockersAsGarbage,
			zap.String("error", err.Error()),
		)

		return
	}

	s.gc.metrics.AddInhumedObjectCount(res.LogicInhumed(), objectTypeLock)
	s.decObjectCounterBy(logical, res.LogicInhumed())
	s.decObjectCounterBy(user, res.UserInhumed())
	s.decContainerObjectCounter(res.InhumedByCnrID())

	i := 0
	for i < res.GetDeletionInfoLength() {
		delInfo := res.GetDeletionInfoByIndex(i)
		s.addToContainerSize(delInfo.CID.EncodeToString(), -int64(delInfo.Size))
		i++
	}

	s.inhumeUnlockedIfExpired(ctx, epoch, unlocked)
}

func (s *Shard) inhumeUnlockedIfExpired(ctx context.Context, epoch uint64, unlocked []oid.Address) {
	expiredUnlocked, err := s.selectExpired(ctx, epoch, unlocked)
	if err != nil {
		s.log.Warn(ctx, logs.ShardFailureToGetExpiredUnlockedObjects, zap.Error(err))
		return
	}

	if len(expiredUnlocked) == 0 {
		return
	}

	s.handleExpiredObjects(ctx, expiredUnlocked)
}

// HandleDeletedLocks unlocks all objects which were locked by lockers.
func (s *Shard) HandleDeletedLocks(lockers []oid.Address) {
	if s.GetMode().NoMetabase() {
		return
	}

	_, err := s.metaBase.FreeLockedBy(lockers)
	if err != nil {
		s.log.Warn(context.Background(), logs.ShardFailureToUnlockObjects,
			zap.String("error", err.Error()),
		)

		return
	}
}

// NotificationChannel returns channel for shard events.
func (s *Shard) NotificationChannel() chan<- Event {
	return s.gc.eventChan
}

func (s *Shard) collectExpiredMetrics(ctx context.Context, e Event) {
	ctx, span := tracing.StartSpanFromContext(ctx, "shard.collectExpiredMetrics")
	defer span.End()

	epoch := e.(newEpoch).epoch

	s.log.Debug(ctx, logs.ShardGCCollectingExpiredMetricsStarted, zap.Uint64("epoch", epoch))
	defer s.log.Debug(ctx, logs.ShardGCCollectingExpiredMetricsCompleted, zap.Uint64("epoch", epoch))

	s.collectExpiredContainerSizeMetrics(ctx, epoch)
	s.collectExpiredContainerCountMetrics(ctx, epoch)
}

func (s *Shard) collectExpiredContainerSizeMetrics(ctx context.Context, epoch uint64) {
	ids, err := s.metaBase.ZeroSizeContainers(ctx)
	if err != nil {
		s.log.Warn(ctx, logs.ShardGCFailedToCollectZeroSizeContainers, zap.Uint64("epoch", epoch), zap.Error(err))
		return
	}
	if len(ids) == 0 {
		return
	}
	s.zeroSizeContainersCallback(ctx, ids)
}

func (s *Shard) collectExpiredContainerCountMetrics(ctx context.Context, epoch uint64) {
	ids, err := s.metaBase.ZeroCountContainers(ctx)
	if err != nil {
		s.log.Warn(ctx, logs.ShardGCFailedToCollectZeroCountContainers, zap.Uint64("epoch", epoch), zap.Error(err))
		return
	}
	if len(ids) == 0 {
		return
	}
	s.zeroCountContainersCallback(ctx, ids)
}