package v2

import (
	"context"
	"errors"
	"fmt"
	"strings"

	objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/session"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/container"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/netmap"
	objectCore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
	apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
	cnrSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	sessionSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/session"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
	"go.uber.org/zap"
)

// Service checks basic ACL rules.
type Service struct {
	*cfg

	c objectCore.SenderClassifier
}

type putStreamBasicChecker struct {
	source *Service
	next   object.PutObjectStream
}

type getStreamBasicChecker struct {
	checker ACLChecker

	object.GetObjectStream

	info RequestInfo
}

type rangeStreamBasicChecker struct {
	checker ACLChecker

	object.GetObjectRangeStream

	info RequestInfo
}

type searchStreamBasicChecker struct {
	checker ACLChecker

	object.SearchStream

	info RequestInfo
}

// Option represents Service constructor option.
type Option func(*cfg)

type cfg struct {
	log *logger.Logger

	containers container.Source

	checker ACLChecker

	irFetcher InnerRingFetcher

	nm netmap.Source

	next object.ServiceServer
}

// New is a constructor for object ACL checking service.
func New(next object.ServiceServer,
	nm netmap.Source,
	irf InnerRingFetcher,
	acl ACLChecker,
	cs container.Source,
	opts ...Option,
) Service {
	cfg := &cfg{
		log:        &logger.Logger{Logger: zap.L()},
		next:       next,
		nm:         nm,
		irFetcher:  irf,
		checker:    acl,
		containers: cs,
	}

	for i := range opts {
		opts[i](cfg)
	}

	return Service{
		cfg: cfg,
		c:   objectCore.NewSenderClassifier(cfg.irFetcher, cfg.nm, cfg.log),
	}
}

// wrappedGetObjectStream propagates RequestContext into GetObjectStream's context.
// This allows to retrieve already calculated immutable request-specific values in next handler invocation.
type wrappedGetObjectStream struct {
	object.GetObjectStream

	requestInfo RequestInfo
}

func (w *wrappedGetObjectStream) Context() context.Context {
	return context.WithValue(w.GetObjectStream.Context(), object.RequestContextKey, &object.RequestContext{
		Namespace:      w.requestInfo.ContainerNamespace(),
		ContainerOwner: w.requestInfo.ContainerOwner(),
		SenderKey:      w.requestInfo.SenderKey(),
		Role:           w.requestInfo.RequestRole(),
		SoftAPECheck:   w.requestInfo.IsSoftAPECheck(),
		BearerToken:    w.requestInfo.Bearer(),
	})
}

func newWrappedGetObjectStreamStream(getObjectStream object.GetObjectStream, reqInfo RequestInfo) object.GetObjectStream {
	return &wrappedGetObjectStream{
		GetObjectStream: getObjectStream,
		requestInfo:     reqInfo,
	}
}

// wrappedRangeStream propagates RequestContext into GetObjectRangeStream's context.
// This allows to retrieve already calculated immutable request-specific values in next handler invocation.
type wrappedRangeStream struct {
	object.GetObjectRangeStream

	requestInfo RequestInfo
}

func (w *wrappedRangeStream) Context() context.Context {
	return context.WithValue(w.GetObjectRangeStream.Context(), object.RequestContextKey, &object.RequestContext{
		Namespace:      w.requestInfo.ContainerNamespace(),
		ContainerOwner: w.requestInfo.ContainerOwner(),
		SenderKey:      w.requestInfo.SenderKey(),
		Role:           w.requestInfo.RequestRole(),
		SoftAPECheck:   w.requestInfo.IsSoftAPECheck(),
		BearerToken:    w.requestInfo.Bearer(),
	})
}

func newWrappedRangeStream(rangeStream object.GetObjectRangeStream, reqInfo RequestInfo) object.GetObjectRangeStream {
	return &wrappedRangeStream{
		GetObjectRangeStream: rangeStream,
		requestInfo:          reqInfo,
	}
}

// wrappedSearchStream propagates RequestContext into SearchStream's context.
// This allows to retrieve already calculated immutable request-specific values in next handler invocation.
type wrappedSearchStream struct {
	object.SearchStream

	requestInfo RequestInfo
}

