package middleware

import (
	"context"
	"crypto/elliptic"
	"encoding/xml"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"strings"

	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
	apiErr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
	frostfsErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors"
	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain"
	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine"
	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/resource/testutil"
	"git.frostfs.info/TrueCloudLab/policy-engine/schema/common"
	"git.frostfs.info/TrueCloudLab/policy-engine/schema/s3"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/nspcc-dev/neo-go/pkg/util"
	"go.uber.org/zap"
)

const (
	QueryVersionID = "versionId"
	QueryPrefix    = "prefix"
	QueryDelimiter = "delimiter"
	QueryMaxKeys   = "max-keys"
	amzTagging     = "x-amz-tagging"
)

// At the beginning of these operations resources haven't yet been created.
var withoutResourceOps = []string{
	CreateBucketOperation,
	CreateMultipartUploadOperation,
	AbortMultipartUploadOperation,
	CompleteMultipartUploadOperation,
	UploadPartOperation,
	UploadPartCopyOperation,
	ListPartsOperation,
	PutObjectOperation,
	CopyObjectOperation,
}

type PolicySettings interface {
	PolicyDenyByDefault() bool
	ACLEnabled() bool
}

type FrostFSIDInformer interface {
	GetUserGroupIDsAndClaims(userHash util.Uint160) ([]string, map[string]string, error)
}

type XMLDecoder interface {
	NewXMLDecoder(io.Reader) *xml.Decoder
}

type ResourceTagging interface {
	GetBucketTagging(ctx context.Context, bktInfo *data.BucketInfo) (map[string]string, error)
	GetObjectTagging(ctx context.Context, p *data.GetObjectTaggingParams) (string, map[string]string, error)
}

// BucketResolveFunc is a func to resolve bucket info by name.
type BucketResolveFunc func(ctx context.Context, bucket string) (*data.BucketInfo, error)

type PolicyConfig struct {
	Storage        engine.ChainRouter
	FrostfsID      FrostFSIDInformer
	Settings       PolicySettings
	Domains        []string
	Log            *zap.Logger
	BucketResolver BucketResolveFunc
	Decoder        XMLDecoder
	Tagging        ResourceTagging
}

func PolicyCheck(cfg PolicyConfig) Func {
	return func(h http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			ctx := r.Context()
			if err := policyCheck(r, cfg); err != nil {
				reqLogOrDefault(ctx, cfg.Log).Error(logs.PolicyValidationFailed, zap.Error(err))
				err = frostfsErrors.UnwrapErr(err)
				if _, wrErr := WriteErrorResponse(w, GetReqInfo(ctx), err); wrErr != nil {
					reqLogOrDefault(ctx, cfg.Log).Error(logs.FailedToWriteResponse, zap.Error(wrErr))
				}
				return
			}

			h.ServeHTTP(w, r)
		})
	}
}

