package ape

import (
	"context"
	"encoding/hex"
	"errors"
	"fmt"

	objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/engine"
	objectSvc "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object"
	getsvc "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/get"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/util"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
	nativeschema "git.frostfs.info/TrueCloudLab/policy-engine/schema/native"
)

var errFailedToCastToRequestContext = errors.New("failed cast to RequestContext")

type Service struct {
	log *logger.Logger

	apeChecker Checker

	next objectSvc.ServiceServer
}

var _ objectSvc.ServiceServer = (*Service)(nil)

type HeaderProvider interface {
	GetHeader(ctx context.Context, cnr cid.ID, oid oid.ID, local bool) (*objectSDK.Object, error)
}

type storageEngineHeaderProvider struct {
	storageEngine *engine.StorageEngine
	getSvc        *getsvc.Service
}

func (p storageEngineHeaderProvider) GetHeader(ctx context.Context, cnr cid.ID, objID oid.ID, local bool) (*objectSDK.Object, error) {
	var addr oid.Address
	addr.SetContainer(cnr)
	addr.SetObject(objID)
	if local {
		return engine.Head(ctx, p.storageEngine, addr)
	}
	w := getsvc.NewSimpleObjectWriter()
	var headPrm getsvc.HeadPrm
	headPrm.WithAddress(addr)
	headPrm.SetHeaderWriter(w)
	headPrm.SetCommonParameters(&util.CommonPrm{}) // default values are ok
	if err := p.getSvc.Head(ctx, headPrm); err != nil {
		return nil, err
	}
	return w.Object(), nil
}

func NewStorageEngineHeaderProvider(e *engine.StorageEngine, s *getsvc.Service) HeaderProvider {
	return storageEngineHeaderProvider{
		storageEngine: e,
		getSvc:        s,
	}
}

func NewService(log *logger.Logger, apeChecker Checker, next objectSvc.ServiceServer) *Service {
	return &Service{
		log:        log,
		apeChecker: apeChecker,
		next:       next,
	}
}

type getStreamBasicChecker struct {
	objectSvc.GetObjectStream

	apeChecker Checker

	namespace string

	senderKey []byte

	containerOwner user.ID

	role string

	softAPECheck bool

	bearerToken *bearer.Token
}

func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error {
	if partInit, ok := resp.GetBody().GetObjectPart().(*objectV2.GetObjectPartInit); ok {
		cnrID, objID, err := getAddressParamsSDK(partInit.GetHeader().GetContainerID(), partInit.GetObjectID())
		if err != nil {
			return toStatusErr(err)
		}

		prm := Prm{
			Namespace:      g.namespace,
			Container:      cnrID,
			Object:         objID,
			Header:         partInit.GetHeader(),
			Method:         nativeschema.MethodGetObject,
			SenderKey:      hex.EncodeToString(g.senderKey),
			ContainerOwner: g.containerOwner,
			Role:           g.role,
			SoftAPECheck:   g.softAPECheck,
			BearerToken:    g.bearerToken,
			XHeaders:       resp.GetMetaHeader().GetXHeaders(),
		}

		if err := g.apeChecker.CheckAPE(g.Context(), prm); err != nil {
			return toStatusErr(err)
		}
	}
	return g.GetObjectStream.Send(resp)
}

func requestContext(ctx context.Context) (*objectSvc.RequestContext, error) {
	untyped := ctx.Value(objectSvc.RequestContextKey)
	if untyped == nil {
		return nil, fmt.Errorf("no key %s in context", objectSvc.RequestContextKey)
	}
	rc, ok := untyped.(*objectSvc.RequestContext)
	if !ok {
		return nil, errFailedToCastToRequestContext
	}
	return rc, nil
}

func (c *Service) Get(request *objectV2.GetRequest, stream objectSvc.GetObjectStream) error {
	reqCtx, err := requestContext(stream.Context())
	if err != nil {
		return toStatusErr(err)
	}

	return c.next.Get(request, &getStreamBasicChecker{
		GetObjectStream: stream,
		apeChecker:      c.apeChecker,
		namespace:       reqCtx.Namespace,
		senderKey:       reqCtx.SenderKey,
		containerOwner:  reqCtx.ContainerOwner,
		role:            nativeSchemaRole(reqCtx.Role),
		softAPECheck:    reqCtx.SoftAPECheck,
		bearerToken:     reqCtx.BearerToken,
	})
}

type putStreamBasicChecker struct {
	apeChecker Checker

	next objectSvc.PutObjectStream
}