func (w *wrappedSearchStream) Context() context.Context {
	return context.WithValue(w.SearchStream.Context(), object.RequestContextKey, &object.RequestContext{
		Namespace:      w.requestInfo.ContainerNamespace(),
		ContainerOwner: w.requestInfo.ContainerOwner(),
		SenderKey:      w.requestInfo.SenderKey(),
		Role:           w.requestInfo.RequestRole(),
		SoftAPECheck:   w.requestInfo.IsSoftAPECheck(),
		BearerToken:    w.requestInfo.Bearer(),
	})
}

func newWrappedSearchStream(searchStream object.SearchStream, reqInfo RequestInfo) object.SearchStream {
	return &wrappedSearchStream{
		SearchStream: searchStream,
		requestInfo:  reqInfo,
	}
}

// Get implements ServiceServer interface, makes ACL checks and calls
// next Get method in the ServiceServer pipeline.
func (b Service) Get(request *objectV2.GetRequest, stream object.GetObjectStream) error {
	cnr, err := getContainerIDFromRequest(request)
	if err != nil {
		return err
	}

	obj, err := getObjectIDFromRequestBody(request.GetBody())
	if err != nil {
		return err
	}

	sTok, err := originalSessionToken(request.GetMetaHeader())
	if err != nil {
		return err
	}

	if sTok != nil {
		err = assertSessionRelation(*sTok, cnr, obj)
		if err != nil {
			return err
		}
	}

	bTok, err := originalBearerToken(request.GetMetaHeader())
	if err != nil {
		return err
	}

	req := MetaWithToken{
		vheader: request.GetVerificationHeader(),
		token:   sTok,
		bearer:  bTok,
		src:     request,
	}

	reqInfo, err := b.findRequestInfo(req, cnr, acl.OpObjectGet)
	if err != nil {
		return err
	}

	reqInfo.obj = obj

	if reqInfo.IsSoftAPECheck() {
		if !b.checker.CheckBasicACL(reqInfo) {
			return basicACLErr(reqInfo)
		} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
			return eACLErr(reqInfo, err)
		}
	}

	return b.next.Get(request, &getStreamBasicChecker{
		GetObjectStream: newWrappedGetObjectStreamStream(stream, reqInfo),
		info:            reqInfo,
		checker:         b.checker,
	})
}

func (b Service) Put() (object.PutObjectStream, error) {
	streamer, err := b.next.Put()

	return putStreamBasicChecker{
		source: &b,
		next:   streamer,
	}, err
}

func (b Service) Head(
	ctx context.Context,
	request *objectV2.HeadRequest,
) (*objectV2.HeadResponse, error) {
	cnr, err := getContainerIDFromRequest(request)
	if err != nil {
		return nil, err
	}

	obj, err := getObjectIDFromRequestBody(request.GetBody())
	if err != nil {
		return nil, err
	}

	sTok, err := originalSessionToken(request.GetMetaHeader())
	if err != nil {
		return nil, err
	}

	if sTok != nil {
		err = assertSessionRelation(*sTok, cnr, obj)
		if err != nil {
			return nil, err
		}
	}

	bTok, err := originalBearerToken(request.GetMetaHeader())
	if err != nil {
		return nil, err
	}

	req := MetaWithToken{
		vheader: request.GetVerificationHeader(),
		token:   sTok,
		bearer:  bTok,
		src:     request,
	}

	reqInfo, err := b.findRequestInfo(req, cnr, acl.OpObjectHead)
	if err != nil {
		return nil, err
	}

	reqInfo.obj = obj

	if reqInfo.IsSoftAPECheck() {
		if !b.checker.CheckBasicACL(reqInfo) {
			return nil, basicACLErr(reqInfo)
		} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
			return nil, eACLErr(reqInfo, err)
		}
	}

	resp, err := b.next.Head(requestContext(ctx, reqInfo), request)
	if err == nil {
		if err = b.checker.CheckEACL(resp, reqInfo); err != nil {
			err = eACLErr(reqInfo, err)
		}
	}

	return resp, err
}