func policyCheck(r *http.Request, cfg PolicyConfig) error {
	reqType, bktName, objName := getBucketObject(r, cfg.Domains)
	req, userKey, userGroups, err := getPolicyRequest(r, cfg, reqType, bktName, objName)
	if err != nil {
		return err
	}

	var bktInfo *data.BucketInfo
	if reqType != noneType && !strings.HasSuffix(req.Operation(), CreateBucketOperation) {
		bktInfo, err = cfg.BucketResolver(r.Context(), bktName)
		if err != nil {
			return err
		}
	}

	reqInfo := GetReqInfo(r.Context())
	target := engine.NewRequestTargetWithNamespace(reqInfo.Namespace)
	if bktInfo != nil {
		cnrTarget := engine.ContainerTarget(bktInfo.CID.EncodeToString())
		target.Container = &cnrTarget
	}

	if userKey != nil {
		entityName := fmt.Sprintf("%s:%s", reqInfo.Namespace, userKey.Address())
		uTarget := engine.UserTarget(entityName)
		target.User = &uTarget
	}

	gts := make([]engine.Target, len(userGroups))
	for i, group := range userGroups {
		entityName := fmt.Sprintf("%s:%s", reqInfo.Namespace, group)
		gts[i] = engine.GroupTarget(entityName)
	}
	target.Groups = gts

	st, found, err := cfg.Storage.IsAllowed(chain.S3, target, req)
	if err != nil {
		return err
	}

	if !found {
		st = chain.NoRuleFound
	}

	switch {
	case st == chain.Allow:
		return nil
	case st != chain.NoRuleFound:
		return apiErr.GetAPIErrorWithError(apiErr.ErrAccessDenied, fmt.Errorf("policy check: %s", st.String()))
	}

	isAPE := !cfg.Settings.ACLEnabled()
	if bktInfo != nil {
		isAPE = bktInfo.APEEnabled
	}

	if isAPE && cfg.Settings.PolicyDenyByDefault() {
		return apiErr.GetAPIErrorWithError(apiErr.ErrAccessDenied, fmt.Errorf("policy check: %s", st.String()))
	}

	return nil
}

func getPolicyRequest(r *http.Request, cfg PolicyConfig, reqType ReqType, bktName string, objName string) (*testutil.Request, *keys.PublicKey, []string, error) {
	var (
		owner  string
		groups []string
		tags   map[string]string
		pk     *keys.PublicKey
	)

	ctx := r.Context()
	bd, err := GetBoxData(ctx)
	if err == nil && bd.Gate.BearerToken != nil {
		pk, err = keys.NewPublicKeyFromBytes(bd.Gate.BearerToken.SigningKeyBytes(), elliptic.P256())
		if err != nil {
			return nil, nil, nil, fmt.Errorf("parse pubclic key from btoken: %w", err)
		}
		owner = pk.Address()

		groups, tags, err = cfg.FrostfsID.GetUserGroupIDsAndClaims(pk.GetScriptHash())
		if err != nil {
			return nil, nil, nil, fmt.Errorf("get group ids: %w", err)
		}
	}

	op := determineOperation(r, reqType)
	var res string
	switch reqType {
	case objectType:
		res = fmt.Sprintf(s3.ResourceFormatS3BucketObject, bktName, objName)
	default:
		res = fmt.Sprintf(s3.ResourceFormatS3Bucket, bktName)
	}

	properties, err := determineProperties(r, cfg.Decoder, cfg.BucketResolver, cfg.Tagging, reqType, op, bktName, objName, owner, groups, tags)
	if err != nil {
		return nil, nil, nil, fmt.Errorf("determine properties: %w", err)
	}

	reqLogOrDefault(r.Context(), cfg.Log).Debug(logs.PolicyRequest, zap.String("action", op),
		zap.String("resource", res), zap.Any("properties", properties))

	return testutil.NewRequest(op, testutil.NewResource(res, nil), properties), pk, groups, nil
}

type ReqType int

const (
	noneType ReqType = iota
	bucketType
	objectType
)

func getBucketObject(r *http.Request, domains []string) (reqType ReqType, bktName string, objName string) {
	for _, domain := range domains {
		ind := strings.Index(r.Host, "."+domain)
		if ind == -1 {
			continue
		}

		bkt := r.Host[:ind]
		if obj := strings.TrimPrefix(r.URL.Path, "/"); obj != "" {
			return objectType, bkt, obj
		}

		return bucketType, bkt, ""
	}

	bktObj := strings.TrimPrefix(r.URL.Path, "/")
	if bktObj == "" {
		return noneType, "", ""
	}

	if ind := strings.IndexByte(bktObj, '/'); ind != -1 && bktObj[ind+1:] != "" {
		return objectType, bktObj[:ind], bktObj[ind+1:]
	}

	return bucketType, strings.TrimSuffix(bktObj, "/"), ""
}