func (p *putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error {
	if partInit, ok := request.GetBody().GetObjectPart().(*objectV2.PutObjectPartInit); ok {
		reqCtx, err := requestContext(ctx)
		if err != nil {
			return toStatusErr(err)
		}

		cnrID, objID, err := getAddressParamsSDK(partInit.GetHeader().GetContainerID(), partInit.GetObjectID())
		if err != nil {
			return toStatusErr(err)
		}

		prm := Prm{
			Namespace:      reqCtx.Namespace,
			Container:      cnrID,
			Object:         objID,
			Header:         partInit.GetHeader(),
			Method:         nativeschema.MethodPutObject,
			SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
			ContainerOwner: reqCtx.ContainerOwner,
			Role:           nativeSchemaRole(reqCtx.Role),
			SoftAPECheck:   reqCtx.SoftAPECheck,
			BearerToken:    reqCtx.BearerToken,
			XHeaders:       request.GetMetaHeader().GetXHeaders(),
		}

		if err := p.apeChecker.CheckAPE(ctx, prm); err != nil {
			return toStatusErr(err)
		}
	}

	return p.next.Send(ctx, request)
}

func (p putStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PutResponse, error) {
	return p.next.CloseAndRecv(ctx)
}

func (c *Service) Put() (objectSvc.PutObjectStream, error) {
	streamer, err := c.next.Put()

	return &putStreamBasicChecker{
		apeChecker: c.apeChecker,
		next:       streamer,
	}, err
}

type patchStreamBasicChecker struct {
	apeChecker Checker

	next objectSvc.PatchObjectStream

	nonFirstSend bool
}

func (p *patchStreamBasicChecker) Send(ctx context.Context, request *objectV2.PatchRequest) error {
	if !p.nonFirstSend {
		p.nonFirstSend = true

		reqCtx, err := requestContext(ctx)
		if err != nil {
			return toStatusErr(err)
		}

		cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
		if err != nil {
			return toStatusErr(err)
		}

		prm := Prm{
			Namespace:      reqCtx.Namespace,
			Container:      cnrID,
			Object:         objID,
			Method:         nativeschema.MethodPatchObject,
			SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
			ContainerOwner: reqCtx.ContainerOwner,
			Role:           nativeSchemaRole(reqCtx.Role),
			SoftAPECheck:   reqCtx.SoftAPECheck,
			BearerToken:    reqCtx.BearerToken,
			XHeaders:       request.GetMetaHeader().GetXHeaders(),
		}

		if err := p.apeChecker.CheckAPE(ctx, prm); err != nil {
			return toStatusErr(err)
		}
	}

	return p.next.Send(ctx, request)
}

func (p patchStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PatchResponse, error) {
	return p.next.CloseAndRecv(ctx)
}

func (c *Service) Patch() (objectSvc.PatchObjectStream, error) {
	streamer, err := c.next.Patch()

	return &patchStreamBasicChecker{
		apeChecker: c.apeChecker,
		next:       streamer,
	}, err
}

func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*objectV2.HeadResponse, error) {
	cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
	if err != nil {
		return nil, err
	}

	reqCtx, err := requestContext(ctx)
	if err != nil {
		return nil, err
	}

	resp, err := c.next.Head(ctx, request)
	if err != nil {
		return nil, err
	}

	header := new(objectV2.Header)
	switch headerPart := resp.GetBody().GetHeaderPart().(type) {
	case *objectV2.ShortHeader:
		cidV2 := new(refs.ContainerID)
		cnrID.WriteToV2(cidV2)
		header.SetContainerID(cidV2)
		header.SetVersion(headerPart.GetVersion())
		header.SetCreationEpoch(headerPart.GetCreationEpoch())
		header.SetOwnerID(headerPart.GetOwnerID())
		header.SetObjectType(headerPart.GetObjectType())
		header.SetHomomorphicHash(header.GetHomomorphicHash())
		header.SetPayloadLength(headerPart.GetPayloadLength())
		header.SetPayloadHash(headerPart.GetPayloadHash())
	case *objectV2.HeaderWithSignature:
		header = headerPart.GetHeader()
	default:
		return resp, nil
	}

	err = c.apeChecker.CheckAPE(ctx, Prm{
		Namespace:      reqCtx.Namespace,
		Container:      cnrID,
		Object:         objID,
		Header:         header,
		Method:         nativeschema.MethodHeadObject,
		Role:           nativeSchemaRole(reqCtx.Role),
		SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
		ContainerOwner: reqCtx.ContainerOwner,
		SoftAPECheck:   reqCtx.SoftAPECheck,
		BearerToken:    reqCtx.BearerToken,
		XHeaders:       request.GetMetaHeader().GetXHeaders(),
	})
	if err != nil {
		return nil, toStatusErr(err)
	}
	return resp, nil
}

func (c *Service) Search(request *objectV2.SearchRequest, stream objectSvc.SearchStream) error {
	var cnrID cid.ID
	if cnrV2 := request.GetBody().GetContainerID(); cnrV2 != nil {
		if err := cnrID.ReadFromV2(*cnrV2); err != nil {
			return toStatusErr(err)
		}
	}

	reqCtx, err := requestContext(stream.Context())
	if err != nil {
		return toStatusErr(err)
	}

	err = c.apeChecker.CheckAPE(stream.Context(), Prm{
		Namespace:      reqCtx.Namespace,
		Container:      cnrID,
		Method:         nativeschema.MethodSearchObject,
		Role:           nativeSchemaRole(reqCtx.Role),
		SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
		ContainerOwner: reqCtx.ContainerOwner,
		SoftAPECheck:   reqCtx.SoftAPECheck,
		BearerToken:    reqCtx.BearerToken,
		XHeaders:       request.GetMetaHeader().GetXHeaders(),
	})
	if err != nil {
		return toStatusErr(err)
	}

	return c.next.Search(request, stream)
}

