package iam

import (
	"errors"
	"fmt"
	"net/netip"
	"strconv"
	"strings"
	"time"
	"unicode/utf8"

	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain"
	"git.frostfs.info/TrueCloudLab/policy-engine/schema/common"
	"github.com/nspcc-dev/neo-go/pkg/encoding/fixedn"
)

const (
	s3ActionAbortMultipartUpload             = "s3:AbortMultipartUpload"
	s3ActionCreateBucket                     = "s3:CreateBucket"
	s3ActionDeleteBucket                     = "s3:DeleteBucket"
	s3ActionDeleteBucketPolicy               = "s3:DeleteBucketPolicy"
	s3ActionDeleteObject                     = "s3:DeleteObject"
	s3ActionDeleteObjectTagging              = "s3:DeleteObjectTagging"
	s3ActionDeleteObjectVersion              = "s3:DeleteObjectVersion"
	s3ActionDeleteObjectVersionTagging       = "s3:DeleteObjectVersionTagging"
	s3ActionGetBucketACL                     = "s3:GetBucketAcl"
	s3ActionGetBucketCORS                    = "s3:GetBucketCORS"
	s3ActionGetBucketLocation                = "s3:GetBucketLocation"
	s3ActionGetBucketNotification            = "s3:GetBucketNotification"
	s3ActionGetBucketObjectLockConfiguration = "s3:GetBucketObjectLockConfiguration"
	s3ActionGetBucketPolicy                  = "s3:GetBucketPolicy"
	s3ActionGetBucketPolicyStatus            = "s3:GetBucketPolicyStatus"
	s3ActionGetBucketTagging                 = "s3:GetBucketTagging"
	s3ActionGetBucketVersioning              = "s3:GetBucketVersioning"
	s3ActionGetLifecycleConfiguration        = "s3:GetLifecycleConfiguration"
	s3ActionGetObject                        = "s3:GetObject"
	s3ActionGetObjectACL                     = "s3:GetObjectAcl"
	s3ActionGetObjectAttributes              = "s3:GetObjectAttributes"
	s3ActionGetObjectLegalHold               = "s3:GetObjectLegalHold"
	s3ActionGetObjectRetention               = "s3:GetObjectRetention"
	s3ActionGetObjectTagging                 = "s3:GetObjectTagging"
	s3ActionGetObjectVersion                 = "s3:GetObjectVersion"
	s3ActionGetObjectVersionACL              = "s3:GetObjectVersionAcl"
	s3ActionGetObjectVersionAttributes       = "s3:GetObjectVersionAttributes"
	s3ActionGetObjectVersionTagging          = "s3:GetObjectVersionTagging"
	s3ActionListAllMyBuckets                 = "s3:ListAllMyBuckets"
	s3ActionListBucket                       = "s3:ListBucket"
	s3ActionListBucketMultipartUploads       = "s3:ListBucketMultipartUploads"
	s3ActionListBucketVersions               = "s3:ListBucketVersions"
	s3ActionListMultipartUploadParts         = "s3:ListMultipartUploadParts"
	s3ActionPutBucketACL                     = "s3:PutBucketAcl"
	s3ActionPutBucketCORS                    = "s3:PutBucketCORS"
	s3ActionPutBucketNotification            = "s3:PutBucketNotification"
	s3ActionPutBucketObjectLockConfiguration = "s3:PutBucketObjectLockConfiguration"
	s3ActionPutBucketPolicy                  = "s3:PutBucketPolicy"
	s3ActionPutBucketTagging                 = "s3:PutBucketTagging"
	s3ActionPutBucketVersioning              = "s3:PutBucketVersioning"
	s3ActionPutLifecycleConfiguration        = "s3:PutLifecycleConfiguration"
	s3ActionPutObject                        = "s3:PutObject"
	s3ActionPutObjectACL                     = "s3:PutObjectAcl"
	s3ActionPutObjectLegalHold               = "s3:PutObjectLegalHold"
	s3ActionPutObjectRetention               = "s3:PutObjectRetention"
	s3ActionPutObjectTagging                 = "s3:PutObjectTagging"
	s3ActionPutObjectVersionACL              = "s3:PutObjectVersionAcl"
	s3ActionPutObjectVersionTagging          = "s3:PutObjectVersionTagging"
)

