From a8e52ef7aaf41995e644ac92925db22f1295ccd0 Mon Sep 17 00:00:00 2001 From: Airat Arifullin Date: Thu, 11 Jan 2024 11:51:29 +0300 Subject: [PATCH] [#898] control: Fix codes for returning APE errors Signed-off-by: Airat Arifullin --- pkg/ape/chainbase/boltdb.go | 36 +++++++++++++------- pkg/services/control/server/policy_engine.go | 12 +++++-- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/pkg/ape/chainbase/boltdb.go b/pkg/ape/chainbase/boltdb.go index 34a5c7d2..1047ad42 100644 --- a/pkg/ape/chainbase/boltdb.go +++ b/pkg/ape/chainbase/boltdb.go @@ -25,15 +25,16 @@ type boltLocalOverrideStorage struct { var chainBucket = []byte{0} var ( - ErrChainBucketNotFound = logicerr.New("chain root bucket has not been found") + // ErrRootBucketNotFound signals the database has not been properly initialized. + ErrRootBucketNotFound = logicerr.New("root bucket not found") - ErrChainNotFound = logicerr.New("chain has not been found") + ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket not found") - ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket has not been found") + ErrTargetTypeBucketNotFound = logicerr.New("target type bucket not found") - ErrTargetTypeBucketNotFound = logicerr.New("target type bucket has not been found") + ErrTargetNameBucketNotFound = logicerr.New("target name bucket not found") - ErrTargetNameBucketNotFound = logicerr.New("target name bucket has not been found") + ErrBucketNotContainsChainID = logicerr.New("chain id not found in bucket") ) // NewBoltLocalOverrideDatabase returns storage wrapper for storing access policy engine @@ -101,31 +102,30 @@ func (cs *boltLocalOverrideStorage) Close() error { func getTargetBucket(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) { cbucket := tx.Bucket(chainBucket) if cbucket == nil { - return nil, ErrChainBucketNotFound + return nil, ErrRootBucketNotFound } nbucket := cbucket.Bucket([]byte(name)) if nbucket == nil { - return nil, fmt.Errorf("global namespace %s: %w", name, ErrGlobalNamespaceBucketNotFound) + return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name) } typeBucket := nbucket.Bucket([]byte{byte(target.Type)}) if typeBucket == nil { - return nil, fmt.Errorf("type bucket '%c': %w", target.Type, ErrTargetTypeBucketNotFound) + return nil, fmt.Errorf("%w: %w: %c", policyengine.ErrChainNotFound, ErrTargetTypeBucketNotFound, target.Type) } rbucket := typeBucket.Bucket([]byte(target.Name)) if rbucket == nil { - return nil, fmt.Errorf("target name bucket %s: %w", target.Name, ErrTargetNameBucketNotFound) + return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrTargetNameBucketNotFound, target.Name) } - return rbucket, nil } func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) { cbucket := tx.Bucket(chainBucket) if cbucket == nil { - return nil, ErrChainBucketNotFound + return nil, ErrRootBucketNotFound } nbucket := cbucket.Bucket([]byte(name)) @@ -186,7 +186,7 @@ func (cs *boltLocalOverrideStorage) GetOverride(name chain.Name, target policyen } serializedChain = rbuck.Get([]byte(chainID)) if serializedChain == nil { - return ErrChainNotFound + return fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrBucketNotContainsChainID, chainID) } serializedChain = slice.Copy(serializedChain) return nil @@ -225,7 +225,7 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy return nil }) }); err != nil { - if errors.Is(err, ErrGlobalNamespaceBucketNotFound) || errors.Is(err, ErrTargetNameBucketNotFound) { + if errors.Is(err, policyengine.ErrChainNotFound) { return []*chain.Chain{}, nil } return nil, err @@ -243,6 +243,16 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy func (cs *boltLocalOverrideStorage) DropAllOverrides(name chain.Name) error { return cs.db.Update(func(tx *bbolt.Tx) error { + cbucket := tx.Bucket(chainBucket) + if cbucket == nil { + return ErrRootBucketNotFound + } + + nbucket := cbucket.Bucket([]byte(name)) + if nbucket == nil { + return fmt.Errorf("%w: %w: global namespace %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name) + } + return tx.DeleteBucket([]byte(name)) }) } diff --git a/pkg/services/control/server/policy_engine.go b/pkg/services/control/server/policy_engine.go index 805c669a..519b2103 100644 --- a/pkg/services/control/server/policy_engine.go +++ b/pkg/services/control/server/policy_engine.go @@ -128,12 +128,18 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove return nil, err } + removed := true if err = s.localOverrideStorage.LocalStorage().RemoveOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())); err != nil { - return nil, status.Error(getCodeByLocalStorageErr(err), err.Error()) + code := getCodeByLocalStorageErr(err) + if code == codes.NotFound { + removed = false + } else { + return nil, status.Error(code, err.Error()) + } } resp := &control.RemoveChainLocalOverrideResponse{ Body: &control.RemoveChainLocalOverrideResponse_Body{ - Removed: true, + Removed: removed, }, } err = SignMessage(s.key, resp) @@ -144,7 +150,7 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove } func getCodeByLocalStorageErr(err error) codes.Code { - if errors.Is(err, engine.ErrChainNotFound) { + if errors.Is(err, engine.ErrChainNotFound) || errors.Is(err, engine.ErrChainNameNotFound) { return codes.NotFound } return codes.Internal