package object

import (
	"context"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/assert"
	"git.frostfs.info/TrueCloudLab/frostfs-qos/tagging"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/session"
)

var _ ServiceServer = (*qosObjectService)(nil)

type AdjustIOTag interface {
	AdjustIncomingTag(ctx context.Context, requestSignPublicKey []byte) context.Context
}

type qosObjectService struct {
	next ServiceServer
	adj  AdjustIOTag
}

func NewQoSObjectService(next ServiceServer, adjIOTag AdjustIOTag) ServiceServer {
	return &qosObjectService{
		next: next,
		adj:  adjIOTag,
	}
}

func (q *qosObjectService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) {
	ctx = q.adj.AdjustIncomingTag(ctx, req.GetVerificationHeader().GetBodySignature().GetKey())
	return q.next.Delete(ctx, req)
}

func (q *qosObjectService) Get(req *object.GetRequest, s GetObjectStream) error {
	ctx := q.adj.AdjustIncomingTag(s.Context(), req.GetVerificationHeader().GetBodySignature().GetKey())
	return q.next.Get(req, &qosReadStream[*object.GetResponse]{
		ctxF:   func() context.Context { return ctx },
		sender: s,
	})
}

func (q *qosObjectService) GetRange(req *object.GetRangeRequest, s GetObjectRangeStream) error {
	ctx := q.adj.AdjustIncomingTag(s.Context(), req.GetVerificationHeader().GetBodySignature().GetKey())
	return q.next.GetRange(req, &qosReadStream[*object.GetRangeResponse]{
		ctxF:   func() context.Context { return ctx },
		sender: s,
	})
}

func (q *qosObjectService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) {
	ctx = q.adj.AdjustIncomingTag(ctx, req.GetVerificationHeader().GetBodySignature().GetKey())
	return q.next.GetRangeHash(ctx, req)
}

func (q *qosObjectService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) {
	ctx = q.adj.AdjustIncomingTag(ctx, req.GetVerificationHeader().GetBodySignature().GetKey())
	return q.next.Head(ctx, req)
}

func (q *qosObjectService) Patch(ctx context.Context) (PatchObjectStream, error) {
	s, err := q.next.Patch(ctx)
	if err != nil {
		return nil, err
	}
	return &qosWriteStream[*object.PatchRequest, *object.PatchResponse]{
		s:   s,
		adj: q.adj,
	}, nil
}

func (q *qosObjectService) Put(ctx context.Context) (PutObjectStream, error) {
	s, err := q.next.Put(ctx)
	if err != nil {
		return nil, err
	}
	return &qosWriteStream[*object.PutRequest, *object.PutResponse]{
		s:   s,
		adj: q.adj,
	}, nil
}

func (q *qosObjectService) PutSingle(ctx context.Context, req *object.PutSingleRequest) (*object.PutSingleResponse, error) {
	ctx = q.adj.AdjustIncomingTag(ctx, req.GetVerificationHeader().GetBodySignature().GetKey())
	return q.next.PutSingle(ctx, req)
}

func (q *qosObjectService) Search(req *object.SearchRequest, s SearchStream) error {
	ctx := q.adj.AdjustIncomingTag(s.Context(), req.GetVerificationHeader().GetBodySignature().GetKey())
	return q.next.Search(req, &qosReadStream[*object.SearchResponse]{
		ctxF:   func() context.Context { return ctx },
		sender: s,
	})
}

type qosSend[T any] interface {
	Send(T) error
}

type qosReadStream[T any] struct {
	sender qosSend[T]
	ctxF   func() context.Context
}

func (g *qosReadStream[T]) Context() context.Context {
	return g.ctxF()
}

func (g *qosReadStream[T]) Send(resp T) error {
	return g.sender.Send(resp)
}

type qosVerificationHeader interface {
	GetVerificationHeader() *session.RequestVerificationHeader
}

type qosSendRecv[TReq qosVerificationHeader, TResp any] interface {
	Send(context.Context, TReq) error
	CloseAndRecv(context.Context) (TResp, error)
}

type qosWriteStream[TReq qosVerificationHeader, TResp any] struct {
	s   qosSendRecv[TReq, TResp]
	adj AdjustIOTag

	ioTag        string
	ioTagDefined bool
}

func (q *qosWriteStream[TReq, TResp]) CloseAndRecv(ctx context.Context) (TResp, error) {
	if q.ioTagDefined {
		ctx = tagging.ContextWithIOTag(ctx, q.ioTag)
	}
	return q.s.CloseAndRecv(ctx)
}

func (q *qosWriteStream[TReq, TResp]) Send(ctx context.Context, req TReq) error {
	if !q.ioTagDefined {
		ctx = q.adj.AdjustIncomingTag(ctx, req.GetVerificationHeader().GetBodySignature().GetKey())
		q.ioTag, q.ioTagDefined = tagging.IOTagFromContext(ctx)
	}
	assert.True(q.ioTagDefined, "io tag undefined after incoming tag adjustment")
	ctx = tagging.ContextWithIOTag(ctx, q.ioTag)
	return q.s.Send(ctx, req)
}