package tree

import (
	"context"
	"sync/atomic"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/audit"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
)

var _ TreeServiceServer = (*auditService)(nil)

type auditService struct {
	next    TreeServiceServer
	log     *logger.Logger
	enabled *atomic.Bool
}

func NewAuditService(next TreeServiceServer, log *logger.Logger, enabled *atomic.Bool) TreeServiceServer {
	return &auditService{
		next:    next,
		log:     log,
		enabled: enabled,
	}
}

// Add implements TreeServiceServer.
func (a *auditService) Add(ctx context.Context, req *AddRequest) (*AddResponse, error) {
	res, err := a.next.Add(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_Add_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return res, err
}

// AddByPath implements TreeServiceServer.
func (a *auditService) AddByPath(ctx context.Context, req *AddByPathRequest) (*AddByPathResponse, error) {
	res, err := a.next.AddByPath(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_AddByPath_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return res, err
}

// Apply implements TreeServiceServer.
func (a *auditService) Apply(ctx context.Context, req *ApplyRequest) (*ApplyResponse, error) {
	res, err := a.next.Apply(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_Apply_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return res, err
}

// GetNodeByPath implements TreeServiceServer.
func (a *auditService) GetNodeByPath(ctx context.Context, req *GetNodeByPathRequest) (*GetNodeByPathResponse, error) {
	res, err := a.next.GetNodeByPath(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_GetNodeByPath_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return res, err
}

// GetOpLog implements TreeServiceServer.
func (a *auditService) GetOpLog(req *GetOpLogRequest, srv TreeService_GetOpLogServer) error {
	err := a.next.GetOpLog(req, srv)
	if !a.enabled.Load() {
		return err
	}
	audit.LogRequestWithKey(a.log, TreeService_GetOpLog_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return err
}

// GetSubTree implements TreeServiceServer.
func (a *auditService) GetSubTree(req *GetSubTreeRequest, srv TreeService_GetSubTreeServer) error {
	err := a.next.GetSubTree(req, srv)
	if !a.enabled.Load() {
		return err
	}
	audit.LogRequestWithKey(a.log, TreeService_GetSubTree_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return err
}

// Healthcheck implements TreeServiceServer.
func (a *auditService) Healthcheck(ctx context.Context, req *HealthcheckRequest) (*HealthcheckResponse, error) {
	res, err := a.next.Healthcheck(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_Healthcheck_FullMethodName, req.GetSignature().GetKey(),
		nil, err == nil)
	return res, err
}

// Move implements TreeServiceServer.
func (a *auditService) Move(ctx context.Context, req *MoveRequest) (*MoveResponse, error) {
	res, err := a.next.Move(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_Move_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return res, err
}

// Remove implements TreeServiceServer.
func (a *auditService) Remove(ctx context.Context, req *RemoveRequest) (*RemoveResponse, error) {
	res, err := a.next.Remove(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_Remove_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), req.GetBody().GetTreeId()), err == nil)
	return res, err
}

// TreeList implements TreeServiceServer.
func (a *auditService) TreeList(ctx context.Context, req *TreeListRequest) (*TreeListResponse, error) {
	res, err := a.next.TreeList(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequestWithKey(a.log, TreeService_TreeList_FullMethodName, req.GetSignature().GetKey(),
		audit.TargetFromTreeID(req.GetBody().GetContainerId(), ""), err == nil)
	return res, err
}