func (b Service) Search(request *objectV2.SearchRequest, stream object.SearchStream) error {
	id, err := getContainerIDFromRequest(request)
	if err != nil {
		return err
	}

	sTok, err := originalSessionToken(request.GetMetaHeader())
	if err != nil {
		return err
	}

	if sTok != nil {
		err = assertSessionRelation(*sTok, id, nil)
		if err != nil {
			return err
		}
	}

	bTok, err := originalBearerToken(request.GetMetaHeader())
	if err != nil {
		return err
	}

	req := MetaWithToken{
		vheader: request.GetVerificationHeader(),
		token:   sTok,
		bearer:  bTok,
		src:     request,
	}

	reqInfo, err := b.findRequestInfo(req, id, acl.OpObjectSearch)
	if err != nil {
		return err
	}

	if reqInfo.IsSoftAPECheck() {
		if !b.checker.CheckBasicACL(reqInfo) {
			return basicACLErr(reqInfo)
		} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
			return eACLErr(reqInfo, err)
		}
	}

	return b.next.Search(request, &searchStreamBasicChecker{
		checker:      b.checker,
		SearchStream: newWrappedSearchStream(stream, reqInfo),
		info:         reqInfo,
	})
}

func (b Service) Delete(
	ctx context.Context,
	request *objectV2.DeleteRequest,
) (*objectV2.DeleteResponse, error) {
	cnr, err := getContainerIDFromRequest(request)
	if err != nil {
		return nil, err
	}

	obj, err := getObjectIDFromRequestBody(request.GetBody())
	if err != nil {
		return nil, err
	}

	sTok, err := originalSessionToken(request.GetMetaHeader())
	if err != nil {
		return nil, err
	}

	if sTok != nil {
		err = assertSessionRelation(*sTok, cnr, obj)
		if err != nil {
			return nil, err
		}
	}

	bTok, err := originalBearerToken(request.GetMetaHeader())
	if err != nil {
		return nil, err
	}

	req := MetaWithToken{
		vheader: request.GetVerificationHeader(),
		token:   sTok,
		bearer:  bTok,
		src:     request,
	}

	reqInfo, err := b.findRequestInfo(req, cnr, acl.OpObjectDelete)
	if err != nil {
		return nil, err
	}

	reqInfo.obj = obj

	if reqInfo.IsSoftAPECheck() {
		if !b.checker.CheckBasicACL(reqInfo) {
			return nil, basicACLErr(reqInfo)
		} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
			return nil, eACLErr(reqInfo, err)
		}
	}

	return b.next.Delete(requestContext(ctx, reqInfo), request)
}

func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetObjectRangeStream) error {
	cnr, err := getContainerIDFromRequest(request)
	if err != nil {
		return err
	}

	obj, err := getObjectIDFromRequestBody(request.GetBody())
	if err != nil {
		return err
	}

	sTok, err := originalSessionToken(request.GetMetaHeader())
	if err != nil {
		return err
	}

	if sTok != nil {
		err = assertSessionRelation(*sTok, cnr, obj)
		if err != nil {
			return err
		}
	}

	bTok, err := originalBearerToken(request.GetMetaHeader())
	if err != nil {
		return err
	}

	req := MetaWithToken{
		vheader: request.GetVerificationHeader(),
		token:   sTok,
		bearer:  bTok,
		src:     request,
	}

	reqInfo, err := b.findRequestInfo(req, cnr, acl.OpObjectRange)
	if err != nil {
		return err
	}

	reqInfo.obj = obj

	if reqInfo.IsSoftAPECheck() {
		if !b.checker.CheckBasicACL(reqInfo) {
			return basicACLErr(reqInfo)
		} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
			return eACLErr(reqInfo, err)
		}
	}

	return b.next.GetRange(request, &rangeStreamBasicChecker{
		checker:              b.checker,
		GetObjectRangeStream: newWrappedRangeStream(stream, reqInfo),
		info:                 reqInfo,
	})
}

func requestContext(ctx context.Context, reqInfo RequestInfo) context.Context {
	return context.WithValue(ctx, object.RequestContextKey, &object.RequestContext{
		Namespace:      reqInfo.ContainerNamespace(),
		ContainerOwner: reqInfo.ContainerOwner(),
		SenderKey:      reqInfo.SenderKey(),
		Role:           reqInfo.RequestRole(),
		SoftAPECheck:   reqInfo.IsSoftAPECheck(),
		BearerToken:    reqInfo.Bearer(),
	})
}

