package utils

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"math"
	"strconv"
	"strings"
	"time"
	"unicode"
	"unicode/utf8"
)

type EpochDurations struct {
	CurrentEpoch  uint64
	MsPerBlock    int64
	BlockPerEpoch uint64
}

type EpochInfoFetcher interface {
	GetEpochDurations(context.Context) (*EpochDurations, error)
}

const (
	UserAttributeHeaderPrefix = "X-Attribute-"
)

const (
	systemAttributePrefix = "__SYSTEM__"

	// deprecated: use systemAttributePrefix
	systemAttributePrefixNeoFS = "__NEOFS__"
)

type systemTransformer struct {
	prefix         string
	backwardPrefix string
	xAttrPrefixes  [][]byte
}

var transformers = []systemTransformer{
	{
		prefix:         systemAttributePrefix,
		backwardPrefix: "System-",
		xAttrPrefixes:  [][]byte{[]byte("System-"), []byte("SYSTEM-"), []byte("system-")},
	},
	{
		prefix:         systemAttributePrefixNeoFS,
		backwardPrefix: "Neofs-",
		xAttrPrefixes:  [][]byte{[]byte("Neofs-"), []byte("NEOFS-"), []byte("neofs-")},
	},
}

func (t systemTransformer) existsExpirationAttributes(headers map[string]string) bool {
	_, ok0 := headers[t.expirationEpochAttr()]
	_, ok1 := headers[t.expirationDurationAttr()]
	_, ok2 := headers[t.expirationTimestampAttr()]
	_, ok3 := headers[t.expirationRFC3339Attr()]
	return ok0 || ok1 || ok2 || ok3
}

func (t systemTransformer) expirationEpochAttr() string {
	return t.prefix + "EXPIRATION_EPOCH"
}

func (t systemTransformer) expirationDurationAttr() string {
	return t.prefix + "EXPIRATION_DURATION"
}

func (t systemTransformer) expirationTimestampAttr() string {
	return t.prefix + "EXPIRATION_TIMESTAMP"
}

func (t systemTransformer) expirationRFC3339Attr() string {
	return t.prefix + "EXPIRATION_RFC3339"
}

func (t systemTransformer) systemTranslator(key, prefix []byte) []byte {
	// replace the specified prefix with system prefix
	key = bytes.Replace(key, prefix, []byte(t.prefix), 1)

	// replace `-` with `_`
	key = bytes.ReplaceAll(key, []byte("-"), []byte("_"))

	// replace with uppercase
	return bytes.ToUpper(key)
}

func (t systemTransformer) transformIfSystem(key []byte) ([]byte, bool) {
	// checks that it's a system FrostFS header
	for _, system := range t.xAttrPrefixes {
		if bytes.HasPrefix(key, system) {
			return t.systemTranslator(key, system), true
		}
	}

	return key, false
}

// systemBackwardTranslator is used to convert headers looking like '__PREFIX__ATTR_NAME' to 'Prefix-Attr-Name'.
func (t systemTransformer) systemBackwardTranslator(key string) string {
	// trim specified prefix '__PREFIX__'
	key = strings.TrimPrefix(key, t.prefix)

	var res strings.Builder
	res.WriteString(t.backwardPrefix)

	strs := strings.Split(key, "_")
	for i, s := range strs {
		s = title(strings.ToLower(s))
		res.WriteString(s)
		if i != len(strs)-1 {
			res.WriteString("-")
		}
	}

	return res.String()
}

func (t systemTransformer) backwardTransformIfSystem(key string) (string, bool) {
	if strings.HasPrefix(key, t.prefix) {
		return t.systemBackwardTranslator(key), true
	}

	return key, false
}

func TransformIfSystem(key []byte) []byte {
	for _, transformer := range transformers {
		key, transformed := transformer.transformIfSystem(key)
		if transformed {
			return key
		}
	}

	return key
}

func BackwardTransformIfSystem(key string) string {
	for _, transformer := range transformers {
		key, transformed := transformer.backwardTransformIfSystem(key)
		if transformed {
			return key
		}
	}

	return key
}

