package apemanager

import (
	"context"
	"sync/atomic"

	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/apemanager"
	ape_grpc "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/apemanager/grpc"
	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/audit"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
)

var _ Server = (*auditService)(nil)

type auditService struct {
	next    Server
	log     *logger.Logger
	enabled *atomic.Bool
}

func NewAuditService(next Server, log *logger.Logger, enabled *atomic.Bool) Server {
	return &auditService{
		next:    next,
		log:     log,
		enabled: enabled,
	}
}

// AddChain implements Server.
func (a *auditService) AddChain(ctx context.Context, req *apemanager.AddChainRequest) (*apemanager.AddChainResponse, error) {
	res, err := a.next.AddChain(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}

	audit.LogRequest(a.log, ape_grpc.APEManagerService_AddChain_FullMethodName, req,
		audit.TargetFromChainID(req.GetBody().GetTarget().GetTargetType().String(),
			req.GetBody().GetTarget().GetName(),
			res.GetBody().GetChainID()),
		err == nil)

	return res, err
}

// ListChains implements Server.
func (a *auditService) ListChains(ctx context.Context, req *apemanager.ListChainsRequest) (*apemanager.ListChainsResponse, error) {
	res, err := a.next.ListChains(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}

	audit.LogRequest(a.log, ape_grpc.APEManagerService_ListChains_FullMethodName, req,
		audit.TargetFromChainID(req.GetBody().GetTarget().GetTargetType().String(),
			req.GetBody().GetTarget().GetName(),
			nil),
		err == nil)

	return res, err
}

// RemoveChain implements Server.
func (a *auditService) RemoveChain(ctx context.Context, req *apemanager.RemoveChainRequest) (*apemanager.RemoveChainResponse, error) {
	res, err := a.next.RemoveChain(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}

	audit.LogRequest(a.log, ape_grpc.APEManagerService_RemoveChain_FullMethodName, req,
		audit.TargetFromChainID(req.GetBody().GetTarget().GetTargetType().String(),
			req.GetBody().GetTarget().GetName(),
			req.GetBody().GetChainID()),
		err == nil)

	return res, err
}