func (b Service) GetRangeHash(
	ctx context.Context,
	request *objectV2.GetRangeHashRequest,
) (*objectV2.GetRangeHashResponse, error) {
	cnr, err := getContainerIDFromRequest(request)
	if err != nil {
		return nil, err
	}

	obj, err := getObjectIDFromRequestBody(request.GetBody())
	if err != nil {
		return nil, err
	}

	sTok, err := originalSessionToken(request.GetMetaHeader())
	if err != nil {
		return nil, err
	}

	if sTok != nil {
		err = assertSessionRelation(*sTok, cnr, obj)
		if err != nil {
			return nil, err
		}
	}

	bTok, err := originalBearerToken(request.GetMetaHeader())
	if err != nil {
		return nil, err
	}

	req := MetaWithToken{
		vheader: request.GetVerificationHeader(),
		token:   sTok,
		bearer:  bTok,
		src:     request,
	}

	reqInfo, err := b.findRequestInfo(req, cnr, acl.OpObjectHash)
	if err != nil {
		return nil, err
	}

	reqInfo.obj = obj

	if reqInfo.IsSoftAPECheck() {
		if !b.checker.CheckBasicACL(reqInfo) {
			return nil, basicACLErr(reqInfo)
		} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
			return nil, eACLErr(reqInfo, err)
		}
	}

	return b.next.GetRangeHash(requestContext(ctx, reqInfo), request)
}

func (b Service) PutSingle(ctx context.Context, request *objectV2.PutSingleRequest) (*objectV2.PutSingleResponse, error) {
	cnr, err := getContainerIDFromRequest(request)
	if err != nil {
		return nil, err
	}

	idV2 := request.GetBody().GetObject().GetHeader().GetOwnerID()
	if idV2 == nil {
		return nil, errors.New("missing object owner")
	}

	var idOwner user.ID

	err = idOwner.ReadFromV2(*idV2)
	if err != nil {
		return nil, fmt.Errorf("invalid object owner: %w", err)
	}

	obj, err := getObjectIDFromRefObjectID(request.GetBody().GetObject().GetObjectID())
	if err != nil {
		return nil, err
	}

	var sTok *sessionSDK.Object
	sTok, err = readSessionToken(cnr, obj, request.GetMetaHeader().GetSessionToken())
	if err != nil {
		return nil, err
	}

	bTok, err := originalBearerToken(request.GetMetaHeader())
	if err != nil {
		return nil, err
	}

	req := MetaWithToken{
		vheader: request.GetVerificationHeader(),
		token:   sTok,
		bearer:  bTok,
		src:     request,
	}

	reqInfo, err := b.findRequestInfo(req, cnr, acl.OpObjectPut)
	if err != nil {
		return nil, err
	}

	reqInfo.obj = obj

	if reqInfo.IsSoftAPECheck() {
		if !b.checker.CheckBasicACL(reqInfo) || !b.checker.StickyBitCheck(reqInfo, idOwner) {
			return nil, basicACLErr(reqInfo)
		}
		if err := b.checker.CheckEACL(request, reqInfo); err != nil {
			return nil, eACLErr(reqInfo, err)
		}
	}

	return b.next.PutSingle(requestContext(ctx, reqInfo), request)
}