func determineOperation(r *http.Request, reqType ReqType) (operation string) {
	switch reqType {
	case objectType:
		operation = determineObjectOperation(r)
	case bucketType:
		operation = determineBucketOperation(r)
	default:
		operation = determineGeneralOperation(r)
	}

	return "s3:" + operation
}

func determineBucketOperation(r *http.Request) string {
	query := r.URL.Query()
	switch r.Method {
	case http.MethodOptions:
		return OptionsOperation
	case http.MethodHead:
		return HeadBucketOperation
	case http.MethodGet:
		switch {
		case query.Has(UploadsQuery):
			return ListMultipartUploadsOperation
		case query.Has(LocationQuery):
			return GetBucketLocationOperation
		case query.Has(PolicyQuery):
			return GetBucketPolicyOperation
		case query.Has(LifecycleQuery):
			return GetBucketLifecycleOperation
		case query.Has(EncryptionQuery):
			return GetBucketEncryptionOperation
		case query.Has(CorsQuery):
			return GetBucketCorsOperation
		case query.Has(ACLQuery):
			return GetBucketACLOperation
		case query.Has(WebsiteQuery):
			return GetBucketWebsiteOperation
		case query.Has(AccelerateQuery):
			return GetBucketAccelerateOperation
		case query.Has(RequestPaymentQuery):
			return GetBucketRequestPaymentOperation
		case query.Has(LoggingQuery):
			return GetBucketLoggingOperation
		case query.Has(ReplicationQuery):
			return GetBucketReplicationOperation
		case query.Has(TaggingQuery):
			return GetBucketTaggingOperation
		case query.Has(ObjectLockQuery):
			return GetBucketObjectLockConfigOperation
		case query.Has(VersioningQuery):
			return GetBucketVersioningOperation
		case query.Has(NotificationQuery):
			return GetBucketNotificationOperation
		case query.Has(EventsQuery):
			return ListenBucketNotificationOperation
		case query.Has(VersionsQuery):
			return ListBucketObjectVersionsOperation
		case query.Get(ListTypeQuery) == "2" && query.Get(MetadataQuery) == "true":
			return ListObjectsV2MOperation
		case query.Get(ListTypeQuery) == "2":
			return ListObjectsV2Operation
		default:
			return ListObjectsV1Operation
		}
	case http.MethodPut:
		switch {
		case query.Has(CorsQuery):
			return PutBucketCorsOperation
		case query.Has(ACLQuery):
			return PutBucketACLOperation
		case query.Has(LifecycleQuery):
			return PutBucketLifecycleOperation
		case query.Has(EncryptionQuery):
			return PutBucketEncryptionOperation
		case query.Has(PolicyQuery):
			return PutBucketPolicyOperation
		case query.Has(ObjectLockQuery):
			return PutBucketObjectLockConfigOperation
		case query.Has(TaggingQuery):
			return PutBucketTaggingOperation
		case query.Has(VersioningQuery):
			return PutBucketVersioningOperation
		case query.Has(NotificationQuery):
			return PutBucketNotificationOperation
		default:
			return CreateBucketOperation
		}
	case http.MethodPost:
		switch {
		case query.Has(DeleteQuery):
			return DeleteMultipleObjectsOperation
		default:
			return PostObjectOperation
		}
	case http.MethodDelete:
		switch {
		case query.Has(CorsQuery):
			return DeleteBucketCorsOperation
		case query.Has(WebsiteQuery):
			return DeleteBucketWebsiteOperation
		case query.Has(TaggingQuery):
			return DeleteBucketTaggingOperation
		case query.Has(PolicyQuery):
			return DeleteBucketPolicyOperation
		case query.Has(LifecycleQuery):
			return DeleteBucketLifecycleOperation
		case query.Has(EncryptionQuery):
			return DeleteBucketEncryptionOperation
		default:
			return DeleteBucketOperation
		}
	}

	return "UnmatchedBucketOperation"
}

