From e5b7e1f763945c194ea5b04acf8724bf7cbd4d3c Mon Sep 17 00:00:00 2001 From: Marina Biryukova Date: Mon, 1 Apr 2024 17:27:45 +0300 Subject: [PATCH] [#60] chain: Support numeric conditions Signed-off-by: Marina Biryukova --- iam/converter.go | 30 +++++++++- iam/converter_test.go | 60 +++++++++++++++++++ pkg/chain/chain.go | 115 ++++++++++++++++++++++++++----------- pkg/chain/chain_test.go | 87 +++++++++++++++++++++++++++- pkg/engine/chain_router.go | 18 +++--- schema/s3/consts.go | 1 + 6 files changed, 265 insertions(+), 46 deletions(-) diff --git a/iam/converter.go b/iam/converter.go index bd3483e..b39e044 100644 --- a/iam/converter.go +++ b/iam/converter.go @@ -219,8 +219,7 @@ func getConditionTypeAndConverter(op string) (chain.ConditionType, convertFuncti return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op) } case strings.HasPrefix(op, "Numeric"): - // TODO - return 0, nil, fmt.Errorf("currently nummeric conditions unsupported: '%s'", op) + return numericConditionTypeAndConverter(op) case strings.HasPrefix(op, "Date"): switch op { case CondDateEquals: @@ -255,12 +254,39 @@ func getConditionTypeAndConverter(op string) (chain.ConditionType, convertFuncti } } +func numericConditionTypeAndConverter(op string) (chain.ConditionType, convertFunction, error) { + switch op { + case CondNumericEquals: + return chain.CondNumericEquals, numericConvertFunction, nil + case CondNumericNotEquals: + return chain.CondNumericNotEquals, numericConvertFunction, nil + case CondNumericLessThan: + return chain.CondNumericLessThan, numericConvertFunction, nil + case CondNumericLessThanEquals: + return chain.CondNumericLessThanEquals, numericConvertFunction, nil + case CondNumericGreaterThan: + return chain.CondNumericGreaterThan, numericConvertFunction, nil + case CondNumericGreaterThanEquals: + return chain.CondNumericGreaterThanEquals, numericConvertFunction, nil + default: + return 0, nil, fmt.Errorf("unsupported condition operator: '%s'", op) + } +} + type convertFunction func(string) (string, error) func noConvertFunction(val string) (string, error) { return val, nil } +func numericConvertFunction(val string) (string, error) { + if _, err := strconv.ParseInt(val, 10, 64); err == nil { + return val, nil + } + + return "", fmt.Errorf("invalid numeric value: '%s'", val) +} + func dateConvertFunction(val string) (string, error) { if _, err := strconv.ParseInt(val, 10, 64); err == nil { return val, nil diff --git a/iam/converter_test.go b/iam/converter_test.go index 97e1cd7..60723cc 100644 --- a/iam/converter_test.go +++ b/iam/converter_test.go @@ -385,6 +385,12 @@ func TestConvertToChainCondition(t *testing.T) { CondArnLike: {condKeyAWSPrincipalARN: {principal}}, CondArnNotEquals: {"key18": {"val18"}}, CondArnNotLike: {"key19": {"val19"}}, + CondNumericEquals: {"key20": {"20"}}, + CondNumericNotEquals: {"key21": {"21"}}, + CondNumericLessThan: {"key22": {"22"}}, + CondNumericLessThanEquals: {"key23": {"23"}}, + CondNumericGreaterThan: {"key24": {"24"}}, + CondNumericGreaterThanEquals: {"key25": {"25"}}, } expectedCondition := []GroupedConditions{ @@ -549,11 +555,65 @@ func TestConvertToChainCondition(t *testing.T) { Value: "val19", }}, }, + { + Conditions: []chain.Condition{{ + Op: chain.CondNumericEquals, + Object: chain.ObjectRequest, + Key: "key20", + Value: "20", + }}, + }, + { + Conditions: []chain.Condition{{ + Op: chain.CondNumericNotEquals, + Object: chain.ObjectRequest, + Key: "key21", + Value: "21", + }}, + }, + { + Conditions: []chain.Condition{{ + Op: chain.CondNumericLessThan, + Object: chain.ObjectRequest, + Key: "key22", + Value: "22", + }}, + }, + { + Conditions: []chain.Condition{{ + Op: chain.CondNumericLessThanEquals, + Object: chain.ObjectRequest, + Key: "key23", + Value: "23", + }}, + }, + { + Conditions: []chain.Condition{{ + Op: chain.CondNumericGreaterThan, + Object: chain.ObjectRequest, + Key: "key24", + Value: "24", + }}, + }, + { + Conditions: []chain.Condition{{ + Op: chain.CondNumericGreaterThanEquals, + Object: chain.ObjectRequest, + Key: "key25", + Value: "25", + }}, + }, } actualCondition, err := convertToChainCondition(conditions) require.NoError(t, err) require.ElementsMatch(t, expectedCondition, actualCondition) + + invalidCondition := Conditions{ + CondNumericEquals: {"key": {"invalid"}}, + } + _, err = convertToChainCondition(invalidCondition) + require.Error(t, err) } func TestParsePrincipalARN(t *testing.T) { diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index aa857c9..35d9da2 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -2,6 +2,7 @@ package chain import ( "fmt" + "strconv" "strings" "git.frostfs.info/TrueCloudLab/policy-engine/pkg/resource" @@ -150,7 +151,7 @@ func FormCondSliceContainsValue(values []string) string { return strings.Join(values, condSliceContainsDelimiter) } -func (c *Condition) Match(req resource.Request) bool { +func (c *Condition) Match(req resource.Request) (bool, error) { var val string switch c.Object { case ObjectResource: @@ -165,31 +166,63 @@ func (c *Condition) Match(req resource.Request) bool { default: panic(fmt.Sprintf("unimplemented: %d", c.Op)) case CondStringEquals: - return val == c.Value + return val == c.Value, nil case CondStringNotEquals: - return val != c.Value + return val != c.Value, nil case CondStringEqualsIgnoreCase: - return strings.EqualFold(val, c.Value) + return strings.EqualFold(val, c.Value), nil case CondStringNotEqualsIgnoreCase: - return !strings.EqualFold(val, c.Value) + return !strings.EqualFold(val, c.Value), nil case CondStringLike: - return util.GlobMatch(val, c.Value) + return util.GlobMatch(val, c.Value), nil case CondStringNotLike: - return !util.GlobMatch(val, c.Value) + return !util.GlobMatch(val, c.Value), nil case CondStringLessThan: - return val < c.Value + return val < c.Value, nil case CondStringLessThanEquals: - return val <= c.Value + return val <= c.Value, nil case CondStringGreaterThan: - return val > c.Value + return val > c.Value, nil case CondStringGreaterThanEquals: - return val >= c.Value + return val >= c.Value, nil case CondSliceContains: - return slices.Contains(strings.Split(val, condSliceContainsDelimiter), c.Value) + return slices.Contains(strings.Split(val, condSliceContainsDelimiter), c.Value), nil + case CondNumericEquals, CondNumericNotEquals, CondNumericLessThan, CondNumericLessThanEquals, CondNumericGreaterThan, + CondNumericGreaterThanEquals: + return c.matchNumeric(val) } } -func (r *Rule) Match(req resource.Request) (status Status, matched bool) { +func (c *Condition) matchNumeric(val string) (bool, error) { + valInt, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid integer value: %s", val) + } + + condVal, err := strconv.ParseInt(c.Value, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid integer value: %s", c.Value) + } + + switch c.Op { + default: + panic(fmt.Sprintf("unimplemented: %d", c.Op)) + case CondNumericEquals: + return valInt == condVal, nil + case CondNumericNotEquals: + return valInt != condVal, nil + case CondNumericLessThan: + return valInt < condVal, nil + case CondNumericLessThanEquals: + return valInt <= condVal, nil + case CondNumericGreaterThan: + return valInt > condVal, nil + case CondNumericGreaterThanEquals: + return valInt >= condVal, nil + } +} + +func (r *Rule) Match(req resource.Request) (status Status, matched bool, err error) { found := len(r.Resources.Names) == 0 for i := range r.Resources.Names { if util.GlobMatch(req.Resource().Name(), r.Resources.Names[i]) != r.Resources.Inverted { @@ -198,42 +231,50 @@ func (r *Rule) Match(req resource.Request) (status Status, matched bool) { } } if !found { - return NoRuleFound, false + return NoRuleFound, false, nil } for i := range r.Actions.Names { if util.GlobMatch(req.Operation(), r.Actions.Names[i]) != r.Actions.Inverted { return r.matchCondition(req) } } - return NoRuleFound, false + return NoRuleFound, false, nil } -func (r *Rule) matchCondition(obj resource.Request) (status Status, matched bool) { +func (r *Rule) matchCondition(obj resource.Request) (status Status, matched bool, err error) { if r.Any { return r.matchAny(obj) } return r.matchAll(obj) } -func (r *Rule) matchAny(obj resource.Request) (status Status, matched bool) { +func (r *Rule) matchAny(obj resource.Request) (status Status, matched bool, err error) { for i := range r.Condition { - if r.Condition[i].Match(obj) { - return r.Status, true + matched, err = r.Condition[i].Match(obj) + if err != nil { + return NoRuleFound, false, err + } + if matched { + return r.Status, true, nil } } - return NoRuleFound, false + return NoRuleFound, false, nil } -func (r *Rule) matchAll(obj resource.Request) (status Status, matched bool) { +func (r *Rule) matchAll(obj resource.Request) (status Status, matched bool, err error) { for i := range r.Condition { - if !r.Condition[i].Match(obj) { - return NoRuleFound, false + matched, err = r.Condition[i].Match(obj) + if err != nil { + return NoRuleFound, false, err + } + if !matched { + return NoRuleFound, false, nil } } - return r.Status, true + return r.Status, true, nil } -func (c *Chain) Match(req resource.Request) (status Status, matched bool) { +func (c *Chain) Match(req resource.Request) (status Status, matched bool, err error) { switch c.MatchType { case MatchTypeDenyPriority: return c.denyPriority(req) @@ -244,30 +285,36 @@ func (c *Chain) Match(req resource.Request) (status Status, matched bool) { } } -func (c *Chain) firstMatch(req resource.Request) (status Status, matched bool) { +func (c *Chain) firstMatch(req resource.Request) (status Status, matched bool, err error) { for i := range c.Rules { - status, matched := c.Rules[i].Match(req) + status, matched, err = c.Rules[i].Match(req) + if err != nil { + return NoRuleFound, false, err + } if matched { - return status, true + return status, true, nil } } - return NoRuleFound, false + return NoRuleFound, false, nil } -func (c *Chain) denyPriority(req resource.Request) (status Status, matched bool) { +func (c *Chain) denyPriority(req resource.Request) (status Status, matched bool, err error) { var allowFound bool for i := range c.Rules { - status, matched := c.Rules[i].Match(req) + status, matched, err = c.Rules[i].Match(req) + if err != nil { + return NoRuleFound, false, err + } if !matched { continue } if status != Allow { - return status, true + return status, true, nil } allowFound = true } if allowFound { - return Allow, true + return Allow, true, nil } - return NoRuleFound, false + return NoRuleFound, false, nil } diff --git a/pkg/chain/chain_test.go b/pkg/chain/chain_test.go index f050574..f4bf6f8 100644 --- a/pkg/chain/chain_test.go +++ b/pkg/chain/chain_test.go @@ -6,6 +6,7 @@ import ( "git.frostfs.info/TrueCloudLab/policy-engine/pkg/resource/testutil" "git.frostfs.info/TrueCloudLab/policy-engine/schema/common" "git.frostfs.info/TrueCloudLab/policy-engine/schema/native" + "git.frostfs.info/TrueCloudLab/policy-engine/schema/s3" "github.com/stretchr/testify/require" ) @@ -75,14 +76,16 @@ func TestReturnFirstMatch(t *testing.T) { request := testutil.NewRequest(native.MethodPutObject, resource, nil) t.Run("default match", func(t *testing.T) { - st, found := ch.Match(request) + st, found, err := ch.Match(request) + require.NoError(t, err) require.True(t, found) require.Equal(t, AccessDenied, st) }) t.Run("return first match", func(t *testing.T) { ch.MatchType = MatchTypeFirstMatch - st, found := ch.Match(request) + st, found, err := ch.Match(request) + require.NoError(t, err) require.True(t, found) require.Equal(t, Allow, st) }) @@ -144,7 +147,85 @@ func TestCondSliceContainsMatch(t *testing.T) { resource := testutil.NewResource(native.ResourceFormatRootContainers, nil) request := testutil.NewRequest(native.MethodPutObject, resource, map[string]string{propKey: tc.value}) - st, _ := ch.Match(request) + st, _, err := ch.Match(request) + require.NoError(t, err) + require.Equal(t, tc.status.String(), st.String()) + }) + } +} + +func TestNumericConditionsMatch(t *testing.T) { + propKey := s3.PropertyKeyMaxKeys + + for _, tc := range []struct { + name string + conditions []Condition + value string + status Status + }{ + { + name: "value from interval", + conditions: []Condition{ + { + Op: CondNumericLessThan, + Object: ObjectRequest, + Key: propKey, + Value: "100", + }, + { + Op: CondNumericGreaterThan, + Object: ObjectRequest, + Key: propKey, + Value: "80", + }, + { + Op: CondNumericNotEquals, + Object: ObjectRequest, + Key: propKey, + Value: "91", + }, + }, + value: "90", + status: Allow, + }, + { + name: "border value", + conditions: []Condition{ + { + Op: CondNumericEquals, + Object: ObjectRequest, + Key: propKey, + Value: "50", + }, + { + Op: CondNumericLessThanEquals, + Object: ObjectRequest, + Key: propKey, + Value: "50", + }, + { + Op: CondNumericGreaterThanEquals, + Object: ObjectRequest, + Key: propKey, + Value: "50", + }, + }, + value: "50", + status: Allow, + }, + } { + t.Run(tc.name, func(t *testing.T) { + resource := testutil.NewResource(native.ResourceFormatRootContainers, nil) + request := testutil.NewRequest(native.MethodPutObject, resource, map[string]string{propKey: tc.value}) + + ch := Chain{Rules: []Rule{{ + Status: Allow, + Actions: Actions{Names: []string{native.MethodPutObject}}, + Resources: Resources{Names: []string{native.ResourceFormatRootContainers}}, + Condition: tc.conditions, + }}} + st, _, err := ch.Match(request) + require.NoError(t, err) require.Equal(t, tc.status.String(), st.String()) }) } diff --git a/pkg/engine/chain_router.go b/pkg/engine/chain_router.go index 830919f..873cd81 100644 --- a/pkg/engine/chain_router.go +++ b/pkg/engine/chain_router.go @@ -86,7 +86,7 @@ func (dr *defaultChainRouter) matchLocalOverrides(name chain.Name, target Target if err != nil { return } - status, ruleFound = dr.getStatusFromChains(localOverrides, r) + status, ruleFound, err = dr.getStatusFromChains(localOverrides, r) return } @@ -95,22 +95,26 @@ func (dr *defaultChainRouter) matchMorphRuleChains(name chain.Name, target Targe if err != nil { return chain.NoRuleFound, false, err } - status, ruleFound = dr.getStatusFromChains(namespaceChains, r) + status, ruleFound, err = dr.getStatusFromChains(namespaceChains, r) return } -func (dr *defaultChainRouter) getStatusFromChains(chains []*chain.Chain, r resource.Request) (chain.Status, bool) { +func (dr *defaultChainRouter) getStatusFromChains(chains []*chain.Chain, r resource.Request) (chain.Status, bool, error) { var allow bool for _, c := range chains { - if status, found := c.Match(r); found { + status, found, err := c.Match(r) + if err != nil { + return chain.NoRuleFound, false, err + } + if found { if status != chain.Allow { - return status, true + return status, true, nil } allow = true } } if allow { - return chain.Allow, true + return chain.Allow, true, nil } - return chain.NoRuleFound, false + return chain.NoRuleFound, false, nil } diff --git a/schema/s3/consts.go b/schema/s3/consts.go index 7159a32..ab8021e 100644 --- a/schema/s3/consts.go +++ b/schema/s3/consts.go @@ -6,6 +6,7 @@ const ( PropertyKeyDelimiter = "s3:delimiter" PropertyKeyPrefix = "s3:prefix" PropertyKeyVersionID = "s3:VersionId" + PropertyKeyMaxKeys = "s3:max-keys" ResourceFormatS3All = "arn:aws:s3:::*" ResourceFormatS3Bucket = "arn:aws:s3:::%s"