func (p putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error {
	body := request.GetBody()
	if body == nil {
		return errEmptyBody
	}

	part := body.GetObjectPart()
	if part, ok := part.(*objectV2.PutObjectPartInit); ok {
		cnr, err := getContainerIDFromRequest(request)
		if err != nil {
			return err
		}

		idV2 := part.GetHeader().GetOwnerID()
		if idV2 == nil {
			return errors.New("missing object owner")
		}

		var idOwner user.ID

		err = idOwner.ReadFromV2(*idV2)
		if err != nil {
			return fmt.Errorf("invalid object owner: %w", err)
		}

		objV2 := part.GetObjectID()
		var obj *oid.ID

		if objV2 != nil {
			obj = new(oid.ID)

			err = obj.ReadFromV2(*objV2)
			if err != nil {
				return err
			}
		}

		var sTok *sessionSDK.Object
		sTok, err = readSessionToken(cnr, obj, request.GetMetaHeader().GetSessionToken())
		if err != nil {
			return err
		}

		bTok, err := originalBearerToken(request.GetMetaHeader())
		if err != nil {
			return err
		}

		req := MetaWithToken{
			vheader: request.GetVerificationHeader(),
			token:   sTok,
			bearer:  bTok,
			src:     request,
		}

		reqInfo, err := p.source.findRequestInfo(req, cnr, acl.OpObjectPut)
		if err != nil {
			return err
		}

		reqInfo.obj = obj

		if reqInfo.IsSoftAPECheck() {
			if !p.source.checker.CheckBasicACL(reqInfo) || !p.source.checker.StickyBitCheck(reqInfo, idOwner) {
				return basicACLErr(reqInfo)
			}
		}

		ctx = requestContext(ctx, reqInfo)
	}

	return p.next.Send(ctx, request)
}

func readSessionToken(cnr cid.ID, obj *oid.ID, tokV2 *session.Token) (*sessionSDK.Object, error) {
	var sTok *sessionSDK.Object

	if tokV2 != nil {
		sTok = new(sessionSDK.Object)

		err := sTok.ReadFromV2(*tokV2)
		if err != nil {
			return nil, fmt.Errorf("invalid session token: %w", err)
		}

		if sTok.AssertVerb(sessionSDK.VerbObjectDelete) {
			// if session relates to object's removal, we don't check
			// relation of the tombstone to the session here since user
			// can't predict tomb's ID.
			err = assertSessionRelation(*sTok, cnr, nil)
		} else {
			err = assertSessionRelation(*sTok, cnr, obj)
		}

		if err != nil {
			return nil, err
		}
	}

	return sTok, nil
}

func (p putStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PutResponse, error) {
	return p.next.CloseAndRecv(ctx)
}

func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error {
	if _, ok := resp.GetBody().GetObjectPart().(*objectV2.GetObjectPartInit); ok {
		if err := g.checker.CheckEACL(resp, g.info); err != nil {
			return eACLErr(g.info, err)
		}
	}

	return g.GetObjectStream.Send(resp)
}

func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error {
	if err := g.checker.CheckEACL(resp, g.info); err != nil {
		return eACLErr(g.info, err)
	}

	return g.GetObjectRangeStream.Send(resp)
}

func (g *searchStreamBasicChecker) Send(resp *objectV2.SearchResponse) error {
	if err := g.checker.CheckEACL(resp, g.info); err != nil {
		return eACLErr(g.info, err)
	}

	return g.SearchStream.Send(resp)
}

func (b Service) findRequestInfo(req MetaWithToken, idCnr cid.ID, op acl.Op) (info RequestInfo, err error) {
	cnr, err := b.containers.Get(idCnr) // fetch actual container
	if err != nil {
		return info, err
	}

	if req.token != nil {
		currentEpoch, err := b.nm.Epoch()
		if err != nil {
			return info, errors.New("can't fetch current epoch")
		}
		if req.token.ExpiredAt(currentEpoch) {
			return info, new(apistatus.SessionTokenExpired)
		}
		if req.token.InvalidAt(currentEpoch) {
			return info, fmt.Errorf("%s: token is invalid at %d epoch)",
				invalidRequestMessage, currentEpoch)
		}

		if !assertVerb(*req.token, op) {
			return info, errInvalidVerb
		}
	}

	// find request role and key
	ownerID, ownerKey, err := req.RequestOwner()
	if err != nil {
		return info, err
	}
	res, err := b.c.Classify(ownerID, ownerKey, idCnr, cnr.Value)
	if err != nil {
		return info, err
	}

	info.basicACL = cnr.Value.BasicACL()
	info.requestRole = res.Role
	info.operation = op
	info.cnrOwner = cnr.Value.Owner()
	info.idCnr = idCnr

	cnrNamespace, hasNamespace := strings.CutSuffix(cnrSDK.ReadDomain(cnr.Value).Zone(), ".ns")
	if hasNamespace {
		info.cnrNamespace = cnrNamespace
	}

	// it is assumed that at the moment the key will be valid,
	// otherwise the request would not pass validation
	info.senderKey = res.Key

	// add bearer token if it is present in request
	info.bearer = req.bearer

	info.srcRequest = req.src

	return info, nil
}