const (
	condKeyAWSPrincipalARN       = "aws:PrincipalArn"
	condKeyAWSSourceIP           = "aws:SourceIp"
	condKeyAWSPrincipalTagPrefix = "aws:PrincipalTag/"
	userClaimTagPrefix           = "tag-"
)

const (
	// String condition operators.
	CondStringEquals              string = "StringEquals"
	CondStringNotEquals           string = "StringNotEquals"
	CondStringEqualsIgnoreCase    string = "StringEqualsIgnoreCase"
	CondStringNotEqualsIgnoreCase string = "StringNotEqualsIgnoreCase"
	CondStringLike                string = "StringLike"
	CondStringNotLike             string = "StringNotLike"

	// Numeric condition operators.
	CondNumericEquals            string = "NumericEquals"
	CondNumericNotEquals         string = "NumericNotEquals"
	CondNumericLessThan          string = "NumericLessThan"
	CondNumericLessThanEquals    string = "NumericLessThanEquals"
	CondNumericGreaterThan       string = "NumericGreaterThan"
	CondNumericGreaterThanEquals string = "NumericGreaterThanEquals"

	// Date condition operators.
	CondDateEquals            string = "DateEquals"
	CondDateNotEquals         string = "DateNotEquals"
	CondDateLessThan          string = "DateLessThan"
	CondDateLessThanEquals    string = "DateLessThanEquals"
	CondDateGreaterThan       string = "DateGreaterThan"
	CondDateGreaterThanEquals string = "DateGreaterThanEquals"

	// Bolean condition operators.
	CondBool string = "Bool"

	// IP address condition operators.
	CondIPAddress    string = "IpAddress"
	CondNotIPAddress string = "NotIpAddress"

	// ARN condition operators.
	CondArnEquals    string = "ArnEquals"
	CondArnLike      string = "ArnLike"
	CondArnNotEquals string = "ArnNotEquals"
	CondArnNotLike   string = "ArnNotLike"

	// Custom condition operators.
	CondSliceContains string = "SliceContains"
)

const (
	arnIAMPrefix     = "arn:aws:iam::"
	s3ResourcePrefix = "arn:aws:s3:::"
	s3ActionPrefix   = "s3:"
	iamActionPrefix  = "iam:"
)

var (
	// ErrInvalidPrincipalFormat occurs when principal has unknown/unsupported format.
	ErrInvalidPrincipalFormat = errors.New("invalid principal format")

	// ErrInvalidResourceFormat occurs when resource has unknown/unsupported format.
	ErrInvalidResourceFormat = errors.New("invalid resource format")

	// ErrInvalidActionFormat occurs when action has unknown/unsupported format.
	ErrInvalidActionFormat = errors.New("invalid action format")

	// ErrActionsNotApplicable occurs when failed to convert any actions.
	ErrActionsNotApplicable = errors.New("actions not applicable")
)

type formPrincipalConditionFunc func(string) chain.Condition

type transformConditionFunc func(gr GroupedConditions) (GroupedConditions, error)

func convertToChainConditions(c Conditions, transformer transformConditionFunc) ([]GroupedConditions, error) {
	conditions, err := convertToChainCondition(c)
	if err != nil {
		return nil, err
	}

	for i := range conditions {
		if conditions[i], err = transformer(conditions[i]); err != nil {
			return nil, fmt.Errorf("transform condition: %w", err)
		}
	}

	return conditions, nil
}

type GroupedConditions struct {
	Conditions []chain.Condition
	Any        bool
}

