From 1d4deffb7d0ac44e4526959b3bdceb1483bc9c79 Mon Sep 17 00:00:00 2001 From: Airat Arifullin Date: Fri, 9 Feb 2024 17:28:42 +0300 Subject: [PATCH] [#78] policy: Introduce ListTargets method for Policy contract * Introduce a new method ListTargets that lists targets by kind. * Slightly fix key mapping - also concatenate kind to prefix. * Write unit-tests. * Regenerate rpcclient. Signed-off-by: Airat Arifullin --- policy/config.yml | 1 + policy/doc.go | 1 + policy/policy_contract.go | 89 +++++++++++++++++++++++++++++++++----- rpcclient/policy/client.go | 5 +++ tests/policy_test.go | 28 ++++++++++++ 5 files changed, 112 insertions(+), 12 deletions(-) diff --git a/policy/config.yml b/policy/config.yml index cdde275..e0c28af 100644 --- a/policy/config.yml +++ b/policy/config.yml @@ -4,5 +4,6 @@ safemethods: - "listChains" - "getChain" - "listChainsByPrefix" + - "listTargets" - "iteratorChainsByPrefix" - "version" \ No newline at end of file diff --git a/policy/doc.go b/policy/doc.go index bfde740..6083d48 100644 --- a/policy/doc.go +++ b/policy/doc.go @@ -8,6 +8,7 @@ | 'n' + uint16(len(namespace)) + namespace | []byte | Container chain | | 'm' + entity name (namespace/container) | []byte | Mapped name to an encoded number | | 'counter' | uint64 | Integer counter used for mapping | + | 'rc' | uint64 | Actvie rules count for a target | */ diff --git a/policy/policy_contract.go b/policy/policy_contract.go index b73b743..e6321f1 100644 --- a/policy/policy_contract.go +++ b/policy/policy_contract.go @@ -23,8 +23,9 @@ const ( ) const ( - mappingKeyPrefix = 'm' - counterKey = "counter" + activeRulesCounterKeyPrefix = "rc" + mappingKeyPrefix = 'm' + counterKey = "counter" ) const ( @@ -90,11 +91,15 @@ func storageKey(prefix Kind, counter int, name []byte) []byte { return append(key, name...) } +func numericMapKeyPrefix(kind Kind) []byte { + return []byte{mappingKeyPrefix, byte(kind)} +} + // mapToNumeric maps a name to a number. That allows to keep more space in // a storage key shortening long names. Short entity // names are also mapped to prevent collisions in the map. -func mapToNumeric(ctx storage.Context, name []byte) (mapped int, mappingExists bool) { - mKey := append([]byte{mappingKeyPrefix}, name...) +func mapToNumeric(ctx storage.Context, kind Kind, name []byte) (mapped int, mappingExists bool) { + mKey := append(numericMapKeyPrefix(kind), name...) numericID := storage.Get(ctx, mKey) if numericID == nil { return @@ -109,8 +114,8 @@ func mapToNumeric(ctx storage.Context, name []byte) (mapped int, mappingExists b // names are also mapped to prevent collisions in the map. // If a mapping cannot be found, then the method creates and returns it. // mapToNumericCreateIfNotExists is NOT applicable for a read-only context. -func mapToNumericCreateIfNotExists(ctx storage.Context, name []byte) int { - mKey := append([]byte{mappingKeyPrefix}, name...) +func mapToNumericCreateIfNotExists(ctx storage.Context, kind Kind, name []byte) int { + mKey := append(numericMapKeyPrefix(kind), name...) numericID := storage.Get(ctx, mKey) if numericID == nil { counter := storage.Get(ctx, counterKey).(int) @@ -126,15 +131,17 @@ func AddChain(entity Kind, entityName string, name []byte, chain []byte) { ctx := storage.GetContext() checkAuthorization(ctx) - entityNameBytes := mapToNumericCreateIfNotExists(ctx, []byte(entityName)) + entityNameBytes := mapToNumericCreateIfNotExists(ctx, entity, []byte(entityName)) key := storageKey(entity, entityNameBytes, name) storage.Put(ctx, key, chain) + + incTargetRulesCount(ctx, append(numericMapKeyPrefix(entity), entityName...)) } func GetChain(entity Kind, entityName string, name []byte) []byte { ctx := storage.GetReadOnlyContext() - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameBytes, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { panic("not found") } @@ -152,20 +159,21 @@ func RemoveChain(entity Kind, entityName string, name []byte) { ctx := storage.GetContext() checkAuthorization(ctx) - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameBytes, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { return } key := storageKey(entity, entityNameBytes, name) storage.Delete(ctx, key) + decTargetRulesCount(ctx, append(numericMapKeyPrefix(entity), entityName...)) } func RemoveChainsByPrefix(entity Kind, entityName string, name []byte) { ctx := storage.GetContext() checkAuthorization(ctx) - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameBytes, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { return } @@ -174,6 +182,7 @@ func RemoveChainsByPrefix(entity Kind, entityName string, name []byte) { it := storage.Find(ctx, key, storage.KeysOnly) for iterator.Next(it) { storage.Delete(ctx, iterator.Value(it).([]byte)) + decTargetRulesCount(ctx, key) } } @@ -195,7 +204,7 @@ func ListChainsByPrefix(entity Kind, entityName string, prefix []byte) [][]byte result := [][]byte{} - entityNameBytes, exists := mapToNumeric(ctx, []byte(entityName)) + entityNameBytes, exists := mapToNumeric(ctx, entity, []byte(entityName)) if !exists { return result } @@ -211,7 +220,63 @@ func ListChainsByPrefix(entity Kind, entityName string, prefix []byte) [][]byte func IteratorChainsByPrefix(entity Kind, entityName string, prefix []byte) iterator.Iterator { ctx := storage.GetReadOnlyContext() - id, _ := mapToNumeric(ctx, []byte(entityName)) + id, _ := mapToNumeric(ctx, entity, []byte(entityName)) keyPrefix := storageKey(entity, id, prefix) return storage.Find(ctx, keyPrefix, storage.ValuesOnly) } + +// ListTargets lists targets for which rules are defined by kind. +func ListTargets(entity Kind) [][]byte { + ctx := storage.GetReadOnlyContext() + + targetNames := [][]byte{} + + keyPrefix := numericMapKeyPrefix(entity) + it := storage.Find(ctx, keyPrefix, storage.KeysOnly) + + for iterator.Next(it) { + key := iterator.Value(it).([]byte) + entityName := key[len(keyPrefix):] + + if getTargetRulesCount(ctx, key) == 0 { + continue + } + targetNames = append(targetNames, entityName) + } + + return targetNames +} + +func incTargetRulesCount(ctx storage.Context, key []byte) { + ruleCounterKey := append([]byte(activeRulesCounterKeyPrefix), key...) + counter := storage.Get(ctx, ruleCounterKey) + if counter == nil { + storage.Put(ctx, ruleCounterKey, 1) + } else { + val := counter.(int) + val++ + storage.Put(ctx, ruleCounterKey, val) + } +} + +func decTargetRulesCount(ctx storage.Context, key []byte) { + ruleCounterKey := append([]byte(activeRulesCounterKeyPrefix), key...) + counter := storage.Get(ctx, ruleCounterKey) + if counter != nil { + val := counter.(int) + if val == 0 { + return + } + val-- + storage.Put(ctx, ruleCounterKey, val) + } +} + +func getTargetRulesCount(ctx storage.Context, key []byte) int { + ruleCounterKey := append([]byte(activeRulesCounterKeyPrefix), key...) + counter := storage.Get(ctx, ruleCounterKey) + if counter != nil { + return counter.(int) + } + return 0 +} diff --git a/rpcclient/policy/client.go b/rpcclient/policy/client.go index 2b99562..ba677dd 100644 --- a/rpcclient/policy/client.go +++ b/rpcclient/policy/client.go @@ -90,6 +90,11 @@ func (c *ContractReader) ListChainsByPrefix(entity *big.Int, entityName string, return unwrap.Array(c.invoker.Call(c.hash, "listChainsByPrefix", entity, entityName, prefix)) } +// ListTargets invokes `listTargets` method of contract. +func (c *ContractReader) ListTargets(entity *big.Int) ([]stackitem.Item, error) { + return unwrap.Array(c.invoker.Call(c.hash, "listTargets", entity)) +} + // Version invokes `version` method of contract. func (c *ContractReader) Version() (*big.Int, error) { return unwrap.BigInt(c.invoker.Call(c.hash, "version")) diff --git a/tests/policy_test.go b/tests/policy_test.go index 8fa1891..5d1aa6f 100644 --- a/tests/policy_test.go +++ b/tests/policy_test.go @@ -39,19 +39,26 @@ func TestPolicy(t *testing.T) { p33 := []byte("chain33") e.Invoke(t, stackitem.Null{}, "addChain", policy.Namespace, "mynamespace", "ingress:123", p1) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) checkChains(t, e, "mynamespace", "", "ingress", [][]byte{p1}) checkChains(t, e, "mynamespace", "", "all", nil) e.Invoke(t, stackitem.Null{}, "addChain", policy.Container, "cnr1", "ingress:myrule2", p2) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) + checkTargets(t, e, policy.Container, [][]byte{[]byte("cnr1")}) checkChains(t, e, "mynamespace", "", "ingress", [][]byte{p1}) // Only namespace chains. checkChains(t, e, "mynamespace", "cnr1", "ingress", [][]byte{p1, p2}) checkChains(t, e, "mynamespace", "cnr1", "all", nil) // No chains attached to 'all'. checkChains(t, e, "mynamespace", "cnr2", "ingress", [][]byte{p1}) // Only namespace, no chains for the container. e.Invoke(t, stackitem.Null{}, "addChain", policy.Container, "cnr1", "ingress:myrule3", p3) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) + checkTargets(t, e, policy.Container, [][]byte{[]byte("cnr1")}) checkChains(t, e, "mynamespace", "cnr1", "ingress", [][]byte{p1, p2, p3}) e.Invoke(t, stackitem.Null{}, "addChain", policy.Container, "cnr1", "ingress:myrule3", p33) + checkTargets(t, e, policy.Namespace, [][]byte{[]byte("mynamespace")}) + checkTargets(t, e, policy.Container, [][]byte{[]byte("cnr1")}) checkChain(t, e, policy.Container, "cnr1", "ingress:myrule3", p33) checkChains(t, e, "mynamespace", "cnr1", "ingress", [][]byte{p1, p2, p33}) // Override chain. checkChainsByPrefix(t, e, policy.Container, "cnr1", "", [][]byte{p2, p33}) @@ -73,6 +80,9 @@ func TestPolicy(t *testing.T) { // Remove by prefix. e.Invoke(t, stackitem.Null{}, "removeChainsByPrefix", policy.Container, "cnr1", "ingress") checkChains(t, e, "mynamespace", "cnr1", "ingress", nil) + + checkTargets(t, e, policy.Namespace, [][]byte{}) + checkTargets(t, e, policy.Container, [][]byte{}) }) } @@ -148,3 +158,21 @@ func checkChain(t *testing.T, e *neotest.ContractInvoker, kind byte, entityName, require.True(t, bytes.Equal(expected, s.Pop().Bytes())) } + +func checkTargets(t *testing.T, e *neotest.ContractInvoker, kind byte, expected [][]byte) { + s, err := e.TestInvoke(t, "listTargets", kind) + require.NoError(t, err) + + var targets [][]byte + for _, item := range s.Pop().Array() { + target, err := item.TryBytes() + require.NoError(t, err) + targets = append(targets, target) + } + + require.Len(t, targets, len(expected)) + + for _, exp := range expected { + require.Contains(t, targets, exp) + } +}