func (c *Service) Delete(ctx context.Context, request *objectV2.DeleteRequest) (*objectV2.DeleteResponse, error) {
	cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
	if err != nil {
		return nil, err
	}

	reqCtx, err := requestContext(ctx)
	if err != nil {
		return nil, err
	}

	err = c.apeChecker.CheckAPE(ctx, Prm{
		Namespace:      reqCtx.Namespace,
		Container:      cnrID,
		Object:         objID,
		Method:         nativeschema.MethodDeleteObject,
		Role:           nativeSchemaRole(reqCtx.Role),
		SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
		ContainerOwner: reqCtx.ContainerOwner,
		SoftAPECheck:   reqCtx.SoftAPECheck,
		BearerToken:    reqCtx.BearerToken,
		XHeaders:       request.GetMetaHeader().GetXHeaders(),
	})
	if err != nil {
		return nil, toStatusErr(err)
	}

	resp, err := c.next.Delete(ctx, request)
	if err != nil {
		return nil, err
	}

	return resp, nil
}

func (c *Service) GetRange(request *objectV2.GetRangeRequest, stream objectSvc.GetObjectRangeStream) error {
	cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
	if err != nil {
		return toStatusErr(err)
	}

	reqCtx, err := requestContext(stream.Context())
	if err != nil {
		return toStatusErr(err)
	}

	err = c.apeChecker.CheckAPE(stream.Context(), Prm{
		Namespace:      reqCtx.Namespace,
		Container:      cnrID,
		Object:         objID,
		Method:         nativeschema.MethodRangeObject,
		Role:           nativeSchemaRole(reqCtx.Role),
		SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
		ContainerOwner: reqCtx.ContainerOwner,
		SoftAPECheck:   reqCtx.SoftAPECheck,
		BearerToken:    reqCtx.BearerToken,
		XHeaders:       request.GetMetaHeader().GetXHeaders(),
	})
	if err != nil {
		return toStatusErr(err)
	}

	return c.next.GetRange(request, stream)
}

func (c *Service) GetRangeHash(ctx context.Context, request *objectV2.GetRangeHashRequest) (*objectV2.GetRangeHashResponse, error) {
	cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
	if err != nil {
		return nil, err
	}

	reqCtx, err := requestContext(ctx)
	if err != nil {
		return nil, err
	}

	prm := Prm{
		Namespace:      reqCtx.Namespace,
		Container:      cnrID,
		Object:         objID,
		Method:         nativeschema.MethodHashObject,
		Role:           nativeSchemaRole(reqCtx.Role),
		SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
		ContainerOwner: reqCtx.ContainerOwner,
		SoftAPECheck:   reqCtx.SoftAPECheck,
		BearerToken:    reqCtx.BearerToken,
		XHeaders:       request.GetMetaHeader().GetXHeaders(),
	}

	resp, err := c.next.GetRangeHash(ctx, request)
	if err != nil {
		return nil, err
	}

	if err = c.apeChecker.CheckAPE(ctx, prm); err != nil {
		return nil, toStatusErr(err)
	}
	return resp, nil
}

func (c *Service) PutSingle(ctx context.Context, request *objectV2.PutSingleRequest) (*objectV2.PutSingleResponse, error) {
	cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetObject().GetHeader().GetContainerID(), request.GetBody().GetObject().GetObjectID())
	if err != nil {
		return nil, err
	}

	reqCtx, err := requestContext(ctx)
	if err != nil {
		return nil, err
	}

	prm := Prm{
		Namespace:      reqCtx.Namespace,
		Container:      cnrID,
		Object:         objID,
		Header:         request.GetBody().GetObject().GetHeader(),
		Method:         nativeschema.MethodPutObject,
		Role:           nativeSchemaRole(reqCtx.Role),
		SenderKey:      hex.EncodeToString(reqCtx.SenderKey),
		ContainerOwner: reqCtx.ContainerOwner,
		SoftAPECheck:   reqCtx.SoftAPECheck,
		BearerToken:    reqCtx.BearerToken,
		XHeaders:       request.GetMetaHeader().GetXHeaders(),
	}

	if err = c.apeChecker.CheckAPE(ctx, prm); err != nil {
		return nil, toStatusErr(err)
	}

	return c.next.PutSingle(ctx, request)
}

func getAddressParamsSDK(cidV2 *refs.ContainerID, objV2 *refs.ObjectID) (cnrID cid.ID, objID *oid.ID, err error) {
	if cidV2 != nil {
		if err = cnrID.ReadFromV2(*cidV2); err != nil {
			return
		}
	}

	if objV2 != nil {
		objID = new(oid.ID)
		if err = objID.ReadFromV2(*objV2); err != nil {
			return
		}
	}
	return
}