package control

import (
	"context"
	"errors"
	"fmt"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/control"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	apechain "git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain"
	engine "git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine"
	nativeschema "git.frostfs.info/TrueCloudLab/policy-engine/schema/native"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

func (s *Server) AddChainLocalOverride(_ context.Context, req *control.AddChainLocalOverrideRequest) (*control.AddChainLocalOverrideResponse, error) {
	if err := s.isValidRequest(req); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	var cid cid.ID
	err := cid.Decode(req.GetBody().GetContainerId())
	if err != nil {
		return nil, status.Error(codes.InvalidArgument, err.Error())
	}

	var chain apechain.Chain
	if err = chain.DecodeBytes(req.GetBody().GetChain()); err != nil {
		return nil, status.Error(codes.InvalidArgument, err.Error())
	}

	src, err := s.apeChainSrc.GetChainSource(cid)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}

	s.apeChainCounter.Add(1)
	// TODO (aarifullin): the such chain id is not well-designed yet.
	chain.ID = apechain.ID(fmt.Sprintf("%s:%d", apechain.Ingress, s.apeChainCounter.Load()))

	resource := fmt.Sprintf(nativeschema.ResourceFormatRootContainerObjects, cid.EncodeToString())
	if _, err = src.LocalStorage().AddOverride(apechain.Ingress, resource, &chain); err != nil {
		return nil, status.Error(getCodeByLocalStorageErr(err), err.Error())
	}

	resp := &control.AddChainLocalOverrideResponse{
		Body: &control.AddChainLocalOverrideResponse_Body{
			ChainId: string(chain.ID),
		},
	}
	err = SignMessage(s.key, resp)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}
	return resp, nil
}

func (s *Server) GetChainLocalOverride(_ context.Context, req *control.GetChainLocalOverrideRequest) (*control.GetChainLocalOverrideResponse, error) {
	if err := s.isValidRequest(req); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	var cid cid.ID
	err := cid.Decode(req.GetBody().GetContainerId())
	if err != nil {
		return nil, status.Error(codes.InvalidArgument, err.Error())
	}

	src, err := s.apeChainSrc.GetChainSource(cid)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}

	resource := fmt.Sprintf(nativeschema.ResourceFormatRootContainerObjects, cid.EncodeToString())
	chain, err := src.LocalStorage().GetOverride(apechain.Ingress, resource, apechain.ID(req.GetBody().GetChainId()))
	if err != nil {
		return nil, status.Error(getCodeByLocalStorageErr(err), err.Error())
	}

	resp := &control.GetChainLocalOverrideResponse{
		Body: &control.GetChainLocalOverrideResponse_Body{
			Chain: chain.Bytes(),
		},
	}
	err = SignMessage(s.key, resp)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}
	return resp, nil
}

func (s *Server) ListChainLocalOverrides(_ context.Context, req *control.ListChainLocalOverridesRequest) (*control.ListChainLocalOverridesResponse, error) {
	if err := s.isValidRequest(req); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	var cid cid.ID
	err := cid.Decode(req.GetBody().GetContainerId())
	if err != nil {
		return nil, status.Error(codes.InvalidArgument, err.Error())
	}

	src, err := s.apeChainSrc.GetChainSource(cid)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}

	resource := fmt.Sprintf(nativeschema.ResourceFormatRootContainerObjects, cid.EncodeToString())
	chains, err := src.LocalStorage().ListOverrides(apechain.Ingress, resource)
	if err != nil {
		return nil, status.Error(getCodeByLocalStorageErr(err), err.Error())
	}
	serializedChains := make([][]byte, 0, len(chains))
	for _, chain := range chains {
		serializedChains = append(serializedChains, chain.Bytes())
	}

	resp := &control.ListChainLocalOverridesResponse{
		Body: &control.ListChainLocalOverridesResponse_Body{
			Chains: serializedChains,
		},
	}
	err = SignMessage(s.key, resp)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}
	return resp, nil
}

func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.RemoveChainLocalOverrideRequest) (*control.RemoveChainLocalOverrideResponse, error) {
	if err := s.isValidRequest(req); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	var cid cid.ID
	err := cid.Decode(req.GetBody().GetContainerId())
	if err != nil {
		return nil, status.Error(codes.InvalidArgument, err.Error())
	}

	src, err := s.apeChainSrc.GetChainSource(cid)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}

	resource := fmt.Sprintf(nativeschema.ResourceFormatRootContainerObjects, cid.EncodeToString())
	if err = src.LocalStorage().RemoveOverride(apechain.Ingress, resource, apechain.ID(req.GetBody().GetChainId())); err != nil {
		return nil, status.Error(getCodeByLocalStorageErr(err), err.Error())
	}
	resp := &control.RemoveChainLocalOverrideResponse{
		Body: &control.RemoveChainLocalOverrideResponse_Body{
			Removed: true,
		},
	}
	err = SignMessage(s.key, resp)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}
	return resp, nil
}

func getCodeByLocalStorageErr(err error) codes.Code {
	if errors.Is(err, engine.ErrChainNotFound) {
		return codes.NotFound
	}
	return codes.Internal
}