package qos

import (
	"context"
	"errors"
	"fmt"
	"sync"
	"sync/atomic"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/cmd/frostfs-node/config/engine/shard/limits"
	"git.frostfs.info/TrueCloudLab/frostfs-qos/scheduling"
	"git.frostfs.info/TrueCloudLab/frostfs-qos/tagging"
	apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
)

const (
	defaultIdleTimeout time.Duration = 0
	defaultShare       float64       = 1.0
	minusOne                         = ^uint64(0)

	defaultMetricsCollectTimeout = 5 * time.Second
)

type ReleaseFunc scheduling.ReleaseFunc

type Limiter interface {
	ReadRequest(context.Context) (ReleaseFunc, error)
	WriteRequest(context.Context) (ReleaseFunc, error)
	SetParentID(string)
	SetMetrics(Metrics)
	Close()
}

type scheduler interface {
	RequestArrival(ctx context.Context, tag string) (scheduling.ReleaseFunc, error)
	Close()
}

func NewLimiter(c *limits.Config) (Limiter, error) {
	if err := validateConfig(c); err != nil {
		return nil, err
	}
	readScheduler, err := createScheduler(c.Read())
	if err != nil {
		return nil, fmt.Errorf("create read scheduler: %w", err)
	}
	writeScheduler, err := createScheduler(c.Write())
	if err != nil {
		return nil, fmt.Errorf("create write scheduler: %w", err)
	}
	l := &mClockLimiter{
		readScheduler:  readScheduler,
		writeScheduler: writeScheduler,
		closeCh:        make(chan struct{}),
		wg:             &sync.WaitGroup{},
		readStats:      createStats(),
		writeStats:     createStats(),
	}
	l.shardID.Store(&shardID{})
	l.metrics.Store(&metricsHolder{metrics: &noopMetrics{}})
	l.startMetricsCollect()
	return l, nil
}

func createScheduler(config limits.OpConfig) (scheduler, error) {
	if len(config.Tags) == 0 && config.MaxWaitingOps == limits.NoLimit {
		return newSemaphoreScheduler(config.MaxRunningOps), nil
	}
	return scheduling.NewMClock(
		uint64(config.MaxRunningOps), uint64(config.MaxWaitingOps),
		converToSchedulingTags(config.Tags), config.IdleTimeout)
}

func converToSchedulingTags(limits []limits.IOTagConfig) map[string]scheduling.TagInfo {
	result := make(map[string]scheduling.TagInfo)
	for _, tag := range []IOTag{IOTagClient, IOTagBackground, IOTagInternal, IOTagPolicer, IOTagWritecache} {
		result[tag.String()] = scheduling.TagInfo{
			Share: defaultShare,
		}
	}
	for _, l := range limits {
		v := result[l.Tag]
		if l.Weight != nil && *l.Weight != 0 {
			v.Share = *l.Weight
		}
		if l.LimitOps != nil && *l.LimitOps != 0 {
			v.LimitIOPS = l.LimitOps
		}
		if l.ReservedOps != nil && *l.ReservedOps != 0 {
			v.ReservedIOPS = l.ReservedOps
		}
		result[l.Tag] = v
	}
	return result
}

var (
	_                   Limiter     = (*noopLimiter)(nil)
	releaseStub         ReleaseFunc = func() {}
	noopLimiterInstance             = &noopLimiter{}
)

func NewNoopLimiter() Limiter {
	return noopLimiterInstance
}

type noopLimiter struct{}

func (n *noopLimiter) ReadRequest(context.Context) (ReleaseFunc, error) {
	return releaseStub, nil
}

func (n *noopLimiter) WriteRequest(context.Context) (ReleaseFunc, error) {
	return releaseStub, nil
}

func (n *noopLimiter) SetParentID(string) {}

func (n *noopLimiter) Close() {}

func (n *noopLimiter) SetMetrics(Metrics) {}

var _ Limiter = (*mClockLimiter)(nil)

type shardID struct {
	id string
}

type mClockLimiter struct {
	readScheduler  scheduler
	writeScheduler scheduler

	readStats  map[string]*stat
	writeStats map[string]*stat

	shardID atomic.Pointer[shardID]
	metrics atomic.Pointer[metricsHolder]
	closeCh chan struct{}
	wg      *sync.WaitGroup
}

func (n *mClockLimiter) ReadRequest(ctx context.Context) (ReleaseFunc, error) {
	return requestArrival(ctx, n.readScheduler, n.readStats)
}

func (n *mClockLimiter) WriteRequest(ctx context.Context) (ReleaseFunc, error) {
	return requestArrival(ctx, n.writeScheduler, n.writeStats)
}

func requestArrival(ctx context.Context, s scheduler, stats map[string]*stat) (ReleaseFunc, error) {
	tag, ok := tagging.IOTagFromContext(ctx)
	if !ok {
		tag = IOTagClient.String()
	}
	stat := getStat(tag, stats)
	stat.pending.Add(1)
	if tag == IOTagCritical.String() {
		stat.inProgress.Add(1)
		return func() {
			stat.completed.Add(1)
		}, nil
	}
	rel, err := s.RequestArrival(ctx, tag)
	stat.inProgress.Add(1)
	if err != nil {
		if errors.Is(err, scheduling.ErrMClockSchedulerRequestLimitExceeded) ||
			errors.Is(err, errSemaphoreLimitExceeded) {
			stat.resourceExhausted.Add(1)
			return nil, &apistatus.ResourceExhausted{}
		}
		stat.completed.Add(1)
		return nil, err
	}
	return func() {
		rel()
		stat.completed.Add(1)
	}, nil
}

func (n *mClockLimiter) Close() {
	n.readScheduler.Close()
	n.writeScheduler.Close()
	close(n.closeCh)
	n.wg.Wait()
	n.metrics.Load().metrics.Close(n.shardID.Load().id)
}

func (n *mClockLimiter) SetParentID(parentID string) {
	n.shardID.Store(&shardID{id: parentID})
}

func (n *mClockLimiter) SetMetrics(m Metrics) {
	n.metrics.Store(&metricsHolder{metrics: m})
}

func (n *mClockLimiter) startMetricsCollect() {
	n.wg.Add(1)
	go func() {
		defer n.wg.Done()

		ticker := time.NewTicker(defaultMetricsCollectTimeout)
		defer ticker.Stop()
		for {
			select {
			case <-n.closeCh:
				return
			case <-ticker.C:
				shardID := n.shardID.Load().id
				if shardID == "" {
					continue
				}
				metrics := n.metrics.Load().metrics
				for tag, s := range n.readStats {
					metrics.SetOperationTagCounters(shardID, "read", tag, s.pending.Load(), s.inProgress.Load(), s.completed.Load(), s.resourceExhausted.Load())
				}
				for tag, s := range n.writeStats {
					metrics.SetOperationTagCounters(shardID, "write", tag, s.pending.Load(), s.inProgress.Load(), s.completed.Load(), s.resourceExhausted.Load())
				}
			}
		}
	}()
}