package handler

import (
	"context"
	"encoding/xml"
	"errors"
	"fmt"
	"sort"
	"strconv"
	"strings"

	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/data"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/logs"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/tokens"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/utils"
	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	qostagging "git.frostfs.info/TrueCloudLab/frostfs-qos/tagging"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"github.com/valyala/fasthttp"
	"go.uber.org/zap"
)

const (
	internalIOTag        = "internal"
	corsFilePathTemplate = "/%s.cors"
	wildcard             = "*"
)

var errNoCORS = errors.New("no CORS objects found")

func (h *Handler) Preflight(c *fasthttp.RequestCtx) {
	ctx, span := tracing.StartSpanFromContext(utils.GetContextFromRequest(c), "handler.Preflight")
	defer span.End()

	ctx = qostagging.ContextWithIOTag(ctx, internalIOTag)
	cidParam, _ := c.UserValue("cid").(string)
	reqLog := utils.GetReqLogOrDefault(ctx, h.log)
	log := reqLog.With(zap.String("cid", cidParam))

	origin := c.Request.Header.Peek(fasthttp.HeaderOrigin)
	if len(origin) == 0 {
		log.Error(logs.EmptyOriginRequestHeader, logs.TagField(logs.TagDatapath))
		ResponseError(c, "Origin request header needed", fasthttp.StatusBadRequest)
		return
	}

	method := c.Request.Header.Peek(fasthttp.HeaderAccessControlRequestMethod)
	if len(method) == 0 {
		log.Error(logs.EmptyAccessControlRequestMethodHeader, logs.TagField(logs.TagDatapath))
		ResponseError(c, "Access-Control-Request-Method request header needed", fasthttp.StatusBadRequest)
		return
	}

	corsRule := h.config.CORS()
	if corsRule != nil {
		setCORSHeadersFromRule(c, corsRule)
		return
	}

	corsConfig, err := h.getCORSConfig(ctx, log, cidParam)
	if err != nil {
		log.Error(logs.CouldNotGetCORSConfiguration, zap.Error(err), logs.TagField(logs.TagDatapath))
		status := fasthttp.StatusInternalServerError
		if errors.Is(err, errNoCORS) {
			status = fasthttp.StatusNotFound
		}
		ResponseError(c, "could not get CORS configuration: "+err.Error(), status)
		return
	}

	var headers []string
	requestHeaders := c.Request.Header.Peek(fasthttp.HeaderAccessControlRequestHeaders)
	if len(requestHeaders) > 0 {
		headers = strings.Split(string(requestHeaders), ", ")
	}

	for _, rule := range corsConfig.CORSRules {
		for _, o := range rule.AllowedOrigins {
			if o == string(origin) || o == wildcard {
				for _, m := range rule.AllowedMethods {
					if m == string(method) {
						if !checkSubslice(rule.AllowedHeaders, headers) {
							continue
						}
						c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin))
						c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
						if headers != nil {
							c.Response.Header.Set(fasthttp.HeaderAccessControlAllowHeaders, string(requestHeaders))
						}
						if rule.ExposeHeaders != nil {
							c.Response.Header.Set(fasthttp.HeaderAccessControlExposeHeaders, strings.Join(rule.ExposeHeaders, ", "))
						}
						if rule.MaxAgeSeconds > 0 || rule.MaxAgeSeconds == -1 {
							c.Response.Header.Set(fasthttp.HeaderAccessControlMaxAge, strconv.Itoa(rule.MaxAgeSeconds))
						}
						if o != wildcard {
							c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
						}
						return
					}
				}
			}
		}
	}
	log.Error(logs.CORSRuleWasNotMatched, logs.TagField(logs.TagDatapath))
	ResponseError(c, "Forbidden", fasthttp.StatusForbidden)
}

func (h *Handler) SetCORSHeaders(c *fasthttp.RequestCtx) {
	ctx, span := tracing.StartSpanFromContext(utils.GetContextFromRequest(c), "handler.SetCORSHeaders")
	defer span.End()

	origin := c.Request.Header.Peek(fasthttp.HeaderOrigin)
	if len(origin) == 0 {
		return
	}

	ctx = qostagging.ContextWithIOTag(ctx, internalIOTag)
	cidParam, _ := c.UserValue("cid").(string)
	reqLog := utils.GetReqLogOrDefault(ctx, h.log)
	log := reqLog.With(zap.String("cid", cidParam))

	corsRule := h.config.CORS()
	if corsRule != nil {
		setCORSHeadersFromRule(c, corsRule)
		return
	}

	corsConfig, err := h.getCORSConfig(ctx, log, cidParam)
	if err != nil {
		log.Error(logs.CouldNotGetCORSConfiguration, zap.Error(err), logs.TagField(logs.TagDatapath))
		return
	}

	var withCredentials bool
	if tkn, err := tokens.LoadBearerToken(ctx); err == nil && tkn != nil {
		withCredentials = true
	}

	for _, rule := range corsConfig.CORSRules {
		for _, o := range rule.AllowedOrigins {
			if o == string(origin) {
				for _, m := range rule.AllowedMethods {
					if m == string(c.Method()) {
						c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin))
						c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
						c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
						c.Response.Header.Set(fasthttp.HeaderVary, fasthttp.HeaderOrigin)
						return
					}
				}
			}
			if o == wildcard {
				for _, m := range rule.AllowedMethods {
					if m == string(c.Method()) {
						if withCredentials {
							c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin))
							c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
							c.Response.Header.Set(fasthttp.HeaderVary, fasthttp.HeaderOrigin)
						} else {
							c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, o)
						}
						c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
						return
					}
				}
			}
		}
	}
}

