package object

import (
	"context"
	"errors"
	"sync/atomic"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/audit"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/util"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/object"
	objectGRPC "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/object/grpc"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/refs"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
)

var _ ServiceServer = (*auditService)(nil)

type auditService struct {
	next    ServiceServer
	log     *logger.Logger
	enabled *atomic.Bool
}

func NewAuditService(next ServiceServer, log *logger.Logger, enabled *atomic.Bool) ServiceServer {
	return &auditService{
		next:    next,
		log:     log,
		enabled: enabled,
	}
}

// Delete implements ServiceServer.
func (a *auditService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) {
	res, err := a.next.Delete(ctx, req)
	if !a.enabled.Load() {
		return res, err
	}
	audit.LogRequest(ctx, a.log, objectGRPC.ObjectService_Delete_FullMethodName, req,
		audit.TargetFromRef(req.GetBody().GetAddress(), &oid.Address{}), err == nil)
	return res, err
}

// Get implements ServiceServer.
func (a *auditService) Get(req *object.GetRequest, stream GetObjectStream) error {
	err := a.next.Get(req, stream)
	if !a.enabled.Load() {
		return err
	}
	audit.LogRequest(stream.Context(), a.log, objectGRPC.ObjectService_Get_FullMethodName, req,
		audit.TargetFromRef(req.GetBody().GetAddress(), &oid.Address{}), err == nil)
	return err
}

// GetRange implements ServiceServer.
func (a *auditService) GetRange(req *object.GetRangeRequest, stream GetObjectRangeStream) error {
	err := a.next.GetRange(req, stream)
	if !a.enabled.Load() {
		return err
	}
	audit.LogRequest(stream.Context(), a.log, objectGRPC.ObjectService_GetRange_FullMethodName, req,
		audit.TargetFromRef(req.GetBody().GetAddress(), &oid.Address{}), err == nil)
	return err
}

// GetRangeHash implements ServiceServer.
func (a *auditService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) {
	resp, err := a.next.GetRangeHash(ctx, req)
	if !a.enabled.Load() {
		return resp, err
	}
	audit.LogRequest(ctx, a.log, objectGRPC.ObjectService_GetRangeHash_FullMethodName, req,
		audit.TargetFromRef(req.GetBody().GetAddress(), &oid.Address{}), err == nil)
	return resp, err
}

// Head implements ServiceServer.
func (a *auditService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) {
	resp, err := a.next.Head(ctx, req)
	if !a.enabled.Load() {
		return resp, err
	}
	audit.LogRequest(ctx, a.log, objectGRPC.ObjectService_Head_FullMethodName, req,
		audit.TargetFromRef(req.GetBody().GetAddress(), &oid.Address{}), err == nil)
	return resp, err
}

// Put implements ServiceServer.
func (a *auditService) Put(ctx context.Context) (PutObjectStream, error) {
	res, err := a.next.Put(ctx)
	if !a.enabled.Load() {
		return res, err
	}
	if err != nil {
		audit.LogRequest(ctx, a.log, objectGRPC.ObjectService_Put_FullMethodName, nil, nil, false)
		return res, err
	}
	return &auditPutStream{
		stream: res,
		log:    a.log,
	}, nil
}

// PutSingle implements ServiceServer.
func (a *auditService) PutSingle(ctx context.Context, req *object.PutSingleRequest) (*object.PutSingleResponse, error) {
	resp, err := a.next.PutSingle(ctx, req)
	if !a.enabled.Load() {
		return resp, err
	}
	audit.LogRequest(ctx, a.log, objectGRPC.ObjectService_PutSingle_FullMethodName, req,
		audit.TargetFromContainerIDObjectID(req.GetBody().GetObject().GetHeader().GetContainerID(),
			req.GetBody().GetObject().GetObjectID()),
		err == nil)
	return resp, err
}

// Search implements ServiceServer.
func (a *auditService) Search(req *object.SearchRequest, stream SearchStream) error {
	err := a.next.Search(req, stream)
	if !a.enabled.Load() {
		return err
	}
	audit.LogRequest(stream.Context(), a.log, objectGRPC.ObjectService_Search_FullMethodName, req,
		audit.TargetFromRef(req.GetBody().GetContainerID(), &cid.ID{}), err == nil)
	return err
}

