package iam

import (
	"encoding/json"
	"errors"
	"fmt"
)

type (
	// Policy grammar https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_grammar.html
	Policy struct {
		Version   string     `json:"Version,omitempty"`
		ID        string     `json:"Id,omitempty"`
		Statement Statements `json:"Statement"`
	}

	Statements []Statement

	Statement struct {
		ID           string     `json:"Id,omitempty"`
		SID          string     `json:"Sid,omitempty"`
		Principal    Principal  `json:"Principal,omitempty"`
		NotPrincipal Principal  `json:"NotPrincipal,omitempty"`
		Effect       Effect     `json:"Effect"`
		Action       Action     `json:"Action,omitempty"`
		NotAction    Action     `json:"NotAction,omitempty"`
		Resource     Resource   `json:"Resource,omitempty"`
		NotResource  Resource   `json:"NotResource,omitempty"`
		Conditions   Conditions `json:"Condition,omitempty"`
	}

	Principal map[PrincipalType][]string

	Effect string

	Action []string

	Resource []string

	Conditions map[string]Condition

	Condition map[string][]string

	PolicyType int

	PrincipalType string
)

const (
	GeneralPolicyType PolicyType = iota
	IdentityBasedPolicyType
	ResourceBasedPolicyType
)

const Wildcard = "*"

const (
	AllowEffect Effect = "Allow"
	DenyEffect  Effect = "Deny"
)

func (e Effect) IsValid() bool {
	return e == AllowEffect || e == DenyEffect
}

const (
	AWSPrincipalType           PrincipalType = "AWS"
	FederatedPrincipalType     PrincipalType = "Federated"
	ServicePrincipalType       PrincipalType = "Service"
	CanonicalUserPrincipalType PrincipalType = "CanonicalUser"
)

func (p PrincipalType) IsValid() bool {
	return p == AWSPrincipalType || p == FederatedPrincipalType ||
		p == ServicePrincipalType || p == CanonicalUserPrincipalType
}

func (s *Statements) UnmarshalJSON(data []byte) error {
	var list []Statement
	if err := json.Unmarshal(data, &list); err == nil {
		*s = list
		return nil
	}

	var elem Statement
	if err := json.Unmarshal(data, &elem); err != nil {
		return err
	}

	*s = []Statement{elem}

	return nil
}

func (p *Principal) UnmarshalJSON(data []byte) error {
	*p = make(Principal)

	var str string

	if err := json.Unmarshal(data, &str); err == nil {
		if str != Wildcard {
			return errors.New("invalid IAM string principal")
		}
		(*p)[Wildcard] = nil
		return nil
	}

	m := make(map[PrincipalType]any)
	if err := json.Unmarshal(data, &m); err != nil {
		return err
	}

	for key, val := range m {
		element, ok := val.(string)
		if ok {
			(*p)[key] = []string{element}
			continue
		}

		list, ok := val.([]any)
		if !ok {
			return errors.New("invalid principal format")
		}

		resList := make([]string, len(list))
		for i := range list {
			val, ok := list[i].(string)
			if !ok {
				return errors.New("invalid principal format")
			}
			resList[i] = val
		}

		(*p)[key] = resList
	}

	return nil
}

func (a *Action) UnmarshalJSON(data []byte) error {
	var list []string
	if err := json.Unmarshal(data, &list); err == nil {
		*a = list
		return nil
	}

	var elem string
	if err := json.Unmarshal(data, &elem); err != nil {
		return err
	}

	*a = []string{elem}

	return nil
}

func (r *Resource) UnmarshalJSON(data []byte) error {
	var list []string
	if err := json.Unmarshal(data, &list); err == nil {
		*r = list
		return nil
	}

	var elem string
	if err := json.Unmarshal(data, &elem); err != nil {
		return err
	}

	*r = []string{elem}

	return nil
}

func (c *Condition) UnmarshalJSON(data []byte) error {
	*c = make(Condition)

	m := make(map[string]any)
	if err := json.Unmarshal(data, &m); err != nil {
		return err
	}

	for key, val := range m {
		element, ok := val.(string)
		if ok {
			(*c)[key] = []string{element}
			continue
		}

		list, ok := val.([]any)
		if !ok {
			return errors.New("invalid principal format")
		}

		resList := make([]string, len(list))
		for i := range list {
			val, ok := list[i].(string)
			if !ok {
				return errors.New("invalid principal format")
			}
			resList[i] = val
		}

		(*c)[key] = resList
	}

	return nil
}

func (p Policy) Validate(typ PolicyType) error {
	if err := p.validate(); err != nil {
		return err
	}

	switch typ {
	case IdentityBasedPolicyType:
		return p.validateIdentityBased()
	case ResourceBasedPolicyType:
		return p.validateResourceBased()
	default:
		return nil
	}
}

func (p Policy) validate() error {
	if len(p.Statement) == 0 {
		return errors.New("'Statement' is missing")
	}

	for _, statement := range p.Statement {
		if !statement.Effect.IsValid() {
			return fmt.Errorf("unknown effect: '%s'", statement.Effect)
		}
		if len(statement.Action) != 0 && len(statement.NotAction) != 0 {
			return errors.New("'Actions' and 'NotAction' are mutually exclusive")
		}
		if statement.Resource != nil && statement.NotResource != nil {
			return errors.New("'Resources' and 'NotResource' are mutually exclusive")
		}
		if len(statement.Resource) == 0 && len(statement.NotResource) == 0 {
			return errors.New("one of 'Resources'/'NotResource' must be provided")
		}
		if len(statement.Principal) != 0 && len(statement.NotPrincipal) != 0 {
			return errors.New("'Principal' and 'NotPrincipal' are mutually exclusive")
		}
		if len(statement.NotPrincipal) != 0 && statement.Effect != DenyEffect {
			return errors.New("using 'NotPrincipal' with effect 'Allow' is not supported")
		}

		principal, _ := statement.principal()
		if err := principal.validate(); err != nil {
			return err
		}
	}

	return nil
}

func (p Policy) validateIdentityBased() error {
	if len(p.ID) != 0 {
		return errors.New("'Id' is not allowed for identity-based policy")
	}

	for _, statement := range p.Statement {
		if len(statement.Principal) != 0 || len(statement.NotPrincipal) != 0 {
			return errors.New("'Principal' and 'NotPrincipal' are not allowed for identity-based policy")
		}
	}

	return nil
}

func (p Policy) validateResourceBased() error {
	for _, statement := range p.Statement {
		if len(statement.Principal) == 0 && len(statement.NotPrincipal) == 0 {
			return errors.New("'Principal' or 'NotPrincipal' must be provided for resource-based policy")
		}
	}

	return nil
}

func (s Statement) principal() (Principal, bool) {
	if len(s.NotPrincipal) != 0 {
		return s.NotPrincipal, true
	}

	return s.Principal, false
}

func (s Statement) action() (Action, bool) {
	if len(s.NotAction) != 0 {
		return s.NotAction, true
	}

	return s.Action, false
}

func (s Statement) resource() (Resource, bool) {
	if len(s.NotResource) != 0 {
		return s.NotResource, true
	}

	return s.Resource, false
}

func (p Principal) validate() error {
	if _, ok := p[Wildcard]; ok && len(p) == 1 {
		return nil
	}

	for key := range p {
		if !key.IsValid() {
			return fmt.Errorf("unknown principal type: '%s'", key)
		}
	}

	return nil
}