From d9f523ee07ab131a89e0086c8571e7c84f215f00 Mon Sep 17 00:00:00 2001 From: Airat Arifullin Date: Fri, 9 Feb 2024 17:28:42 +0300 Subject: [PATCH] [#78] policy: Introduce ListTargets method for Policy contract * Introduce a new method ListTargets that lists targets by kind. * Slightly fix key mapping - also concatenate kind to prefix. * Write unit-tests. * Regenerate rpcclient. Signed-off-by: Airat Arifullin --- policy/config.yml | 1 + policy/policy_contract.go | 49 ++++++++++++++++++++++++++++---------- rpcclient/policy/client.go | 14 +++++++++++ tests/policy_test.go | 39 ++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 12 deletions(-) diff --git a/policy/config.yml b/policy/config.yml index cdde275..e0c28af 100644 --- a/policy/config.yml +++ b/policy/config.yml @@ -4,5 +4,6 @@ safemethods: - "listChains" - "getChain" - "listChainsByPrefix" + - "listTargets" - "iteratorChainsByPrefix" - "version" \ No newline at end of file diff --git a/policy/policy_contract.go b/policy/policy_contract.go index b73b743..64a131c 100644 --- a/policy/policy_contract.go +++ b/policy/policy_contract.go @@ -90,11 +90,15 @@ func storageKey(prefix Kind, counter int, name []byte) []byte { return append(key, name...) } +func mapKey(kind Kind, name []byte) []byte { + return append([]byte{mappingKeyPrefix, byte(kind)}, name...) +} + // mapToNumeric maps a name to a number. That allows to keep more space in // a storage key shortening long names. Short entity // names are also mapped to prevent collisions in the map. -func mapToNumeric(ctx storage.Context, name []byte) (mapped int, mappingExists bool) { - mKey := append([]byte{mappingKeyPrefix}, name...) +func mapToNumeric(ctx storage.Context, kind Kind, name []byte) (mapped int, mappingExists bool) { + mKey := mapKey(kind, name) numericID := storage.Get(ctx, mKey) if numericID == nil { return @@ -109,8 +113,8 @@ func mapToNumeric(ctx storage.Context, name []byte) (mapped int, mappingExists b // names are also mapped to prevent collisions in the map. // If a mapping cannot be found, then the method creates and returns it. // mapToNumericCreateIfNotExists is NOT applicable for a read-only context. -func mapToNumericCreateIfNotExists(ctx storage.Context, name []byte) int { - mKey := append([]byte{mappingKeyPrefix}, name...) +func mapToNumericCreateIfNotExists(ctx storage.Context, kind Kind, name []byte) int { + mKey := mapKey(kind, name) numericID := storage.Get(ctx, mKey) if numericID == nil { counter := storage.Get(ctx, counterKey).(int) @@ -126,7 +130,7 @@ func AddChain(entity Kind, entityName string, name []byte, chain []byte) { ctx := storage.GetContext() checkAuthorization(ctx) - entityNameBytes := mapToNumericCreateIfNotExists(ctx, []byte(entityName)) + entityNameBytes := mapToNumericCreateIfNotExists(ctx, entity, []byte(entityName)) key := storageKey(entity, entityNameBytes, name) storage.Put(ctx, key, chain) } @@ -134,7 +138,7 @@ func AddChain(entity Kind, entityName string, name []byte, chain []byte) { func GetChain(entity Kind, entityName string, name []byte) []byte { ctx := storage.GetReadOnlyContext() - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameBytes, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { panic("not found") } @@ -152,29 +156,43 @@ func RemoveChain(entity Kind, entityName string, name []byte) { ctx := storage.GetContext() checkAuthorization(ctx) - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameNum, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { return } - key := storageKey(entity, entityNameBytes, name) + key := storageKey(entity, entityNameNum, name) storage.Delete(ctx, key) + + // If no chains are left for the target, then remove the mapping. + prefix := append([]byte{byte(entity)}, common.ToFixedWidth64(entityNameNum)...) + it := storage.Find(ctx, prefix, storage.KeysOnly) + if !iterator.Next(it) { + storage.Delete(ctx, mapKey(entity, []byte(entityName))) + } } func RemoveChainsByPrefix(entity Kind, entityName string, name []byte) { ctx := storage.GetContext() checkAuthorization(ctx) - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameNum, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { return } - key := storageKey(entity, entityNameBytes, name) + key := storageKey(entity, entityNameNum, name) it := storage.Find(ctx, key, storage.KeysOnly) for iterator.Next(it) { storage.Delete(ctx, iterator.Value(it).([]byte)) } + + // If no chains are left for the target, then remove the mapping. + prefix := append([]byte{byte(entity)}, common.ToFixedWidth64(entityNameNum)...) + it = storage.Find(ctx, prefix, storage.KeysOnly) + if !iterator.Next(it) { + storage.Delete(ctx, mapKey(entity, []byte(entityName))) + } } // ListChains lists all chains for the namespace by prefix. @@ -195,7 +213,7 @@ func ListChainsByPrefix(entity Kind, entityName string, prefix []byte) [][]byte result := [][]byte{} - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameBytes, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { return result } @@ -211,7 +229,14 @@ func ListChainsByPrefix(entity Kind, entityName string, prefix []byte) [][]byte func IteratorChainsByPrefix(entity Kind, entityName string, prefix []byte) iterator.Iterator { ctx := storage.GetReadOnlyContext() - id, _ := mapToNumeric(ctx, []byte(entityName)) + id, _ := mapToNumeric(ctx, entity, []byte(entityName)) keyPrefix := storageKey(entity, id, prefix) return storage.Find(ctx, keyPrefix, storage.ValuesOnly) } + +// ListTargets iterates over targets for which rules are defined. +func ListTargets(entity Kind) iterator.Iterator { + ctx := storage.GetReadOnlyContext() + mKey := mapKey(entity, []byte{}) + return storage.Find(ctx, mKey, storage.KeysOnly|storage.RemovePrefix) +} diff --git a/rpcclient/policy/client.go b/rpcclient/policy/client.go index 2b99562..6a0d551 100644 --- a/rpcclient/policy/client.go +++ b/rpcclient/policy/client.go @@ -90,6 +90,20 @@ func (c *ContractReader) ListChainsByPrefix(entity *big.Int, entityName string, return unwrap.Array(c.invoker.Call(c.hash, "listChainsByPrefix", entity, entityName, prefix)) } +// ListTargets invokes `listTargets` method of contract. +func (c *ContractReader) ListTargets(entity *big.Int) (uuid.UUID, result.Iterator, error) { + return unwrap.SessionIterator(c.invoker.Call(c.hash, "listTargets", entity)) +} + +// ListTargetsExpanded is similar to ListTargets (uses the same contract +// method), but can be useful if the server used doesn't support sessions and +// doesn't expand iterators. It creates a script that will get the specified +// number of result items from the iterator right in the VM and return them to +// you. It's only limited by VM stack and GAS available for RPC invocations. +func (c *ContractReader) ListTargetsExpanded(entity *big.Int, _numOfIteratorItems int) ([]stackitem.Item, error) { + return unwrap.Array(c.invoker.CallAndExpandIterator(c.hash, "listTargets", _numOfIteratorItems, entity)) +} + // Version invokes `version` method of contract. func (c *ContractReader) Version() (*big.Int, error) { return unwrap.BigInt(c.invoker.Call(c.hash, "version")) diff --git a/tests/policy_test.go b/tests/policy_test.go index 8fa1891..781a308 100644 --- a/tests/policy_test.go +++ b/tests/policy_test.go @@ -39,19 +39,26 @@ func TestPolicy(t *testing.T) { p33 := []byte("chain33") e.Invoke(t, stackitem.Null{}, "addChain", policy.Namespace, "mynamespace", "ingress:123", p1) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) checkChains(t, e, "mynamespace", "", "ingress", [][]byte{p1}) checkChains(t, e, "mynamespace", "", "all", nil) e.Invoke(t, stackitem.Null{}, "addChain", policy.Container, "cnr1", "ingress:myrule2", p2) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) + checkTargets(t, e, policy.Container, [][]byte{[]byte("cnr1")}) checkChains(t, e, "mynamespace", "", "ingress", [][]byte{p1}) // Only namespace chains. checkChains(t, e, "mynamespace", "cnr1", "ingress", [][]byte{p1, p2}) checkChains(t, e, "mynamespace", "cnr1", "all", nil) // No chains attached to 'all'. checkChains(t, e, "mynamespace", "cnr2", "ingress", [][]byte{p1}) // Only namespace, no chains for the container. e.Invoke(t, stackitem.Null{}, "addChain", policy.Container, "cnr1", "ingress:myrule3", p3) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) + checkTargets(t, e, policy.Container, [][]byte{[]byte("cnr1")}) checkChains(t, e, "mynamespace", "cnr1", "ingress", [][]byte{p1, p2, p3}) e.Invoke(t, stackitem.Null{}, "addChain", policy.Container, "cnr1", "ingress:myrule3", p33) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) + checkTargets(t, e, policy.Container, [][]byte{[]byte("cnr1")}) checkChain(t, e, policy.Container, "cnr1", "ingress:myrule3", p33) checkChains(t, e, "mynamespace", "cnr1", "ingress", [][]byte{p1, p2, p33}) // Override chain. checkChainsByPrefix(t, e, policy.Container, "cnr1", "", [][]byte{p2, p33}) @@ -73,6 +80,21 @@ func TestPolicy(t *testing.T) { // Remove by prefix. e.Invoke(t, stackitem.Null{}, "removeChainsByPrefix", policy.Container, "cnr1", "ingress") checkChains(t, e, "mynamespace", "cnr1", "ingress", nil) + + // Remove by prefix. + e.Invoke(t, stackitem.Null{}, "removeChainsByPrefix", policy.Container, "cnr1", "ingress") + checkChains(t, e, "mynamespace", "cnr1", "ingress", nil) + + checkTargets(t, e, policy.Namespace, [][]byte{}) + checkTargets(t, e, policy.Container, [][]byte{}) + }) + + t.Run("add again after removal", func(t *testing.T) { + e.Invoke(t, stackitem.Null{}, "addChain", policy.Namespace, "mynamespace", "ingress:123", p1) + e.Invoke(t, stackitem.Null{}, "addChain", policy.Container, "cnr1", "ingress:myrule3", p33) + + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) + checkTargets(t, e, policy.Container, [][]byte{[]byte("cnr1")}) }) } @@ -148,3 +170,20 @@ func checkChain(t *testing.T, e *neotest.ContractInvoker, kind byte, entityName, require.True(t, bytes.Equal(expected, s.Pop().Bytes())) } + +func checkTargets(t *testing.T, e *neotest.ContractInvoker, kind byte, expected [][]byte) { + s, err := e.TestInvoke(t, "listTargets", kind) + require.NoError(t, err) + + require.NotEqual(t, 0, s.Len(), "stack is empty") + + iteratorItem := s.Pop().Value().(*storage.Iterator) + targets := iteratorToArray(iteratorItem) + require.Equal(t, len(expected), len(targets)) + + for i := range expected { + bytesTargets, err := targets[i].TryBytes() + require.NoError(t, err) + require.Equal(t, expected[i], bytesTargets) + } +}