var _ PutObjectStream = (*auditPutStream)(nil)

type auditPutStream struct {
	stream PutObjectStream
	log    *logger.Logger

	failed      bool
	key         []byte
	containerID *refs.ContainerID
	objectID    *refs.ObjectID
}

// CloseAndRecv implements PutObjectStream.
func (a *auditPutStream) CloseAndRecv(ctx context.Context) (*object.PutResponse, error) {
	resp, err := a.stream.CloseAndRecv(ctx)
	if err != nil {
		a.failed = true
	}
	a.objectID = resp.GetBody().GetObjectID()
	audit.LogRequestWithKey(ctx, a.log, objectGRPC.ObjectService_Put_FullMethodName, a.key,
		audit.TargetFromContainerIDObjectID(a.containerID, a.objectID),
		!a.failed)
	return resp, err
}

// Send implements PutObjectStream.
func (a *auditPutStream) Send(ctx context.Context, req *object.PutRequest) error {
	if partInit, ok := req.GetBody().GetObjectPart().(*object.PutObjectPartInit); ok {
		a.containerID = partInit.GetHeader().GetContainerID()
		a.objectID = partInit.GetObjectID()
		a.key = req.GetVerificationHeader().GetBodySignature().GetKey()
	}

	err := a.stream.Send(ctx, req)
	if err != nil {
		a.failed = true
	}
	if !errors.Is(err, util.ErrAbortStream) { // CloseAndRecv will not be called, so log here
		audit.LogRequestWithKey(ctx, a.log, objectGRPC.ObjectService_Put_FullMethodName, a.key,
			audit.TargetFromContainerIDObjectID(a.containerID, a.objectID),
			!a.failed)
	}
	return err
}

type auditPatchStream struct {
	stream PatchObjectStream
	log    *logger.Logger

	failed      bool
	key         []byte
	containerID *refs.ContainerID
	objectID    *refs.ObjectID

	nonFirstSend bool
}

func (a *auditService) Patch(ctx context.Context) (PatchObjectStream, error) {
	res, err := a.next.Patch(ctx)
	if !a.enabled.Load() {
		return res, err
	}
	if err != nil {
		audit.LogRequest(ctx, a.log, objectGRPC.ObjectService_Patch_FullMethodName, nil, nil, false)
		return res, err
	}
	return &auditPatchStream{
		stream: res,
		log:    a.log,
	}, nil
}

// CloseAndRecv implements PatchObjectStream.
func (a *auditPatchStream) CloseAndRecv(ctx context.Context) (*object.PatchResponse, error) {
	resp, err := a.stream.CloseAndRecv(ctx)
	if err != nil {
		a.failed = true
	}
	a.objectID = resp.GetBody().GetObjectID()
	audit.LogRequestWithKey(ctx, a.log, objectGRPC.ObjectService_Patch_FullMethodName, a.key,
		audit.TargetFromContainerIDObjectID(a.containerID, a.objectID),
		!a.failed)
	return resp, err
}

// Send implements PatchObjectStream.
func (a *auditPatchStream) Send(ctx context.Context, req *object.PatchRequest) error {
	if !a.nonFirstSend {
		a.containerID = req.GetBody().GetAddress().GetContainerID()
		a.objectID = req.GetBody().GetAddress().GetObjectID()
		a.key = req.GetVerificationHeader().GetBodySignature().GetKey()
		a.nonFirstSend = true
	}

	err := a.stream.Send(ctx, req)
	if err != nil {
		a.failed = true
	}
	if !errors.Is(err, util.ErrAbortStream) { // CloseAndRecv will not be called, so log here
		audit.LogRequestWithKey(ctx, a.log, objectGRPC.ObjectService_Patch_FullMethodName, a.key,
			audit.TargetFromContainerIDObjectID(a.containerID, a.objectID),
			!a.failed)
	}
	return err
}