func convertToChainCondition(c Conditions) ([]GroupedConditions, error) {
	var grouped []GroupedConditions

	for op, KVs := range c {
		condType, convertValue, err := getConditionTypeAndConverter(op)
		if err != nil {
			return nil, err
		}

		for key, values := range KVs {
			group := GroupedConditions{
				Conditions: make([]chain.Condition, len(values)),
				Any:        len(values) > 1,
			}

			for i, val := range values {
				converted, err := convertValue(val)
				if err != nil {
					return nil, err
				}

				group.Conditions[i] = chain.Condition{
					Op:     condType,
					Object: chain.ObjectRequest,
					Key:    transformKey(key),
					Value:  converted,
				}
			}
			grouped = append(grouped, group)
		}
	}

	return grouped, nil
}

func transformKey(key string) string {
	tagName, isTag := strings.CutPrefix(key, condKeyAWSPrincipalTagPrefix)
	if isTag {
		return fmt.Sprintf(common.PropertyKeyFormatFrostFSIDUserClaim, userClaimTagPrefix+tagName)
	}

	switch key {
	case condKeyAWSSourceIP:
		return common.PropertyKeyFrostFSSourceIP
	}

	return key
}

func getConditionTypeAndConverter(op string) (chain.ConditionType, convertFunction, error) {
	switch {
	case strings.HasPrefix(op, "String"):
		switch op {
		case CondStringEquals:
			return chain.CondStringEquals, noConvertFunction, nil
		case CondStringNotEquals:
			return chain.CondStringNotEquals, noConvertFunction, nil
		case CondStringEqualsIgnoreCase:
			return chain.CondStringEqualsIgnoreCase, noConvertFunction, nil
		case CondStringNotEqualsIgnoreCase:
			return chain.CondStringNotEqualsIgnoreCase, noConvertFunction, nil
		case CondStringLike:
			return chain.CondStringLike, noConvertFunction, nil
		case CondStringNotLike:
			return chain.CondStringNotLike, noConvertFunction, nil
		default:
			return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
		}
	case strings.HasPrefix(op, "Arn"):
		switch op {
		case CondArnEquals:
			return chain.CondStringEquals, noConvertFunction, nil
		case CondArnNotEquals:
			return chain.CondStringNotEquals, noConvertFunction, nil
		case CondArnLike:
			return chain.CondStringLike, noConvertFunction, nil
		case CondArnNotLike:
			return chain.CondStringNotLike, noConvertFunction, nil
		default:
			return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
		}
	case strings.HasPrefix(op, "Numeric"):
		return numericConditionTypeAndConverter(op)
	case strings.HasPrefix(op, "Date"):
		switch op {
		case CondDateEquals:
			return chain.CondStringEquals, dateConvertFunction, nil
		case CondDateNotEquals:
			return chain.CondStringNotEquals, dateConvertFunction, nil
		case CondDateLessThan:
			return chain.CondStringLessThan, dateConvertFunction, nil
		case CondDateLessThanEquals:
			return chain.CondStringLessThanEquals, dateConvertFunction, nil
		case CondDateGreaterThan:
			return chain.CondStringGreaterThan, dateConvertFunction, nil
		case CondDateGreaterThanEquals:
			return chain.CondStringGreaterThanEquals, dateConvertFunction, nil
		default:
			return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
		}
	case op == CondBool:
		return chain.CondStringEqualsIgnoreCase, noConvertFunction, nil
	case op == CondIPAddress:
		return chain.CondIPAddress, ipConvertFunction, nil
	case op == CondNotIPAddress:
		return chain.CondNotIPAddress, ipConvertFunction, nil
	case op == CondSliceContains:
		return chain.CondSliceContains, noConvertFunction, nil
	default:
		return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
	}
}

func numericConditionTypeAndConverter(op string) (chain.ConditionType, convertFunction, error) {
	switch op {
	case CondNumericEquals:
		return chain.CondNumericEquals, numericConvertFunction, nil
	case CondNumericNotEquals:
		return chain.CondNumericNotEquals, numericConvertFunction, nil
	case CondNumericLessThan:
		return chain.CondNumericLessThan, numericConvertFunction, nil
	case CondNumericLessThanEquals:
		return chain.CondNumericLessThanEquals, numericConvertFunction, nil
	case CondNumericGreaterThan:
		return chain.CondNumericGreaterThan, numericConvertFunction, nil
	case CondNumericGreaterThanEquals:
		return chain.CondNumericGreaterThanEquals, numericConvertFunction, nil
	default:
		return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
	}
}

