control: Fix codes for returing APE errors #901

Merged
fyrchik merged 2 commits from fix/898-ape_error_codes into master 2024-01-11 12:31:32 +00:00
2 changed files with 33 additions and 17 deletions

View file

@ -25,15 +25,16 @@ type boltLocalOverrideStorage struct {
var chainBucket = []byte{0}
var (
ErrChainBucketNotFound = logicerr.New("chain root bucket has not been found")
// ErrRootBucketNotFound signals the database has not been properly initialized.
ErrRootBucketNotFound = logicerr.New("root bucket not found")
ErrChainNotFound = logicerr.New("chain has not been found")
ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket not found")
ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket has not been found")
ErrTargetTypeBucketNotFound = logicerr.New("target type bucket not found")
ErrTargetTypeBucketNotFound = logicerr.New("target type bucket has not been found")
ErrTargetNameBucketNotFound = logicerr.New("target name bucket not found")
ErrTargetNameBucketNotFound = logicerr.New("target name bucket has not been found")
ErrBucketNotContainsChainID = logicerr.New("chain id not found in bucket")
)
// NewBoltLocalOverrideDatabase returns storage wrapper for storing access policy engine
@ -101,31 +102,30 @@ func (cs *boltLocalOverrideStorage) Close() error {
func getTargetBucket(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) {
cbucket := tx.Bucket(chainBucket)
if cbucket == nil {
return nil, ErrChainBucketNotFound
return nil, ErrRootBucketNotFound
}
nbucket := cbucket.Bucket([]byte(name))
if nbucket == nil {
return nil, fmt.Errorf("global namespace %s: %w", name, ErrGlobalNamespaceBucketNotFound)
return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name)
}
typeBucket := nbucket.Bucket([]byte{byte(target.Type)})
if typeBucket == nil {
return nil, fmt.Errorf("type bucket '%c': %w", target.Type, ErrTargetTypeBucketNotFound)
return nil, fmt.Errorf("%w: %w: %c", policyengine.ErrChainNotFound, ErrTargetTypeBucketNotFound, target.Type)
}
rbucket := typeBucket.Bucket([]byte(target.Name))
if rbucket == nil {
return nil, fmt.Errorf("target name bucket %s: %w", target.Name, ErrTargetNameBucketNotFound)
return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrTargetNameBucketNotFound, target.Name)
}
return rbucket, nil
}
func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) {
cbucket := tx.Bucket(chainBucket)
if cbucket == nil {
return nil, ErrChainBucketNotFound
return nil, ErrRootBucketNotFound
}
nbucket := cbucket.Bucket([]byte(name))
@ -140,7 +140,7 @@ func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyen
typeBucket := nbucket.Bucket([]byte{byte(target.Type)})
if typeBucket == nil {
var err error
typeBucket, err = cbucket.CreateBucket([]byte{byte(target.Type)})
typeBucket, err = nbucket.CreateBucket([]byte{byte(target.Type)})
if err != nil {
return nil, fmt.Errorf("could not create a bucket for the target type '%c': %w", target.Type, err)
}
@ -186,7 +186,7 @@ func (cs *boltLocalOverrideStorage) GetOverride(name chain.Name, target policyen
}
serializedChain = rbuck.Get([]byte(chainID))
if serializedChain == nil {
return ErrChainNotFound
return fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrBucketNotContainsChainID, chainID)
}
serializedChain = slice.Copy(serializedChain)
return nil
@ -225,7 +225,7 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy
return nil
})
}); err != nil {
if errors.Is(err, ErrGlobalNamespaceBucketNotFound) || errors.Is(err, ErrTargetNameBucketNotFound) {
if errors.Is(err, policyengine.ErrChainNotFound) {
return []*chain.Chain{}, nil
}
return nil, err
@ -243,6 +243,16 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy
func (cs *boltLocalOverrideStorage) DropAllOverrides(name chain.Name) error {
return cs.db.Update(func(tx *bbolt.Tx) error {
cbucket := tx.Bucket(chainBucket)
if cbucket == nil {
return ErrRootBucketNotFound
}
nbucket := cbucket.Bucket([]byte(name))
if nbucket == nil {
return fmt.Errorf("%w: %w: global namespace %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name)
}
return tx.DeleteBucket([]byte(name))
})
}

View file

@ -128,12 +128,18 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove
return nil, err
}
removed := true
if err = s.localOverrideStorage.LocalStorage().RemoveOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())); err != nil {
return nil, status.Error(getCodeByLocalStorageErr(err), err.Error())
code := getCodeByLocalStorageErr(err)
if code == codes.NotFound {
removed = false
} else {
return nil, status.Error(code, err.Error())
}
}
resp := &control.RemoveChainLocalOverrideResponse{
Body: &control.RemoveChainLocalOverrideResponse_Body{
Removed: true,
Removed: removed,
},
}
err = SignMessage(s.key, resp)
@ -144,7 +150,7 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove
}
func getCodeByLocalStorageErr(err error) codes.Code {
if errors.Is(err, engine.ErrChainNotFound) {
if errors.Is(err, engine.ErrChainNotFound) || errors.Is(err, engine.ErrChainNameNotFound) {
return codes.NotFound
}
return codes.Internal