func title(str string) string {
	if str == "" {
		return ""
	}

	r, size := utf8.DecodeRuneInString(str)
	r0 := unicode.ToTitle(r)
	return string(r0) + str[size:]
}

func PrepareExpirationHeader(ctx context.Context, epochFetcher EpochInfoFetcher, headers map[string]string, now time.Time) error {
	formatsNum := 0
	index := -1
	for i, transformer := range transformers {
		if transformer.existsExpirationAttributes(headers) {
			formatsNum++
			index = i
		}
	}

	switch formatsNum {
	case 0:
		return nil
	case 1:
		epochDuration, err := epochFetcher.GetEpochDurations(ctx)
		if err != nil {
			return fmt.Errorf("couldn't get epoch durations from network info: %w", err)
		}
		return transformers[index].prepareExpirationHeader(headers, epochDuration, now)
	default:
		return errors.New("both deprecated and new system attributes formats are used, please use only one")
	}
}

func (t systemTransformer) prepareExpirationHeader(headers map[string]string, epochDurations *EpochDurations, now time.Time) error {
	expirationInEpoch := headers[t.expirationEpochAttr()]

	if timeRFC3339, ok := headers[t.expirationRFC3339Attr()]; ok {
		expTime, err := time.Parse(time.RFC3339, timeRFC3339)
		if err != nil {
			return fmt.Errorf("couldn't parse value %s of header %s", timeRFC3339, t.expirationRFC3339Attr())
		}

		if expTime.Before(now) {
			return fmt.Errorf("value %s of header %s must be in the future", timeRFC3339, t.expirationRFC3339Attr())
		}
		t.updateExpirationHeader(headers, epochDurations, expTime.Sub(now))
		delete(headers, t.expirationRFC3339Attr())
	}

	if timestamp, ok := headers[t.expirationTimestampAttr()]; ok {
		value, err := strconv.ParseInt(timestamp, 10, 64)
		if err != nil {
			return fmt.Errorf("couldn't parse value %s of header %s", timestamp, t.expirationTimestampAttr())
		}
		expTime := time.Unix(value, 0)

		if expTime.Before(now) {
			return fmt.Errorf("value %s of header %s must be in the future", timestamp, t.expirationTimestampAttr())
		}
		t.updateExpirationHeader(headers, epochDurations, expTime.Sub(now))
		delete(headers, t.expirationTimestampAttr())
	}

	if duration, ok := headers[t.expirationDurationAttr()]; ok {
		expDuration, err := time.ParseDuration(duration)
		if err != nil {
			return fmt.Errorf("couldn't parse value %s of header %s", duration, t.expirationDurationAttr())
		}
		if expDuration <= 0 {
			return fmt.Errorf("value %s of header %s must be positive", expDuration, t.expirationDurationAttr())
		}
		t.updateExpirationHeader(headers, epochDurations, expDuration)
		delete(headers, t.expirationDurationAttr())
	}

	if expirationInEpoch != "" {
		expEpoch, err := strconv.ParseUint(expirationInEpoch, 10, 64)
		if err != nil {
			return fmt.Errorf("parse expiration epoch '%s': %w", expirationInEpoch, err)
		}
		if expEpoch < epochDurations.CurrentEpoch {
			return fmt.Errorf("expiration epoch '%d' must be greater than current epoch '%d'", expEpoch, epochDurations.CurrentEpoch)
		}

		headers[t.expirationEpochAttr()] = expirationInEpoch
	}

	return nil
}

func (t systemTransformer) updateExpirationHeader(headers map[string]string, durations *EpochDurations, expDuration time.Duration) {
	epochDuration := uint64(durations.MsPerBlock) * durations.BlockPerEpoch
	currentEpoch := durations.CurrentEpoch
	numEpoch := uint64(expDuration.Milliseconds()) / epochDuration

	if uint64(expDuration.Milliseconds())%epochDuration != 0 {
		numEpoch++
	}

	expirationEpoch := uint64(math.MaxUint64)
	if numEpoch < math.MaxUint64-currentEpoch {
		expirationEpoch = currentEpoch + numEpoch
	}

	headers[t.expirationEpochAttr()] = strconv.FormatUint(expirationEpoch, 10)
}