func determineObjectOperation(r *http.Request) string {
	query := r.URL.Query()
	switch r.Method {
	case http.MethodHead:
		return HeadObjectOperation
	case http.MethodGet:
		switch {
		case query.Has(UploadIDQuery):
			return ListPartsOperation
		case query.Has(ACLQuery):
			return GetObjectACLOperation
		case query.Has(TaggingQuery):
			return GetObjectTaggingOperation
		case query.Has(RetentionQuery):
			return GetObjectRetentionOperation
		case query.Has(LegalQuery):
			return GetObjectLegalHoldOperation
		case query.Has(AttributesQuery):
			return GetObjectAttributesOperation
		default:
			return GetObjectOperation
		}
	case http.MethodPut:
		switch {
		case query.Has(PartNumberQuery) && query.Has(UploadIDQuery) && r.Header.Get("X-Amz-Copy-Source") != "":
			return UploadPartCopyOperation
		case query.Has(PartNumberQuery) && query.Has(UploadIDQuery):
			return UploadPartOperation
		case query.Has(ACLQuery):
			return PutObjectACLOperation
		case query.Has(TaggingQuery):
			return PutObjectTaggingOperation
		case r.Header.Get("X-Amz-Copy-Source") != "":
			return CopyObjectOperation
		case query.Has(RetentionQuery):
			return PutObjectRetentionOperation
		case query.Has(LegalHoldQuery):
			return PutObjectLegalHoldOperation
		default:
			return PutObjectOperation
		}
	case http.MethodPost:
		switch {
		case query.Has(UploadIDQuery):
			return CompleteMultipartUploadOperation
		case query.Has(UploadsQuery):
			return CreateMultipartUploadOperation
		default:
			return SelectObjectContentOperation
		}
	case http.MethodDelete:
		switch {
		case query.Has(UploadIDQuery):
			return AbortMultipartUploadOperation
		case query.Has(TaggingQuery):
			return DeleteObjectTaggingOperation
		default:
			return DeleteObjectOperation
		}
	}

	return "UnmatchedObjectOperation"
}

func determineGeneralOperation(r *http.Request) string {
	if r.Method == http.MethodGet {
		return ListBucketsOperation
	}
	return "UnmatchedOperation"
}

func determineProperties(r *http.Request, decoder XMLDecoder, resolver BucketResolveFunc, tagging ResourceTagging, reqType ReqType,
	op, bktName, objName, owner string, groups []string, tags map[string]string) (map[string]string, error) {
	res := map[string]string{
		s3.PropertyKeyOwner:                owner,
		common.PropertyKeyFrostFSIDGroupID: chain.FormCondSliceContainsValue(groups),
		common.PropertyKeyFrostFSSourceIP:  GetReqInfo(r.Context()).RemoteHost,
	}
	queries := GetReqInfo(r.Context()).URL.Query()

	for k, v := range tags {
		res[fmt.Sprintf(common.PropertyKeyFormatFrostFSIDUserClaim, k)] = v
	}

	if reqType == objectType {
		if versionID := queries.Get(QueryVersionID); len(versionID) > 0 {
			res[s3.PropertyKeyVersionID] = versionID
		}
	}

	if reqType == bucketType && (strings.HasSuffix(op, ListObjectsV1Operation) || strings.HasSuffix(op, ListObjectsV2Operation) ||
		strings.HasSuffix(op, ListBucketObjectVersionsOperation) || strings.HasSuffix(op, ListMultipartUploadsOperation)) {
		if prefix := queries.Get(QueryPrefix); len(prefix) > 0 {
			res[s3.PropertyKeyPrefix] = prefix
		}
		if delimiter := queries.Get(QueryDelimiter); len(delimiter) > 0 {
			res[s3.PropertyKeyDelimiter] = delimiter
		}
		if maxKeys := queries.Get(QueryMaxKeys); len(maxKeys) > 0 {
			res[s3.PropertyKeyMaxKeys] = maxKeys
		}
	}

	tags, err := determineTags(r, decoder, resolver, tagging, reqType, op, bktName, objName, queries.Get(QueryVersionID))
	if err != nil {
		return nil, fmt.Errorf("determine tags: %w", err)
	}
	for k, v := range tags {
		res[k] = v
	}

	res[s3.PropertyKeyAccessBoxAttrMFA] = "false"
	attrs, err := GetAccessBoxAttrs(r.Context())
	if err == nil {
		for _, attr := range attrs {
			res[fmt.Sprintf(s3.PropertyKeyFormatAccessBoxAttr, attr.Key())] = attr.Value()
		}
	}

	return res, nil
}

