control: Fix codes for returing APE errors #901

Merged
fyrchik merged 2 commits from fix/898-ape_error_codes into master 2024-01-11 12:31:32 +00:00
2 changed files with 33 additions and 17 deletions

View file

@ -25,15 +25,16 @@ type boltLocalOverrideStorage struct {
var chainBucket = []byte{0} var chainBucket = []byte{0}
var ( var (
ErrChainBucketNotFound = logicerr.New("chain root bucket has not been found") // ErrRootBucketNotFound signals the database has not been properly initialized.
ErrRootBucketNotFound = logicerr.New("root bucket not found")
ErrChainNotFound = logicerr.New("chain has not been found") ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket not found")
ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket has not been found") ErrTargetTypeBucketNotFound = logicerr.New("target type bucket not found")
ErrTargetTypeBucketNotFound = logicerr.New("target type bucket has not been found") ErrTargetNameBucketNotFound = logicerr.New("target name bucket not found")
ErrTargetNameBucketNotFound = logicerr.New("target name bucket has not been found") ErrBucketNotContainsChainID = logicerr.New("chain id not found in bucket")
) )
// NewBoltLocalOverrideDatabase returns storage wrapper for storing access policy engine // NewBoltLocalOverrideDatabase returns storage wrapper for storing access policy engine
@ -101,31 +102,30 @@ func (cs *boltLocalOverrideStorage) Close() error {
func getTargetBucket(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) { func getTargetBucket(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) {
cbucket := tx.Bucket(chainBucket) cbucket := tx.Bucket(chainBucket)
if cbucket == nil { if cbucket == nil {
return nil, ErrChainBucketNotFound return nil, ErrRootBucketNotFound
} }
nbucket := cbucket.Bucket([]byte(name)) nbucket := cbucket.Bucket([]byte(name))
if nbucket == nil { if nbucket == nil {
return nil, fmt.Errorf("global namespace %s: %w", name, ErrGlobalNamespaceBucketNotFound) return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name)
} }
typeBucket := nbucket.Bucket([]byte{byte(target.Type)}) typeBucket := nbucket.Bucket([]byte{byte(target.Type)})
if typeBucket == nil { if typeBucket == nil {
return nil, fmt.Errorf("type bucket '%c': %w", target.Type, ErrTargetTypeBucketNotFound) return nil, fmt.Errorf("%w: %w: %c", policyengine.ErrChainNotFound, ErrTargetTypeBucketNotFound, target.Type)
} }
rbucket := typeBucket.Bucket([]byte(target.Name)) rbucket := typeBucket.Bucket([]byte(target.Name))
if rbucket == nil { if rbucket == nil {
return nil, fmt.Errorf("target name bucket %s: %w", target.Name, ErrTargetNameBucketNotFound) return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrTargetNameBucketNotFound, target.Name)
} }
return rbucket, nil return rbucket, nil
} }
func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) { func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) {
cbucket := tx.Bucket(chainBucket) cbucket := tx.Bucket(chainBucket)
if cbucket == nil { if cbucket == nil {
return nil, ErrChainBucketNotFound return nil, ErrRootBucketNotFound
} }
nbucket := cbucket.Bucket([]byte(name)) nbucket := cbucket.Bucket([]byte(name))
@ -140,7 +140,7 @@ func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyen
typeBucket := nbucket.Bucket([]byte{byte(target.Type)}) typeBucket := nbucket.Bucket([]byte{byte(target.Type)})
if typeBucket == nil { if typeBucket == nil {
var err error var err error
typeBucket, err = cbucket.CreateBucket([]byte{byte(target.Type)}) typeBucket, err = nbucket.CreateBucket([]byte{byte(target.Type)})
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create a bucket for the target type '%c': %w", target.Type, err) return nil, fmt.Errorf("could not create a bucket for the target type '%c': %w", target.Type, err)
} }
@ -186,7 +186,7 @@ func (cs *boltLocalOverrideStorage) GetOverride(name chain.Name, target policyen
} }
serializedChain = rbuck.Get([]byte(chainID)) serializedChain = rbuck.Get([]byte(chainID))
if serializedChain == nil { if serializedChain == nil {
return ErrChainNotFound return fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrBucketNotContainsChainID, chainID)
} }
serializedChain = slice.Copy(serializedChain) serializedChain = slice.Copy(serializedChain)
return nil return nil
@ -225,7 +225,7 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy
return nil return nil
}) })
}); err != nil { }); err != nil {
if errors.Is(err, ErrGlobalNamespaceBucketNotFound) || errors.Is(err, ErrTargetNameBucketNotFound) { if errors.Is(err, policyengine.ErrChainNotFound) {
return []*chain.Chain{}, nil return []*chain.Chain{}, nil
} }
return nil, err return nil, err
@ -243,6 +243,16 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy
func (cs *boltLocalOverrideStorage) DropAllOverrides(name chain.Name) error { func (cs *boltLocalOverrideStorage) DropAllOverrides(name chain.Name) error {
return cs.db.Update(func(tx *bbolt.Tx) error { return cs.db.Update(func(tx *bbolt.Tx) error {
cbucket := tx.Bucket(chainBucket)
if cbucket == nil {
return ErrRootBucketNotFound
}
nbucket := cbucket.Bucket([]byte(name))
if nbucket == nil {
return fmt.Errorf("%w: %w: global namespace %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name)
}
return tx.DeleteBucket([]byte(name)) return tx.DeleteBucket([]byte(name))
}) })
} }

View file

@ -128,12 +128,18 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove
return nil, err return nil, err
} }
removed := true
if err = s.localOverrideStorage.LocalStorage().RemoveOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())); err != nil { if err = s.localOverrideStorage.LocalStorage().RemoveOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())); err != nil {
return nil, status.Error(getCodeByLocalStorageErr(err), err.Error()) code := getCodeByLocalStorageErr(err)
if code == codes.NotFound {
removed = false
} else {
return nil, status.Error(code, err.Error())
}
} }
resp := &control.RemoveChainLocalOverrideResponse{ resp := &control.RemoveChainLocalOverrideResponse{
Body: &control.RemoveChainLocalOverrideResponse_Body{ Body: &control.RemoveChainLocalOverrideResponse_Body{
Removed: true, Removed: removed,
}, },
} }
err = SignMessage(s.key, resp) err = SignMessage(s.key, resp)
@ -144,7 +150,7 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove
} }
func getCodeByLocalStorageErr(err error) codes.Code { func getCodeByLocalStorageErr(err error) codes.Code {
if errors.Is(err, engine.ErrChainNotFound) { if errors.Is(err, engine.ErrChainNotFound) || errors.Is(err, engine.ErrChainNameNotFound) {
return codes.NotFound return codes.NotFound
} }
return codes.Internal return codes.Internal