From 8c673ee4f4af56095b62d587bb8a70ed0d920efb Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Fri, 8 Dec 2023 14:07:39 +0300 Subject: [PATCH] [#21] chain: Allow to return first match result Signed-off-by: Dmitrii Stepanov --- iam/policy_test.go | 1 + pkg/chain/chain.go | 41 ++++++++++++++++++++++++++++++++++++++++ pkg/chain/chain_test.go | 42 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/iam/policy_test.go b/iam/policy_test.go index 046035d..c397f33 100644 --- a/iam/policy_test.go +++ b/iam/policy_test.go @@ -478,6 +478,7 @@ func TestProcessDenyFirst(t *testing.T) { identityNativePolicy, err := ConvertToNativeChain(identityPolicy, mockResolver) require.NoError(t, err) + identityNativePolicy.MatchType = chain.MatchTypeFirstMatch resourceNativePolicy, err := ConvertToNativeChain(resourcePolicy, mockResolver) require.NoError(t, err) diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index a98ef43..d241ca8 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -12,10 +12,22 @@ import ( // ID is the ID of rule chain. type ID string +// MatchType is the match type for chain rules. +type MatchType uint8 + +const ( + // MatchTypeDenyPriority rejects the request if any `Deny` is specified. + MatchTypeDenyPriority MatchType = 0 + // MatchTypeFirstMatch returns the first rule action matched to the request. + MatchTypeFirstMatch MatchType = 1 +) + type Chain struct { ID ID Rules []Rule + + MatchType MatchType } func (c *Chain) Bytes() []byte { @@ -214,6 +226,17 @@ func (r *Rule) matchAll(obj resource.Request) (status Status, matched bool) { } func (c *Chain) Match(req resource.Request) (status Status, matched bool) { + switch c.MatchType { + case MatchTypeDenyPriority: + return c.denyPriority(req) + case MatchTypeFirstMatch: + return c.firstMatch(req) + default: + panic(fmt.Sprintf("unknown MatchType %d", c.MatchType)) + } +} + +func (c *Chain) firstMatch(req resource.Request) (status Status, matched bool) { for i := range c.Rules { status, matched := c.Rules[i].Match(req) if matched { @@ -222,3 +245,21 @@ func (c *Chain) Match(req resource.Request) (status Status, matched bool) { } return NoRuleFound, false } + +func (c *Chain) denyPriority(req resource.Request) (status Status, matched bool) { + var allowFound bool + for i := range c.Rules { + status, matched := c.Rules[i].Match(req) + if !matched { + continue + } + if status != Allow { + return status, true + } + allowFound = true + } + if allowFound { + return Allow, true + } + return NoRuleFound, false +} diff --git a/pkg/chain/chain_test.go b/pkg/chain/chain_test.go index 2c5ae9c..911daa4 100644 --- a/pkg/chain/chain_test.go +++ b/pkg/chain/chain_test.go @@ -3,11 +3,14 @@ package chain import ( "testing" + "git.frostfs.info/TrueCloudLab/policy-engine/pkg/resource/testutil" + "git.frostfs.info/TrueCloudLab/policy-engine/schema/native" "github.com/stretchr/testify/require" ) func TestEncodeDecode(t *testing.T) { expected := Chain{ + MatchType: MatchTypeFirstMatch, Rules: []Rule{ { Status: Allow, @@ -31,3 +34,42 @@ func TestEncodeDecode(t *testing.T) { require.NoError(t, actual.DecodeBytes(data)) require.Equal(t, expected, actual) } + +func TestReturnFirstMatch(t *testing.T) { + ch := Chain{ + Rules: []Rule{ + { + Status: Allow, + Actions: Actions{Names: []string{ + native.MethodPutObject, + }}, + Resources: Resources{Names: []string{native.ResourceFormatRootContainers}}, + Condition: []Condition{}, + }, + { + Status: AccessDenied, + Actions: Actions{Names: []string{ + native.MethodPutObject, + }}, + Resources: Resources{Names: []string{native.ResourceFormatRootContainers}}, + Condition: []Condition{}, + }, + }, + } + + resource := testutil.NewResource(native.ResourceFormatRootContainers, nil) + request := testutil.NewRequest(native.MethodPutObject, resource, nil) + + t.Run("default match", func(t *testing.T) { + st, found := ch.Match(request) + require.True(t, found) + require.Equal(t, AccessDenied, st) + }) + + t.Run("return first match", func(t *testing.T) { + ch.MatchType = MatchTypeFirstMatch + st, found := ch.Match(request) + require.True(t, found) + require.Equal(t, Allow, st) + }) +}