package control import ( "context" "errors" "fmt" "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/control" apechain "git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain" engine "git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func apeTarget(chainTarget *control.ChainTarget) (engine.Target, error) { switch chainTarget.GetType() { case control.ChainTarget_CONTAINER: return engine.ContainerTarget(chainTarget.GetName()), nil case control.ChainTarget_NAMESPACE: return engine.NamespaceTarget(chainTarget.GetName()), nil default: } return engine.Target{}, status.Error(codes.InvalidArgument, fmt.Errorf("target type is not supported: %s", chainTarget.GetType().String()).Error()) } func (s *Server) AddChainLocalOverride(_ context.Context, req *control.AddChainLocalOverrideRequest) (*control.AddChainLocalOverrideResponse, error) { if err := s.isValidRequest(req); err != nil { return nil, status.Error(codes.PermissionDenied, err.Error()) } var chain apechain.Chain if err := chain.DecodeBytes(req.GetBody().GetChain()); err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } s.apeChainCounter.Add(1) // TODO (aarifullin): the such chain id is not well-designed yet. if chain.ID == "" { chain.ID = apechain.ID(fmt.Sprintf("%s:%d", apechain.Ingress, s.apeChainCounter.Load())) } target, err := apeTarget(req.GetBody().GetTarget()) if err != nil { return nil, err } if _, err = s.localOverrideStorage.LocalStorage().AddOverride(apechain.Ingress, target, &chain); err != nil { return nil, status.Error(getCodeByLocalStorageErr(err), err.Error()) } resp := &control.AddChainLocalOverrideResponse{ Body: &control.AddChainLocalOverrideResponse_Body{ ChainId: string(chain.ID), }, } err = SignMessage(s.key, resp) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } return resp, nil } func (s *Server) GetChainLocalOverride(_ context.Context, req *control.GetChainLocalOverrideRequest) (*control.GetChainLocalOverrideResponse, error) { if err := s.isValidRequest(req); err != nil { return nil, status.Error(codes.PermissionDenied, err.Error()) } target, err := apeTarget(req.GetBody().GetTarget()) if err != nil { return nil, err } chain, err := s.localOverrideStorage.LocalStorage().GetOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())) if err != nil { return nil, status.Error(getCodeByLocalStorageErr(err), err.Error()) } resp := &control.GetChainLocalOverrideResponse{ Body: &control.GetChainLocalOverrideResponse_Body{ Chain: chain.Bytes(), }, } err = SignMessage(s.key, resp) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } return resp, nil } func (s *Server) ListChainLocalOverrides(_ context.Context, req *control.ListChainLocalOverridesRequest) (*control.ListChainLocalOverridesResponse, error) { if err := s.isValidRequest(req); err != nil { return nil, status.Error(codes.PermissionDenied, err.Error()) } target, err := apeTarget(req.GetBody().GetTarget()) if err != nil { return nil, err } chains, err := s.localOverrideStorage.LocalStorage().ListOverrides(apechain.Ingress, target) if err != nil { return nil, status.Error(getCodeByLocalStorageErr(err), err.Error()) } serializedChains := make([][]byte, 0, len(chains)) for _, chain := range chains { serializedChains = append(serializedChains, chain.Bytes()) } resp := &control.ListChainLocalOverridesResponse{ Body: &control.ListChainLocalOverridesResponse_Body{ Chains: serializedChains, }, } err = SignMessage(s.key, resp) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } return resp, nil } func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.RemoveChainLocalOverrideRequest) (*control.RemoveChainLocalOverrideResponse, error) { if err := s.isValidRequest(req); err != nil { return nil, status.Error(codes.PermissionDenied, err.Error()) } target, err := apeTarget(req.GetBody().GetTarget()) if err != nil { return nil, err } if err = s.localOverrideStorage.LocalStorage().RemoveOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())); err != nil { return nil, status.Error(getCodeByLocalStorageErr(err), err.Error()) } resp := &control.RemoveChainLocalOverrideResponse{ Body: &control.RemoveChainLocalOverrideResponse_Body{ Removed: true, }, } err = SignMessage(s.key, resp) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } return resp, nil } func getCodeByLocalStorageErr(err error) codes.Code { if errors.Is(err, engine.ErrChainNotFound) { return codes.NotFound } return codes.Internal }