package signature

import (
	"errors"
	"fmt"

	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/session"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/util/signature"
	"golang.org/x/sync/errgroup"
)

type signatureProvider interface {
	GetBodySignature() *refs.Signature
	GetMetaSignature() *refs.Signature
	GetOriginSignature() *refs.Signature
}

// VerifyServiceMessage verifies service message.
func VerifyServiceMessage(msg any) error {
	switch v := msg.(type) {
	case nil:
		return nil
	case serviceRequest:
		return verifyServiceRequest(v)
	case serviceResponse:
		return verifyServiceResponse(v)
	default:
		panic(fmt.Sprintf("unsupported session message %T", v))
	}
}

func verifyServiceRequest(v serviceRequest) error {
	meta := v.GetMetaHeader()
	verificationHeader := v.GetVerificationHeader()
	body := serviceMessageBody(v)
	return verifyServiceRequestRecursive(body, meta, verificationHeader)
}

func verifyServiceRequestRecursive(body stableMarshaler, meta *session.RequestMetaHeader, verify *session.RequestVerificationHeader) error {
	verificationHeaderOrigin := verify.GetOrigin()
	metaOrigin := meta.GetOrigin()

	stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify)
	if err != nil {
		return err
	}
	if stop {
		return nil
	}

	return verifyServiceRequestRecursive(body, metaOrigin, verificationHeaderOrigin)
}

func verifyMessageParts(body, meta, originHeader stableMarshaler, hasOriginHeader bool, sigProvider signatureProvider) (stop bool, err error) {
	eg := &errgroup.Group{}

	eg.Go(func() error {
		if err := verifyServiceMessagePart(meta, sigProvider.GetMetaSignature); err != nil {
			return fmt.Errorf("could not verify meta header: %w", err)
		}
		return nil
	})

	eg.Go(func() error {
		if err := verifyServiceMessagePart(originHeader, sigProvider.GetOriginSignature); err != nil {
			return fmt.Errorf("could not verify origin of verification header: %w", err)
		}
		return nil
	})

	if !hasOriginHeader {
		eg.Go(func() error {
			if err := verifyServiceMessagePart(body, sigProvider.GetBodySignature); err != nil {
				return fmt.Errorf("could not verify body: %w", err)
			}
			return nil
		})
	}

	if err := eg.Wait(); err != nil {
		return false, err
	}

	if !hasOriginHeader {
		return true, nil
	}

	if sigProvider.GetBodySignature() != nil {
		return false, errors.New("body signature misses at the matryoshka upper level")
	}

	return false, nil
}

func verifyServiceResponse(v serviceResponse) error {
	meta := v.GetMetaHeader()
	verificationHeader := v.GetVerificationHeader()
	body := serviceMessageBody(v)
	return verifyServiceResponseRecursive(body, meta, verificationHeader)
}

func verifyServiceResponseRecursive(body stableMarshaler, meta *session.ResponseMetaHeader, verify *session.ResponseVerificationHeader) error {
	verificationHeaderOrigin := verify.GetOrigin()
	metaOrigin := meta.GetOrigin()

	stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify)
	if err != nil {
		return err
	}
	if stop {
		return nil
	}

	return verifyServiceResponseRecursive(body, metaOrigin, verificationHeaderOrigin)
}

func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *refs.Signature) error {
	wrapper := StableMarshalerWrapper{
		SM: part,
	}

	return signature.VerifyDataWithSource(
		wrapper,
		sigRdr,
	)
}