func (h *Handler) getCORSConfig(ctx context.Context, log *zap.Logger, cidStr string) (*data.CORSConfiguration, error) {
	cnrID, err := h.resolveContainer(ctx, cidStr)
	if err != nil {
		return nil, fmt.Errorf("resolve container '%s': %w", cidStr, err)
	}

	if cors := h.corsCache.Get(*cnrID); cors != nil {
		return cors, nil
	}

	objID, err := h.getLastCORSObject(ctx, *cnrID)
	if err != nil {
		return nil, fmt.Errorf("get last cors object: %w", err)
	}

	var addr oid.Address
	addr.SetContainer(h.corsCnrID)
	addr.SetObject(objID)
	corsObj, err := h.frostfs.GetObject(ctx, PrmObjectGet{
		PrmAuth: PrmAuth{
			BearerToken: bearerToken(ctx),
		},
		Address: addr,
	})
	if err != nil {
		return nil, fmt.Errorf("get cors object '%s': %w", addr.EncodeToString(), err)
	}

	corsConfig := &data.CORSConfiguration{}
	if err = xml.NewDecoder(corsObj.Payload).Decode(corsConfig); err != nil {
		return nil, fmt.Errorf("decode cors object: %w", err)
	}

	if err = h.corsCache.Put(*cnrID, corsConfig); err != nil {
		log.Warn(logs.CouldntCacheCors, zap.Error(err), logs.TagField(logs.TagDatapath))
	}

	return corsConfig, nil
}

func (h *Handler) getLastCORSObject(ctx context.Context, cnrID cid.ID) (oid.ID, error) {
	filters := object.NewSearchFilters()
	filters.AddRootFilter()
	filters.AddFilter(object.AttributeFilePath, fmt.Sprintf(corsFilePathTemplate, cnrID), object.MatchStringEqual)

	prmAuth := PrmAuth{
		BearerToken: bearerToken(ctx),
	}
	res, err := h.frostfs.SearchObjects(ctx, PrmObjectSearch{
		PrmAuth:   prmAuth,
		Container: h.corsCnrID,
		Filters:   filters,
	})
	if err != nil {
		return oid.ID{}, fmt.Errorf("search cors versions: %w", err)
	}
	defer res.Close()

	var (
		addr    oid.Address
		obj     *object.Object
		headErr error
		objs    = make([]*object.Object, 0)
	)
	addr.SetContainer(h.corsCnrID)
	err = res.Iterate(func(id oid.ID) bool {
		addr.SetObject(id)
		obj, headErr = h.frostfs.HeadObject(ctx, PrmObjectHead{
			PrmAuth: prmAuth,
			Address: addr,
		})
		if headErr != nil {
			headErr = fmt.Errorf("head cors object '%s': %w", addr.EncodeToString(), headErr)
			return true
		}

		objs = append(objs, obj)
		return false
	})
	if err != nil {
		return oid.ID{}, fmt.Errorf("iterate cors objects: %w", err)
	}

	if headErr != nil {
		return oid.ID{}, headErr
	}

	if len(objs) == 0 {
		return oid.ID{}, errNoCORS
	}

	sort.Slice(objs, func(i, j int) bool {
		versionID1, _ := objs[i].ID()
		versionID2, _ := objs[j].ID()
		timestamp1 := utils.GetAttributeValue(objs[i].Attributes(), object.AttributeTimestamp)
		timestamp2 := utils.GetAttributeValue(objs[j].Attributes(), object.AttributeTimestamp)

		if objs[i].CreationEpoch() != objs[j].CreationEpoch() {
			return objs[i].CreationEpoch() < objs[j].CreationEpoch()
		}

		if len(timestamp1) > 0 && len(timestamp2) > 0 && timestamp1 != timestamp2 {
			unixTime1, err := strconv.ParseInt(timestamp1, 10, 64)
			if err != nil {
				return versionID1.EncodeToString() < versionID2.EncodeToString()
			}

			unixTime2, err := strconv.ParseInt(timestamp2, 10, 64)
			if err != nil {
				return versionID1.EncodeToString() < versionID2.EncodeToString()
			}

			return unixTime1 < unixTime2
		}

		return versionID1.EncodeToString() < versionID2.EncodeToString()
	})

	objID, _ := objs[len(objs)-1].ID()
	return objID, nil
}

func setCORSHeadersFromRule(c *fasthttp.RequestCtx, cors *data.CORSRule) {
	c.Response.Header.Set(fasthttp.HeaderAccessControlMaxAge, strconv.Itoa(cors.MaxAgeSeconds))

	if len(cors.AllowedOrigins) != 0 {
		c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, cors.AllowedOrigins[0])
	}

	if len(cors.AllowedMethods) != 0 {
		c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(cors.AllowedMethods, ", "))
	}

	if len(cors.AllowedHeaders) != 0 {
		c.Response.Header.Set(fasthttp.HeaderAccessControlAllowHeaders, strings.Join(cors.AllowedHeaders, ", "))
	}

	if len(cors.ExposeHeaders) != 0 {
		c.Response.Header.Set(fasthttp.HeaderAccessControlExposeHeaders, strings.Join(cors.ExposeHeaders, ", "))
	}

	if cors.AllowedCredentials {
		c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
	}
}

func checkSubslice(slice []string, subSlice []string) bool {
	if sliceContains(slice, wildcard) {
		return true
	}
	if len(subSlice) > len(slice) {
		return false
	}
	for _, r := range subSlice {
		if !sliceContains(slice, r) {
			return false
		}
	}
	return true
}

func sliceContains(slice []string, str string) bool {
	for _, s := range slice {
		if s == str {
			return true
		}
	}
	return false
}