diff --git a/pkg/engine/inmemory/inmemory_test.go b/pkg/engine/inmemory/inmemory_test.go index 1155488..628a81c 100644 --- a/pkg/engine/inmemory/inmemory_test.go +++ b/pkg/engine/inmemory/inmemory_test.go @@ -245,6 +245,48 @@ func TestInmemory(t *testing.T) { require.False(t, ok) }) }) + + t.Run("remove all", func(t *testing.T) { + s := NewInMemoryLocalOverrides() + _, _, err := s.MorphRuleChainStorage().AddMorphRuleChain(chain.Ingress, engine.NamespaceTarget(namespace), &chain.Chain{ + Rules: []chain.Rule{ + { + Status: chain.AccessDenied, + Actions: chain.Actions{Inverted: true, Names: []string{"native::object::get"}}, + Resources: chain.Resources{Inverted: true, Names: []string{object}}, + }, + }, + }) + require.NoError(t, err) + _, _, err = s.MorphRuleChainStorage().AddMorphRuleChain(chain.Ingress, engine.NamespaceTarget(namespace2), &chain.Chain{ + Rules: []chain.Rule{ + { + Status: chain.Allow, + Actions: chain.Actions{Inverted: true, Names: []string{"native::object::get"}}, + Resources: chain.Resources{Inverted: true, Names: []string{object}}, + }, + }, + }) + require.NoError(t, err) + _, _, err = s.MorphRuleChainStorage().AddMorphRuleChain(chain.Ingress, engine.NamespaceTarget(namespace2), &chain.Chain{ + Rules: []chain.Rule{ + { + Status: chain.AccessDenied, + Actions: chain.Actions{Inverted: true, Names: []string{"native::object::get"}}, + Resources: chain.Resources{Inverted: true, Names: []string{object}}, + }, + }, + }) + require.NoError(t, err) + _, _, err = s.MorphRuleChainStorage().RemoveMorphRuleChainsByTarget(chain.Ingress, engine.NamespaceTarget(namespace2)) + require.NoError(t, err) + chains, err := s.MorphRuleChainStorage().ListMorphRuleChains(chain.Ingress, engine.NamespaceTarget(namespace2)) + require.NoError(t, err) + require.Equal(t, 0, len(chains)) + chains, err = s.MorphRuleChainStorage().ListMorphRuleChains(chain.Ingress, engine.NamespaceTarget(namespace)) + require.NoError(t, err) + require.Equal(t, 1, len(chains)) + }) } func itemStacksEqual(t *testing.T, got []stackitem.Item, expected []stackitem.Item) { diff --git a/pkg/engine/inmemory/local_storage.go b/pkg/engine/inmemory/local_storage.go index 9cbeed9..276b495 100644 --- a/pkg/engine/inmemory/local_storage.go +++ b/pkg/engine/inmemory/local_storage.go @@ -115,6 +115,24 @@ func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Tar return engine.ErrChainNotFound } +func (s *inmemoryLocalStorage) RemoveOverridesByTarget(name chain.Name, target engine.Target) error { + s.guard.Lock() + defer s.guard.Unlock() + + if _, ok := s.nameToResourceChains[name]; !ok { + return engine.ErrChainNameNotFound + } + if target.Name == "" { + target.Name = "root" + } + _, ok := s.nameToResourceChains[name][target] + if ok { + delete(s.nameToResourceChains[name], target) + return nil + } + return engine.ErrResourceNotFound +} + func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Target) ([]*chain.Chain, error) { s.guard.RLock() defer s.guard.RUnlock() diff --git a/pkg/engine/inmemory/local_storage_test.go b/pkg/engine/inmemory/local_storage_test.go index c6ad0c9..f21d94d 100644 --- a/pkg/engine/inmemory/local_storage_test.go +++ b/pkg/engine/inmemory/local_storage_test.go @@ -110,6 +110,52 @@ func TestRemoveOverride(t *testing.T) { require.True(t, ok) require.Len(t, resourceChains, 0) }) + + t.Run("remove by target", func(t *testing.T) { + inmem := testInmemLocalStorage() + t0 := engine.ContainerTarget("name0") + t1 := engine.ContainerTarget("name1") + inmem.AddOverride(chain.Ingress, t0, &chain.Chain{ + ID: chain.ID(chainID), + Rules: []chain.Rule{ + { + Status: chain.AccessDenied, + Actions: chain.Actions{Names: []string{"native::object::delete"}}, + Resources: chain.Resources{Names: []string{"native::object::*"}}, + }, + }, + }) + inmem.AddOverride(chain.Ingress, t0, &chain.Chain{ + ID: chain.ID(chainID), + Rules: []chain.Rule{ + { + Status: chain.Allow, + Actions: chain.Actions{Names: []string{"native::object::delete"}}, + Resources: chain.Resources{Names: []string{"native::object::*"}}, + }, + }, + }) + inmem.AddOverride(chain.Ingress, t1, &chain.Chain{ + ID: chain.ID(chainID), + Rules: []chain.Rule{ + { + Status: chain.Allow, + Actions: chain.Actions{Names: []string{"native::object::delete"}}, + Resources: chain.Resources{Names: []string{"native::object::*"}}, + }, + }, + }) + + err := inmem.RemoveOverridesByTarget(chain.Ingress, t0) + require.NoError(t, err) + + ingressChains, ok := inmem.nameToResourceChains[chain.Ingress] + require.True(t, ok) + require.Len(t, ingressChains, 1) + resourceChains, ok := ingressChains[t1] + require.True(t, ok) + require.Len(t, resourceChains, 1) + }) } func TestGetOverride(t *testing.T) { diff --git a/pkg/engine/inmemory/morph_storage.go b/pkg/engine/inmemory/morph_storage.go index e666e2e..53922a6 100644 --- a/pkg/engine/inmemory/morph_storage.go +++ b/pkg/engine/inmemory/morph_storage.go @@ -10,23 +10,19 @@ import ( ) type inmemoryMorphRuleChainStorage struct { - nameToNamespaceChains engine.LocalOverrideStorage - nameToContainerChains engine.LocalOverrideStorage + storage engine.LocalOverrideStorage } func NewInmemoryMorphRuleChainStorage() engine.MorphRuleChainStorage { return &inmemoryMorphRuleChainStorage{ - nameToNamespaceChains: NewInmemoryLocalStorage(), - nameToContainerChains: NewInmemoryLocalStorage(), + storage: NewInmemoryLocalStorage(), } } func (s *inmemoryMorphRuleChainStorage) AddMorphRuleChain(name chain.Name, target engine.Target, c *chain.Chain) (_ util.Uint256, _ uint32, err error) { switch target.Type { - case engine.Namespace: - _, err = s.nameToNamespaceChains.AddOverride(name, target, c) - case engine.Container: - _, err = s.nameToContainerChains.AddOverride(name, target, c) + case engine.Namespace, engine.Container: + _, err = s.storage.AddOverride(name, target, c) default: err = engine.ErrUnknownTarget } @@ -35,10 +31,18 @@ func (s *inmemoryMorphRuleChainStorage) AddMorphRuleChain(name chain.Name, targe func (s *inmemoryMorphRuleChainStorage) RemoveMorphRuleChain(name chain.Name, target engine.Target, chainID chain.ID) (_ util.Uint256, _ uint32, err error) { switch target.Type { - case engine.Namespace: - err = s.nameToNamespaceChains.RemoveOverride(name, target, chainID) - case engine.Container: - err = s.nameToContainerChains.RemoveOverride(name, target, chainID) + case engine.Namespace, engine.Container: + err = s.storage.RemoveOverride(name, target, chainID) + default: + err = engine.ErrUnknownTarget + } + return +} + +func (s *inmemoryMorphRuleChainStorage) RemoveMorphRuleChainsByTarget(name chain.Name, target engine.Target) (_ util.Uint256, _ uint32, err error) { + switch target.Type { + case engine.Namespace, engine.Container: + err = s.storage.RemoveOverridesByTarget(name, target) default: err = engine.ErrUnknownTarget } @@ -47,10 +51,8 @@ func (s *inmemoryMorphRuleChainStorage) RemoveMorphRuleChain(name chain.Name, ta func (s *inmemoryMorphRuleChainStorage) ListMorphRuleChains(name chain.Name, target engine.Target) ([]*chain.Chain, error) { switch target.Type { - case engine.Namespace: - return s.nameToNamespaceChains.ListOverrides(name, target) - case engine.Container: - return s.nameToContainerChains.ListOverrides(name, target) + case engine.Namespace, engine.Container: + return s.storage.ListOverrides(name, target) default: } return nil, engine.ErrUnknownTarget @@ -72,7 +74,7 @@ func (s *inmemoryMorphRuleChainStorage) ListTargetsIterator(targetType engine.Ta // Listing targets may look bizarre, because inmemory rule chain storage use inmemory local overrides where // targets are listed by chain names. var targets []engine.Target - targets, err = s.nameToNamespaceChains.ListOverrideDefinedTargets(chain.Ingress) + targets, err = s.storage.ListOverrideDefinedTargets(chain.Ingress) if err != nil { return } @@ -80,7 +82,7 @@ func (s *inmemoryMorphRuleChainStorage) ListTargetsIterator(targetType engine.Ta it.Values = append(it.Values, stackitem.NewByteArray([]byte(t.Name))) } - targets, err = s.nameToNamespaceChains.ListOverrideDefinedTargets(chain.S3) + targets, err = s.storage.ListOverrideDefinedTargets(chain.S3) if err != nil { return } @@ -89,7 +91,7 @@ func (s *inmemoryMorphRuleChainStorage) ListTargetsIterator(targetType engine.Ta } case engine.Container: var targets []engine.Target - targets, err = s.nameToContainerChains.ListOverrideDefinedTargets(chain.Ingress) + targets, err = s.storage.ListOverrideDefinedTargets(chain.Ingress) if err != nil { return } @@ -97,7 +99,7 @@ func (s *inmemoryMorphRuleChainStorage) ListTargetsIterator(targetType engine.Ta it.Values = append(it.Values, stackitem.NewByteArray([]byte(t.Name))) } - targets, err = s.nameToContainerChains.ListOverrideDefinedTargets(chain.S3) + targets, err = s.storage.ListOverrideDefinedTargets(chain.S3) if err != nil { return } diff --git a/pkg/engine/interface.go b/pkg/engine/interface.go index 2c34e44..71c89e4 100644 --- a/pkg/engine/interface.go +++ b/pkg/engine/interface.go @@ -23,6 +23,8 @@ type LocalOverrideStorage interface { RemoveOverride(name chain.Name, target Target, chainID chain.ID) error + RemoveOverridesByTarget(name chain.Name, target Target) error + ListOverrides(name chain.Name, target Target) ([]*chain.Chain, error) DropAllOverrides(name chain.Name) error @@ -118,6 +120,9 @@ type MorphRuleChainStorage interface { // RemoveMorphRuleChain removes a chain rule to the policy contract and returns transaction hash, VUB and error. RemoveMorphRuleChain(name chain.Name, target Target, chainID chain.ID) (util.Uint256, uint32, error) + // RemoveMorphRuleChainsByTarget removes all chains by target and returns transaction hash, VUB and error. + RemoveMorphRuleChainsByTarget(name chain.Name, target Target) (util.Uint256, uint32, error) + SetAdmin(addr util.Uint160) (util.Uint256, uint32, error) } diff --git a/pkg/morph/policy/policy_contract_storage.go b/pkg/morph/policy/policy_contract_storage.go index 65d36ac..3d18f56 100644 --- a/pkg/morph/policy/policy_contract_storage.go +++ b/pkg/morph/policy/policy_contract_storage.go @@ -86,6 +86,18 @@ func (s *ContractStorage) RemoveMorphRuleChain(name chain.Name, target engine.Ta return } +func (s *ContractStorage) RemoveMorphRuleChainsByTarget(name chain.Name, target engine.Target) (txHash util.Uint256, vub uint32, err error) { + var kind policy.Kind + kind, err = policyKind(target.Type) + if err != nil { + return + } + fullName := prefixedChainName(name, nil) + + txHash, vub, err = s.contractInterface.RemoveChainsByPrefix(big.NewInt(int64(kind)), target.Name, fullName) + return +} + func (s *ContractStorage) ListMorphRuleChains(name chain.Name, target engine.Target) ([]*chain.Chain, error) { kind, err := policyKind(target.Type) if err != nil {