type convertFunction func(string) (string, error)

func noConvertFunction(val string) (string, error) {
	return val, nil
}

func numericConvertFunction(val string) (string, error) {
	if _, err := fixedn.Fixed8FromString(val); err == nil {
		return val, nil
	}

	return "", fmt.Errorf("invalid numeric value: '%s'", val)
}

func ipConvertFunction(val string) (string, error) {
	var ipAddr netip.Addr

	if prefix, err := netip.ParsePrefix(val); err != nil {
		if ipAddr, err = netip.ParseAddr(val); err != nil {
			return "", err
		}
		val += "/32"
	} else {
		ipAddr = prefix.Addr()
	}

	if ipAddr.IsPrivate() {
		return "", fmt.Errorf("invalid ip value '%s': must be public", val)
	}

	return val, nil
}

func dateConvertFunction(val string) (string, error) {
	if _, err := strconv.ParseInt(val, 10, 64); err == nil {
		return val, nil
	}

	parsed, err := time.Parse(time.RFC3339, val)
	if err != nil {
		return "", err
	}

	return strconv.FormatInt(parsed.UTC().Unix(), 10), nil
}

func parsePrincipalAsIAMUser(principal string) (account string, user string, err error) {
	if !strings.HasPrefix(principal, arnIAMPrefix) {
		return "", "", ErrInvalidPrincipalFormat
	}

	// iam arn format arn:aws:iam::<account>:user/<user-name-with-path>
	iamResource := strings.TrimPrefix(principal, arnIAMPrefix)
	sepIndex := strings.Index(iamResource, ":user/")
	if sepIndex < 0 {
		return "", "", ErrInvalidPrincipalFormat
	}

	account = iamResource[:sepIndex]
	user = iamResource[sepIndex+6:]
	if len(user) == 0 {
		return "", "", ErrInvalidPrincipalFormat
	}

	userNameIndex := strings.LastIndexByte(user, '/')
	if userNameIndex > -1 {
		user = user[userNameIndex+1:]
		if len(user) == 0 {
			return "", "", ErrInvalidPrincipalFormat
		}
	}

	return account, user, nil
}

func validateResource(resource string) error {
	if resource == Wildcard {
		return nil
	}

	if !strings.HasPrefix(resource, s3ResourcePrefix) && !strings.HasPrefix(resource, arnIAMPrefix) {
		return ErrInvalidResourceFormat
	}

	index := strings.IndexByte(resource, Wildcard[0])
	if index != -1 && index != utf8.RuneCountInString(resource)-1 {
		return ErrInvalidResourceFormat
	}

	return nil
}

func validateAction(action string) (bool, error) {
	isIAM := strings.HasPrefix(action, iamActionPrefix)
	if !strings.HasPrefix(action, s3ActionPrefix) && !isIAM {
		return false, ErrInvalidActionFormat
	}

	index := strings.IndexByte(action, Wildcard[0])
	if index != -1 && index != utf8.RuneCountInString(action)-1 {
		return false, ErrInvalidActionFormat
	}

	return isIAM, nil
}

func splitGroupedConditions(groupedConditions []GroupedConditions) [][]chain.Condition {
	var orConditions []chain.Condition
	commonConditions := make([]chain.Condition, 0, len(groupedConditions))
	for _, grouped := range groupedConditions {
		if grouped.Any {
			orConditions = append(orConditions, grouped.Conditions...)
		} else {
			commonConditions = append(commonConditions, grouped.Conditions...)
		}
	}

	if len(orConditions) == 0 {
		return [][]chain.Condition{commonConditions}
	}

	res := make([][]chain.Condition, len(orConditions))
	for i, condition := range orConditions {
		res[i] = append([]chain.Condition{condition}, commonConditions...)
	}

	return res
}

func formStatus(statement Statement) chain.Status {
	status := chain.AccessDenied
	if statement.Effect == AllowEffect {
		status = chain.Allow
	}

	return status
}