[#25] engine: Refactor LocalOverrideStorage

* Make LocalOverrideStorage methods to receive Target type
  instead resource
* Refactor unit-tests and dependencies
* Make default chain router check local overrides not
  only for container but also for namespaces

Signed-off-by: Airat Arifullin <aarifullin@yadro.com>
This commit is contained in:
aarifullin 2023-12-01 16:08:52 +03:00 committed by Evgenii Stratonikov
parent a0a35bf4bf
commit 4d8242584a
6 changed files with 67 additions and 56 deletions

View file

@ -25,19 +25,36 @@ func NewDefaultChainRouterWithLocalOverrides(morph MorphRuleChainStorage, local
} }
func (dr *defaultChainRouter) IsAllowed(name chain.Name, namespace string, r resource.Request) (status chain.Status, ruleFound bool, err error) { func (dr *defaultChainRouter) IsAllowed(name chain.Name, namespace string, r resource.Request) (status chain.Status, ruleFound bool, err error) {
if dr.local != nil { status, ruleFound, err = dr.checkLocal(name, namespace, r)
var localRuleFound bool
status, localRuleFound, err = dr.checkLocalOverrides(name, r)
if err != nil { if err != nil {
return chain.NoRuleFound, false, err return chain.NoRuleFound, false, err
} else if localRuleFound { } else if ruleFound {
ruleFound = true
return return
} }
status, ruleFound, err = dr.checkMorph(name, namespace, r)
return
}
func (dr *defaultChainRouter) checkLocal(name chain.Name, namespace string, r resource.Request) (status chain.Status, ruleFound bool, err error) {
if dr.local == nil {
return
} }
status, ruleFound, err = dr.matchLocalOverrides(name, ContainerTarget(r.Resource().Name()), r)
if err != nil {
return chain.NoRuleFound, false, err
} else if ruleFound {
return
}
status, ruleFound, err = dr.matchLocalOverrides(name, NamespaceTarget(namespace), r)
return
}
func (dr *defaultChainRouter) checkMorph(name chain.Name, namespace string, r resource.Request) (status chain.Status, ruleFound bool, err error) {
var namespaceRuleFound bool var namespaceRuleFound bool
status, namespaceRuleFound, err = dr.checkNamespaceChains(name, namespace, r) status, namespaceRuleFound, err = dr.matchMorphRuleChains(name, NamespaceTarget(namespace), r)
if err != nil { if err != nil {
return return
} else if namespaceRuleFound && status != chain.Allow { } else if namespaceRuleFound && status != chain.Allow {
@ -46,7 +63,7 @@ func (dr *defaultChainRouter) IsAllowed(name chain.Name, namespace string, r res
} }
var cnrRuleFound bool var cnrRuleFound bool
status, cnrRuleFound, err = dr.checkContainerChains(name, r.Resource().Name(), r) status, cnrRuleFound, err = dr.matchMorphRuleChains(name, ContainerTarget(r.Resource().Name()), r)
if err != nil { if err != nil {
return return
} else if cnrRuleFound && status != chain.Allow { } else if cnrRuleFound && status != chain.Allow {
@ -61,8 +78,8 @@ func (dr *defaultChainRouter) IsAllowed(name chain.Name, namespace string, r res
return return
} }
func (dr *defaultChainRouter) checkLocalOverrides(name chain.Name, r resource.Request) (status chain.Status, ruleFound bool, err error) { func (dr *defaultChainRouter) matchLocalOverrides(name chain.Name, target Target, r resource.Request) (status chain.Status, ruleFound bool, err error) {
localOverrides, err := dr.local.ListOverrides(name, r.Resource().Name()) localOverrides, err := dr.local.ListOverrides(name, target)
if err != nil { if err != nil {
return return
} }
@ -74,8 +91,8 @@ func (dr *defaultChainRouter) checkLocalOverrides(name chain.Name, r resource.Re
return return
} }
func (dr *defaultChainRouter) checkNamespaceChains(name chain.Name, namespace string, r resource.Request) (status chain.Status, ruleFound bool, err error) { func (dr *defaultChainRouter) matchMorphRuleChains(name chain.Name, target Target, r resource.Request) (status chain.Status, ruleFound bool, err error) {
namespaceChains, err := dr.morph.ListMorphRuleChains(name, NamespaceTarget(namespace)) namespaceChains, err := dr.morph.ListMorphRuleChains(name, target)
if err != nil { if err != nil {
return return
} }
@ -86,16 +103,3 @@ func (dr *defaultChainRouter) checkNamespaceChains(name chain.Name, namespace st
} }
return return
} }
func (dr *defaultChainRouter) checkContainerChains(name chain.Name, container string, r resource.Request) (status chain.Status, ruleFound bool, err error) {
containerChains, err := dr.morph.ListMorphRuleChains(name, ContainerTarget(container))
if err != nil {
return
}
for _, c := range containerChains {
if status, ruleFound = c.Match(r); ruleFound {
return
}
}
return
}

View file

@ -167,7 +167,7 @@ func TestInmemory(t *testing.T) {
require.False(t, ok) require.False(t, ok)
t.Run("quota on a different container", func(t *testing.T) { t.Run("quota on a different container", func(t *testing.T) {
s.LocalStorage().AddOverride(chain.Ingress, container, &chain.Chain{ s.LocalStorage().AddOverride(chain.Ingress, engine.ContainerTarget(container), &chain.Chain{
Rules: []chain.Rule{{ Rules: []chain.Rule{{
Status: chain.QuotaLimitReached, Status: chain.QuotaLimitReached,
Actions: chain.Actions{Names: []string{"native::object::put"}}, Actions: chain.Actions{Names: []string{"native::object::put"}},
@ -182,7 +182,7 @@ func TestInmemory(t *testing.T) {
var quotaRuleChainID chain.ID var quotaRuleChainID chain.ID
t.Run("quota on the request container", func(t *testing.T) { t.Run("quota on the request container", func(t *testing.T) {
quotaRuleChainID, _ = s.LocalStorage().AddOverride(chain.Ingress, container, &chain.Chain{ quotaRuleChainID, _ = s.LocalStorage().AddOverride(chain.Ingress, engine.ContainerTarget(container), &chain.Chain{
Rules: []chain.Rule{{ Rules: []chain.Rule{{
Status: chain.QuotaLimitReached, Status: chain.QuotaLimitReached,
Actions: chain.Actions{Names: []string{"native::object::put"}}, Actions: chain.Actions{Names: []string{"native::object::put"}},
@ -195,7 +195,7 @@ func TestInmemory(t *testing.T) {
require.True(t, ok) require.True(t, ok)
}) })
t.Run("removed quota on the request container", func(t *testing.T) { t.Run("removed quota on the request container", func(t *testing.T) {
err := s.LocalStorage().RemoveOverride(chain.Ingress, container, quotaRuleChainID) err := s.LocalStorage().RemoveOverride(chain.Ingress, engine.ContainerTarget(container), quotaRuleChainID)
require.NoError(t, err) require.NoError(t, err)
status, ok, _ = s.IsAllowed(chain.Ingress, namespace, reqGood) status, ok, _ = s.IsAllowed(chain.Ingress, namespace, reqGood)

View file

@ -10,7 +10,7 @@ import (
"git.frostfs.info/TrueCloudLab/policy-engine/util" "git.frostfs.info/TrueCloudLab/policy-engine/util"
) )
type targetToChain map[string][]*chain.Chain type targetToChain map[engine.Target][]*chain.Chain
type inmemoryLocalStorage struct { type inmemoryLocalStorage struct {
usedChainID map[chain.ID]struct{} usedChainID map[chain.ID]struct{}
@ -24,11 +24,11 @@ func NewInmemoryLocalStorage() engine.LocalOverrideStorage {
} }
} }
func (s *inmemoryLocalStorage) generateChainID(name chain.Name, resource string) chain.ID { func (s *inmemoryLocalStorage) generateChainID(name chain.Name, target engine.Target) chain.ID {
var id chain.ID var id chain.ID
for { for {
suffix := rand.Uint32() % 100 suffix := rand.Uint32() % 100
sid := fmt.Sprintf("%s:%s/%d", name, resource, suffix) sid := fmt.Sprintf("%s:%s/%d", name, target.Name, suffix)
sid = strings.ReplaceAll(sid, "*", "") sid = strings.ReplaceAll(sid, "*", "")
sid = strings.ReplaceAll(sid, "/", ":") sid = strings.ReplaceAll(sid, "/", ":")
sid = strings.ReplaceAll(sid, "::", ":") sid = strings.ReplaceAll(sid, "::", ":")
@ -43,24 +43,24 @@ func (s *inmemoryLocalStorage) generateChainID(name chain.Name, resource string)
return id return id
} }
func (s *inmemoryLocalStorage) AddOverride(name chain.Name, resource string, c *chain.Chain) (chain.ID, error) { func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target, c *chain.Chain) (chain.ID, error) {
// AddOverride assigns generated chain ID if it has not been assigned. // AddOverride assigns generated chain ID if it has not been assigned.
if c.ID == "" { if c.ID == "" {
c.ID = s.generateChainID(name, resource) c.ID = s.generateChainID(name, target)
} }
if s.nameToResourceChains[name] == nil { if s.nameToResourceChains[name] == nil {
s.nameToResourceChains[name] = make(targetToChain) s.nameToResourceChains[name] = make(targetToChain)
} }
rc := s.nameToResourceChains[name] rc := s.nameToResourceChains[name]
rc[resource] = append(rc[resource], c) rc[target] = append(rc[target], c)
return c.ID, nil return c.ID, nil
} }
func (s *inmemoryLocalStorage) GetOverride(name chain.Name, resource string, chainID chain.ID) (*chain.Chain, error) { func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target, chainID chain.ID) (*chain.Chain, error) {
if _, ok := s.nameToResourceChains[name]; !ok { if _, ok := s.nameToResourceChains[name]; !ok {
return nil, engine.ErrChainNameNotFound return nil, engine.ErrChainNameNotFound
} }
chains, ok := s.nameToResourceChains[name][resource] chains, ok := s.nameToResourceChains[name][target]
if !ok { if !ok {
return nil, engine.ErrResourceNotFound return nil, engine.ErrResourceNotFound
} }
@ -72,30 +72,33 @@ func (s *inmemoryLocalStorage) GetOverride(name chain.Name, resource string, cha
return nil, engine.ErrChainNotFound return nil, engine.ErrChainNotFound
} }
func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, resource string, chainID chain.ID) error { func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Target, chainID chain.ID) error {
if _, ok := s.nameToResourceChains[name]; !ok { if _, ok := s.nameToResourceChains[name]; !ok {
return engine.ErrChainNameNotFound return engine.ErrChainNameNotFound
} }
chains, ok := s.nameToResourceChains[name][resource] chains, ok := s.nameToResourceChains[name][target]
if !ok { if !ok {
return engine.ErrResourceNotFound return engine.ErrResourceNotFound
} }
for i, c := range chains { for i, c := range chains {
if c.ID == chainID { if c.ID == chainID {
s.nameToResourceChains[name][resource] = append(chains[:i], chains[i+1:]...) s.nameToResourceChains[name][target] = append(chains[:i], chains[i+1:]...)
return nil return nil
} }
} }
return engine.ErrChainNotFound return engine.ErrChainNotFound
} }
func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, resource string) ([]*chain.Chain, error) { func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Target) ([]*chain.Chain, error) {
rcs, ok := s.nameToResourceChains[name] rcs, ok := s.nameToResourceChains[name]
if !ok { if !ok {
return []*chain.Chain{}, nil return []*chain.Chain{}, nil
} }
for container, chains := range rcs { for t, chains := range rcs {
if !util.GlobMatch(resource, container) { if t.Type != target.Type {
continue
}
if !util.GlobMatch(target.Name, t.Name) {
continue continue
} }
return chains, nil return chains, nil

View file

@ -9,11 +9,15 @@ import (
) )
const ( const (
resrc = "native:::object/ExYw/*" container = "native:::object/ExYw/*"
chainID = "ingress:ExYw" chainID = "ingress:ExYw"
nonExistChainId = "ingress:LxGyWyL" nonExistChainId = "ingress:LxGyWyL"
) )
var (
resrc = engine.ContainerTarget(container)
)
func testInmemLocalStorage() *inmemoryLocalStorage { func testInmemLocalStorage() *inmemoryLocalStorage {
return NewInmemoryLocalStorage().(*inmemoryLocalStorage) return NewInmemoryLocalStorage().(*inmemoryLocalStorage)
} }

View file

@ -20,9 +20,9 @@ func NewInmemoryMorphRuleChainStorage() engine.MorphRuleChainStorage {
func (s *inmemoryMorphRuleChainStorage) AddMorphRuleChain(name chain.Name, target engine.Target, c *chain.Chain) (err error) { func (s *inmemoryMorphRuleChainStorage) AddMorphRuleChain(name chain.Name, target engine.Target, c *chain.Chain) (err error) {
switch target.Type { switch target.Type {
case engine.Namespace: case engine.Namespace:
_, err = s.nameToNamespaceChains.AddOverride(name, target.Name, c) _, err = s.nameToNamespaceChains.AddOverride(name, target, c)
case engine.Container: case engine.Container:
_, err = s.nameToContainerChains.AddOverride(name, target.Name, c) _, err = s.nameToContainerChains.AddOverride(name, target, c)
default: default:
err = engine.ErrUnknownTarget err = engine.ErrUnknownTarget
} }
@ -32,9 +32,9 @@ func (s *inmemoryMorphRuleChainStorage) AddMorphRuleChain(name chain.Name, targe
func (s *inmemoryMorphRuleChainStorage) RemoveMorphRuleChain(name chain.Name, target engine.Target, chainID chain.ID) error { func (s *inmemoryMorphRuleChainStorage) RemoveMorphRuleChain(name chain.Name, target engine.Target, chainID chain.ID) error {
switch target.Type { switch target.Type {
case engine.Namespace: case engine.Namespace:
return s.nameToNamespaceChains.RemoveOverride(name, target.Name, chainID) return s.nameToNamespaceChains.RemoveOverride(name, target, chainID)
case engine.Container: case engine.Container:
return s.nameToContainerChains.RemoveOverride(name, target.Name, chainID) return s.nameToContainerChains.RemoveOverride(name, target, chainID)
default: default:
return engine.ErrUnknownTarget return engine.ErrUnknownTarget
} }
@ -43,9 +43,9 @@ func (s *inmemoryMorphRuleChainStorage) RemoveMorphRuleChain(name chain.Name, ta
func (s *inmemoryMorphRuleChainStorage) ListMorphRuleChains(name chain.Name, target engine.Target) ([]*chain.Chain, error) { func (s *inmemoryMorphRuleChainStorage) ListMorphRuleChains(name chain.Name, target engine.Target) ([]*chain.Chain, error) {
switch target.Type { switch target.Type {
case engine.Namespace: case engine.Namespace:
return s.nameToNamespaceChains.ListOverrides(name, target.Name) return s.nameToNamespaceChains.ListOverrides(name, target)
case engine.Container: case engine.Container:
return s.nameToContainerChains.ListOverrides(name, target.Name) return s.nameToContainerChains.ListOverrides(name, target)
default: default:
} }
return nil, engine.ErrUnknownTarget return nil, engine.ErrUnknownTarget

View file

@ -14,13 +14,13 @@ type ChainRouter interface {
// LocalOverrideStorage is the interface to manage local overrides defined // LocalOverrideStorage is the interface to manage local overrides defined
// for a node. Local overrides have a higher priority than chains got from morph storage. // for a node. Local overrides have a higher priority than chains got from morph storage.
type LocalOverrideStorage interface { type LocalOverrideStorage interface {
AddOverride(name chain.Name, resource string, c *chain.Chain) (chain.ID, error) AddOverride(name chain.Name, target Target, c *chain.Chain) (chain.ID, error)
GetOverride(name chain.Name, resource string, chainID chain.ID) (*chain.Chain, error) GetOverride(name chain.Name, target Target, chainID chain.ID) (*chain.Chain, error)
RemoveOverride(name chain.Name, resource string, chainID chain.ID) error RemoveOverride(name chain.Name, target Target, chainID chain.ID) error
ListOverrides(name chain.Name, resource string) ([]*chain.Chain, error) ListOverrides(name chain.Name, target Target) ([]*chain.Chain, error)
DropAllOverrides(name chain.Name) error DropAllOverrides(name chain.Name) error
} }