package tree

import (
	"context"

	"google.golang.org/grpc"
)

var _ TreeServiceServer = (*ioTagAdjust)(nil)

type AdjustIOTag interface {
	AdjustIncomingTag(ctx context.Context, requestSignPublicKey []byte) context.Context
}

type ioTagAdjust struct {
	s TreeServiceServer
	a AdjustIOTag
}

func NewIOTagAdjustServer(s TreeServiceServer, a AdjustIOTag) TreeServiceServer {
	return &ioTagAdjust{
		s: s,
		a: a,
	}
}

func (i *ioTagAdjust) Add(ctx context.Context, req *AddRequest) (*AddResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.Add(ctx, req)
}

func (i *ioTagAdjust) AddByPath(ctx context.Context, req *AddByPathRequest) (*AddByPathResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.AddByPath(ctx, req)
}

func (i *ioTagAdjust) Apply(ctx context.Context, req *ApplyRequest) (*ApplyResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.Apply(ctx, req)
}

func (i *ioTagAdjust) GetNodeByPath(ctx context.Context, req *GetNodeByPathRequest) (*GetNodeByPathResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.GetNodeByPath(ctx, req)
}

func (i *ioTagAdjust) GetOpLog(req *GetOpLogRequest, srv TreeService_GetOpLogServer) error {
	ctx := i.a.AdjustIncomingTag(srv.Context(), req.GetSignature().GetKey())
	return i.s.GetOpLog(req, &qosServerWrapper[*GetOpLogResponse]{
		sender:       srv,
		ServerStream: srv,
		ctxF:         func() context.Context { return ctx },
	})
}

func (i *ioTagAdjust) GetSubTree(req *GetSubTreeRequest, srv TreeService_GetSubTreeServer) error {
	ctx := i.a.AdjustIncomingTag(srv.Context(), req.GetSignature().GetKey())
	return i.s.GetSubTree(req, &qosServerWrapper[*GetSubTreeResponse]{
		sender:       srv,
		ServerStream: srv,
		ctxF:         func() context.Context { return ctx },
	})
}

func (i *ioTagAdjust) Healthcheck(ctx context.Context, req *HealthcheckRequest) (*HealthcheckResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.Healthcheck(ctx, req)
}

func (i *ioTagAdjust) Move(ctx context.Context, req *MoveRequest) (*MoveResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.Move(ctx, req)
}

func (i *ioTagAdjust) Remove(ctx context.Context, req *RemoveRequest) (*RemoveResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.Remove(ctx, req)
}

func (i *ioTagAdjust) TreeList(ctx context.Context, req *TreeListRequest) (*TreeListResponse, error) {
	ctx = i.a.AdjustIncomingTag(ctx, req.GetSignature().GetKey())
	return i.s.TreeList(ctx, req)
}

type qosSend[T any] interface {
	Send(T) error
}

type qosServerWrapper[T any] struct {
	grpc.ServerStream
	sender qosSend[T]
	ctxF   func() context.Context
}

func (w *qosServerWrapper[T]) Send(resp T) error {
	return w.sender.Send(resp)
}

func (w *qosServerWrapper[T]) Context() context.Context {
	return w.ctxF()
}