package iam

import (
	"errors"
	"fmt"
	"strconv"
	"strings"
	"time"

	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain"
)

const condKeyAWSPrincipalARN = "aws:PrincipalArn"

const (
	// String condition operators.
	CondStringEquals              string = "StringEquals"
	CondStringNotEquals           string = "StringNotEquals"
	CondStringEqualsIgnoreCase    string = "StringEqualsIgnoreCase"
	CondStringNotEqualsIgnoreCase string = "StringNotEqualsIgnoreCase"
	CondStringLike                string = "StringLike"
	CondStringNotLike             string = "StringNotLike"

	// Numeric condition operators.
	CondNumericEquals            string = "NumericEquals"
	CondNumericNotEquals         string = "NumericNotEquals"
	CondNumericLessThan          string = "NumericLessThan"
	CondNumericLessThanEquals    string = "NumericLessThanEquals"
	CondNumericGreaterThan       string = "NumericGreaterThan"
	CondNumericGreaterThanEquals string = "NumericGreaterThanEquals"

	// Date condition operators.
	CondDateEquals            string = "DateEquals"
	CondDateNotEquals         string = "DateNotEquals"
	CondDateLessThan          string = "DateLessThan"
	CondDateLessThanEquals    string = "DateLessThanEquals"
	CondDateGreaterThan       string = "DateGreaterThan"
	CondDateGreaterThanEquals string = "DateGreaterThanEquals"

	// Bolean condition operators.
	CondBool string = "Bool"

	// IP address condition operators.
	CondIPAddress    string = "IpAddress"
	CondNotIPAddress string = "NotIpAddress"

	// ARN condition operators.
	CondArnEquals    string = "ArnEquals"
	CondArnLike      string = "ArnLike"
	CondArnNotEquals string = "ArnNotEquals"
	CondArnNotLike   string = "ArnNotLike"
)

const (
	arnIAMPrefix     = "arn:aws:iam::"
	s3ResourcePrefix = "arn:aws:s3:::"
	s3ActionPrefix   = "s3:"
)

var (
	// ErrInvalidPrincipalFormat occurs when principal has unknown/unsupported format.
	ErrInvalidPrincipalFormat = errors.New("invalid principal format")

	// ErrInvalidResourceFormat occurs when resource has unknown/unsupported format.
	ErrInvalidResourceFormat = errors.New("invalid resource format")
)

type formPrincipalConditionFunc func(string) chain.Condition

type transformConditionFunc func(gr GroupedConditions) (GroupedConditions, error)

func convertToChainConditions(c Conditions, transformer transformConditionFunc) ([]GroupedConditions, error) {
	conditions, err := convertToChainCondition(c)
	if err != nil {
		return nil, err
	}

	for i := range conditions {
		if conditions[i], err = transformer(conditions[i]); err != nil {
			return nil, fmt.Errorf("transform condition: %w", err)
		}
	}

	return conditions, nil
}

type GroupedConditions struct {
	Conditions []chain.Condition
	Any        bool
}

func convertToChainCondition(c Conditions) ([]GroupedConditions, error) {
	var grouped []GroupedConditions

	for op, KVs := range c {
		condType, convertValue, err := getConditionTypeAndConverter(op)
		if err != nil {
			return nil, err
		}

		for key, values := range KVs {
			group := GroupedConditions{
				Conditions: make([]chain.Condition, len(values)),
				Any:        len(values) > 1,
			}

			for i, val := range values {
				converted, err := convertValue(val)
				if err != nil {
					return nil, err
				}

				group.Conditions[i] = chain.Condition{
					Op:     condType,
					Object: chain.ObjectRequest,
					Key:    key,
					Value:  converted,
				}
			}
			grouped = append(grouped, group)
		}
	}

	return grouped, nil
}