func determineTags(r *http.Request, decoder XMLDecoder, resolver BucketResolveFunc, tagging ResourceTagging, reqType ReqType,
	op, bktName, objName, versionID string) (map[string]string, error) {
	res, err := determineRequestTags(r, decoder, op)
	if err != nil {
		return nil, fmt.Errorf("determine request tags: %w", err)
	}

	tags, err := determineResourceTags(r.Context(), reqType, op, bktName, objName, versionID, resolver, tagging)
	if err != nil {
		return nil, fmt.Errorf("determine resource tags: %w", err)
	}
	for k, v := range tags {
		res[k] = v
	}

	return res, nil
}

func determineRequestTags(r *http.Request, decoder XMLDecoder, op string) (map[string]string, error) {
	tags := make(map[string]string)

	if strings.HasSuffix(op, PutObjectTaggingOperation) || strings.HasSuffix(op, PutBucketTaggingOperation) {
		tagging := new(data.Tagging)
		if err := decoder.NewXMLDecoder(r.Body).Decode(tagging); err != nil {
			return nil, fmt.Errorf("%w: %s", apiErr.GetAPIError(apiErr.ErrMalformedXML), err.Error())
		}
		GetReqInfo(r.Context()).Tagging = tagging

		for _, tag := range tagging.TagSet {
			tags[fmt.Sprintf(s3.PropertyKeyFormatRequestTag, tag.Key)] = tag.Value
		}
	}

	if tagging := r.Header.Get(amzTagging); len(tagging) > 0 {
		queries, err := url.ParseQuery(tagging)
		if err != nil {
			return nil, apiErr.GetAPIError(apiErr.ErrInvalidArgument)
		}
		for key := range queries {
			tags[fmt.Sprintf(s3.PropertyKeyFormatRequestTag, key)] = queries.Get(key)
		}
	}

	return tags, nil
}

func determineResourceTags(ctx context.Context, reqType ReqType, op, bktName, objName, versionID string, resolver BucketResolveFunc,
	tagging ResourceTagging) (map[string]string, error) {
	tags := make(map[string]string)

	if reqType != bucketType && reqType != objectType {
		return tags, nil
	}

	for _, withoutResOp := range withoutResourceOps {
		if strings.HasSuffix(op, withoutResOp) {
			return tags, nil
		}
	}

	bktInfo, err := resolver(ctx, bktName)
	if err != nil {
		return nil, fmt.Errorf("get bucket info: %w", err)
	}

	if reqType == bucketType {
		tags, err = tagging.GetBucketTagging(ctx, bktInfo)
		if err != nil {
			return nil, fmt.Errorf("get bucket tagging: %w", err)
		}
	}

	if reqType == objectType {
		tagPrm := &data.GetObjectTaggingParams{
			ObjectVersion: &data.ObjectVersion{
				BktInfo:    bktInfo,
				ObjectName: objName,
				VersionID:  versionID,
			},
		}
		_, tags, err = tagging.GetObjectTagging(ctx, tagPrm)
		if err != nil {
			return nil, fmt.Errorf("get object tagging: %w", err)
		}
	}

	res := make(map[string]string, len(tags))
	for k, v := range tags {
		res[fmt.Sprintf(s3.PropertyKeyFormatResourceTag, k)] = v
	}

	return res, nil
}