[#898] control: Fix codes for returning APE errors
All checks were successful
DCO action / DCO (pull_request) Successful in 1m23s
Build / Build Components (1.21) (pull_request) Successful in 3m22s
Vulncheck / Vulncheck (pull_request) Successful in 2m49s
Build / Build Components (1.20) (pull_request) Successful in 3m44s
Tests and linters / Lint (pull_request) Successful in 4m57s
Tests and linters / Staticcheck (pull_request) Successful in 4m52s
Tests and linters / Tests (1.20) (pull_request) Successful in 6m59s
Tests and linters / Tests (1.21) (pull_request) Successful in 7m34s
Tests and linters / Tests with -race (pull_request) Successful in 8m12s

Signed-off-by: Airat Arifullin <a.arifullin@yadro.com>
This commit is contained in:
Airat Arifullin 2024-01-11 11:51:29 +03:00
parent 2d831bbe9f
commit 5f55096252
2 changed files with 32 additions and 16 deletions

View file

@ -25,15 +25,16 @@ type boltLocalOverrideStorage struct {
var chainBucket = []byte{0} var chainBucket = []byte{0}
var ( var (
ErrChainBucketNotFound = logicerr.New("chain root bucket has not been found") // ErrRootBucketNotFound signals the database has not been properly initialized.
ErrRootBucketNotFound = logicerr.New("root bucket not found")
ErrChainNotFound = logicerr.New("chain has not been found") ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket not found")
ErrGlobalNamespaceBucketNotFound = logicerr.New("global namespace bucket has not been found") ErrTargetTypeBucketNotFound = logicerr.New("target type bucket not found")
ErrTargetTypeBucketNotFound = logicerr.New("target type bucket has not been found") ErrTargetNameBucketNotFound = logicerr.New("target name bucket not found")
ErrTargetNameBucketNotFound = logicerr.New("target name bucket has not been found") ErrBucketNotContainsChainID = logicerr.New("chain id not found in bucket")
) )
// NewBoltLocalOverrideDatabase returns storage wrapper for storing access policy engine // NewBoltLocalOverrideDatabase returns storage wrapper for storing access policy engine
@ -101,31 +102,30 @@ func (cs *boltLocalOverrideStorage) Close() error {
func getTargetBucket(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) { func getTargetBucket(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) {
cbucket := tx.Bucket(chainBucket) cbucket := tx.Bucket(chainBucket)
if cbucket == nil { if cbucket == nil {
return nil, ErrChainBucketNotFound return nil, ErrRootBucketNotFound
} }
nbucket := cbucket.Bucket([]byte(name)) nbucket := cbucket.Bucket([]byte(name))
if nbucket == nil { if nbucket == nil {
return nil, fmt.Errorf("global namespace %s: %w", name, ErrGlobalNamespaceBucketNotFound) return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name)
} }
typeBucket := nbucket.Bucket([]byte{byte(target.Type)}) typeBucket := nbucket.Bucket([]byte{byte(target.Type)})
if typeBucket == nil { if typeBucket == nil {
return nil, fmt.Errorf("type bucket '%c': %w", target.Type, ErrTargetTypeBucketNotFound) return nil, fmt.Errorf("%w: %w: %c", policyengine.ErrChainNotFound, ErrTargetTypeBucketNotFound, target.Type)
} }
rbucket := typeBucket.Bucket([]byte(target.Name)) rbucket := typeBucket.Bucket([]byte(target.Name))
if rbucket == nil { if rbucket == nil {
return nil, fmt.Errorf("target name bucket %s: %w", target.Name, ErrTargetNameBucketNotFound) return nil, fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrTargetNameBucketNotFound, target.Name)
} }
return rbucket, nil return rbucket, nil
} }
func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) { func getTargetBucketCreateIfEmpty(tx *bbolt.Tx, name chain.Name, target policyengine.Target) (*bbolt.Bucket, error) {
cbucket := tx.Bucket(chainBucket) cbucket := tx.Bucket(chainBucket)
if cbucket == nil { if cbucket == nil {
return nil, ErrChainBucketNotFound return nil, ErrRootBucketNotFound
} }
nbucket := cbucket.Bucket([]byte(name)) nbucket := cbucket.Bucket([]byte(name))
@ -186,7 +186,7 @@ func (cs *boltLocalOverrideStorage) GetOverride(name chain.Name, target policyen
} }
serializedChain = rbuck.Get([]byte(chainID)) serializedChain = rbuck.Get([]byte(chainID))
if serializedChain == nil { if serializedChain == nil {
return ErrChainNotFound return fmt.Errorf("%w: %w: %s", policyengine.ErrChainNotFound, ErrBucketNotContainsChainID, chainID)
} }
serializedChain = slice.Copy(serializedChain) serializedChain = slice.Copy(serializedChain)
return nil return nil
@ -225,7 +225,7 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy
return nil return nil
}) })
}); err != nil { }); err != nil {
if errors.Is(err, ErrGlobalNamespaceBucketNotFound) || errors.Is(err, ErrTargetNameBucketNotFound) { if errors.Is(err, policyengine.ErrChainNotFound) {
return []*chain.Chain{}, nil return []*chain.Chain{}, nil
} }
return nil, err return nil, err
@ -243,6 +243,16 @@ func (cs *boltLocalOverrideStorage) ListOverrides(name chain.Name, target policy
func (cs *boltLocalOverrideStorage) DropAllOverrides(name chain.Name) error { func (cs *boltLocalOverrideStorage) DropAllOverrides(name chain.Name) error {
return cs.db.Update(func(tx *bbolt.Tx) error { return cs.db.Update(func(tx *bbolt.Tx) error {
cbucket := tx.Bucket(chainBucket)
if cbucket == nil {
return ErrRootBucketNotFound
}
nbucket := cbucket.Bucket([]byte(name))
if nbucket == nil {
return fmt.Errorf("%w: %w: global namespace %s", policyengine.ErrChainNotFound, ErrGlobalNamespaceBucketNotFound, name)
}
return tx.DeleteBucket([]byte(name)) return tx.DeleteBucket([]byte(name))
}) })
} }

View file

@ -128,12 +128,18 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove
return nil, err return nil, err
} }
removed := true
if err = s.localOverrideStorage.LocalStorage().RemoveOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())); err != nil { if err = s.localOverrideStorage.LocalStorage().RemoveOverride(apechain.Ingress, target, apechain.ID(req.GetBody().GetChainId())); err != nil {
return nil, status.Error(getCodeByLocalStorageErr(err), err.Error()) code := getCodeByLocalStorageErr(err)
if code == codes.NotFound {
removed = false
} else {
return nil, status.Error(code, err.Error())
}
} }
resp := &control.RemoveChainLocalOverrideResponse{ resp := &control.RemoveChainLocalOverrideResponse{
Body: &control.RemoveChainLocalOverrideResponse_Body{ Body: &control.RemoveChainLocalOverrideResponse_Body{
Removed: true, Removed: removed,
}, },
} }
err = SignMessage(s.key, resp) err = SignMessage(s.key, resp)
@ -144,7 +150,7 @@ func (s *Server) RemoveChainLocalOverride(_ context.Context, req *control.Remove
} }
func getCodeByLocalStorageErr(err error) codes.Code { func getCodeByLocalStorageErr(err error) codes.Code {
if errors.Is(err, engine.ErrChainNotFound) { if errors.Is(err, engine.ErrChainNotFound) || errors.Is(err, engine.ErrChainNameNotFound) {
return codes.NotFound return codes.NotFound
} }
return codes.Internal return codes.Internal