func getConditionTypeAndConverter(op string) (chain.ConditionType, convertFunction, error) {
	switch {
	case strings.HasPrefix(op, "String"):
		switch op {
		case CondStringEquals:
			return chain.CondStringEquals, noConvertFunction, nil
		case CondStringNotEquals:
			return chain.CondStringNotEquals, noConvertFunction, nil
		case CondStringEqualsIgnoreCase:
			return chain.CondStringEqualsIgnoreCase, noConvertFunction, nil
		case CondStringNotEqualsIgnoreCase:
			return chain.CondStringNotEqualsIgnoreCase, noConvertFunction, nil
		case CondStringLike:
			return chain.CondStringLike, noConvertFunction, nil
		case CondStringNotLike:
			return chain.CondStringNotLike, noConvertFunction, nil
		default:
			return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
		}
	case strings.HasPrefix(op, "Arn"):
		switch op {
		case CondArnEquals:
			return chain.CondStringEquals, noConvertFunction, nil
		case CondArnNotEquals:
			return chain.CondStringNotEquals, noConvertFunction, nil
		case CondArnLike:
			return chain.CondStringLike, noConvertFunction, nil
		case CondArnNotLike:
			return chain.CondStringNotLike, noConvertFunction, nil
		default:
			return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
		}
	case strings.HasPrefix(op, "Numeric"):
		// TODO
		return 0, nil, fmt.Errorf("currently nummeric conditions unsupported: '%s'", op)
	case strings.HasPrefix(op, "Date"):
		switch op {
		case CondDateEquals:
			return chain.CondStringEquals, dateConvertFunction, nil
		case CondDateNotEquals:
			return chain.CondStringNotEquals, dateConvertFunction, nil
		case CondDateLessThan:
			return chain.CondStringLessThan, dateConvertFunction, nil
		case CondDateLessThanEquals:
			return chain.CondStringLessThanEquals, dateConvertFunction, nil
		case CondDateGreaterThan:
			return chain.CondStringGreaterThan, dateConvertFunction, nil
		case CondDateGreaterThanEquals:
			return chain.CondStringGreaterThanEquals, dateConvertFunction, nil
		default:
			return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
		}
	case op == CondBool:
		return chain.CondStringEqualsIgnoreCase, noConvertFunction, nil
	case op == CondIPAddress:
		// todo consider using converters
		//  "203.0.113.0/24" -> "203.0.113.*",
		//  "2001:DB8:1234:5678::/64" -> "2001:DB8:1234:5678:*"
		//  or having specific condition type for IP
		return chain.CondStringLike, noConvertFunction, nil
	case op == CondNotIPAddress:
		return chain.CondStringNotLike, noConvertFunction, nil
	default:
		return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op)
	}
}

type convertFunction func(string) (string, error)

func noConvertFunction(val string) (string, error) {
	return val, nil
}

func dateConvertFunction(val string) (string, error) {
	if _, err := strconv.ParseInt(val, 10, 64); err == nil {
		return val, nil
	}

	parsed, err := time.Parse(time.RFC3339, val)
	if err != nil {
		return "", err
	}

	return strconv.FormatInt(parsed.UTC().Unix(), 10), nil
}

func parsePrincipalAsIAMUser(principal string) (account string, user string, err error) {
	if !strings.HasPrefix(principal, arnIAMPrefix) {
		return "", "", ErrInvalidPrincipalFormat
	}

	// iam arn format arn:aws:iam::<account>:user/<user-name-with-path>
	iamResource := strings.TrimPrefix(principal, arnIAMPrefix)
	sepIndex := strings.Index(iamResource, ":user/")
	if sepIndex < 0 {
		return "", "", ErrInvalidPrincipalFormat
	}

	account = iamResource[:sepIndex]
	user = iamResource[sepIndex+6:]
	if len(user) == 0 {
		return "", "", ErrInvalidPrincipalFormat
	}

	userNameIndex := strings.LastIndexByte(user, '/')
	if userNameIndex > -1 {
		user = user[userNameIndex+1:]
		if len(user) == 0 {
			return "", "", ErrInvalidPrincipalFormat
		}
	}

	return account, user, nil
}

func parseResourceAsS3ARN(resource string) (bucket string, object string, err error) {
	if !strings.HasPrefix(resource, s3ResourcePrefix) {
		return "", "", ErrInvalidResourceFormat
	}

	// iam arn format arn:aws:s3:::<bucket-name>/<object-name>
	s3Resource := strings.TrimPrefix(resource, s3ResourcePrefix)
	sepIndex := strings.Index(s3Resource, "/")
	if sepIndex < 0 {
		return s3Resource, Wildcard, nil
	}

	bucket = s3Resource[:sepIndex]
	object = s3Resource[sepIndex+1:]
	if len(object) == 0 {
		return bucket, Wildcard, nil
	}

	if bucket == Wildcard && object != Wildcard {
		return "", "", ErrInvalidResourceFormat
	}

	return bucket, object, nil
}

func splitGroupedConditions(groupedConditions []GroupedConditions) [][]chain.Condition {
	var orConditions []chain.Condition
	commonConditions := make([]chain.Condition, 0, len(groupedConditions))
	for _, grouped := range groupedConditions {
		if grouped.Any {
			orConditions = append(orConditions, grouped.Conditions...)
		} else {
			commonConditions = append(commonConditions, grouped.Conditions...)
		}
	}

	if len(orConditions) == 0 {
		return [][]chain.Condition{commonConditions}
	}

	res := make([][]chain.Condition, len(orConditions))
	for i, condition := range orConditions {
		res[i] = append([]chain.Condition{condition}, commonConditions...)
	}

	return res
}

func formStatus(statement Statement) chain.Status {
	status := chain.AccessDenied
	if statement.Effect == AllowEffect {
		status = chain.Allow
	}

	return status
}