policy-engine/pkg/engine/inmemory/local_storage.go
Dmitrii Stepanov ed93bb5cc5 [#35] local_storage: Make thread safe
Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
2023-12-21 12:13:54 +00:00

136 lines
3.3 KiB
Go

package inmemory
import (
"fmt"
"math/rand"
"strings"
"sync"
"git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain"
"git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine"
"git.frostfs.info/TrueCloudLab/policy-engine/util"
)
type targetToChain map[engine.Target][]*chain.Chain
type inmemoryLocalStorage struct {
usedChainID map[chain.ID]struct{}
nameToResourceChains map[chain.Name]targetToChain
guard *sync.RWMutex
}
func NewInmemoryLocalStorage() engine.LocalOverrideStorage {
return &inmemoryLocalStorage{
usedChainID: map[chain.ID]struct{}{},
nameToResourceChains: make(map[chain.Name]targetToChain),
guard: &sync.RWMutex{},
}
}
func (s *inmemoryLocalStorage) generateChainID(name chain.Name, target engine.Target) chain.ID {
var id chain.ID
for {
suffix := rand.Uint32() % 100
sid := fmt.Sprintf("%s:%s/%d", name, target.Name, suffix)
sid = strings.ReplaceAll(sid, "*", "")
sid = strings.ReplaceAll(sid, "/", ":")
sid = strings.ReplaceAll(sid, "::", ":")
id = chain.ID(sid)
_, ok := s.usedChainID[id]
if ok {
continue
}
s.usedChainID[id] = struct{}{}
break
}
return id
}
func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target, c *chain.Chain) (chain.ID, error) {
s.guard.Lock()
defer s.guard.Unlock()
// AddOverride assigns generated chain ID if it has not been assigned.
if c.ID == "" {
c.ID = s.generateChainID(name, target)
}
if s.nameToResourceChains[name] == nil {
s.nameToResourceChains[name] = make(targetToChain)
}
rc := s.nameToResourceChains[name]
for i := range rc[target] {
if rc[target][i].ID == c.ID {
rc[target][i] = c
return c.ID, nil
}
}
rc[target] = append(rc[target], c)
return c.ID, nil
}
func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target, chainID chain.ID) (*chain.Chain, error) {
s.guard.RLock()
defer s.guard.RUnlock()
if _, ok := s.nameToResourceChains[name]; !ok {
return nil, engine.ErrChainNameNotFound
}
chains, ok := s.nameToResourceChains[name][target]
if !ok {
return nil, engine.ErrResourceNotFound
}
for _, c := range chains {
if c.ID == chainID {
return c, nil
}
}
return nil, engine.ErrChainNotFound
}
func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Target, chainID chain.ID) error {
s.guard.Lock()
defer s.guard.Unlock()
if _, ok := s.nameToResourceChains[name]; !ok {
return engine.ErrChainNameNotFound
}
chains, ok := s.nameToResourceChains[name][target]
if !ok {
return engine.ErrResourceNotFound
}
for i, c := range chains {
if c.ID == chainID {
s.nameToResourceChains[name][target] = append(chains[:i], chains[i+1:]...)
return nil
}
}
return engine.ErrChainNotFound
}
func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Target) ([]*chain.Chain, error) {
s.guard.RLock()
defer s.guard.RUnlock()
rcs, ok := s.nameToResourceChains[name]
if !ok {
return []*chain.Chain{}, nil
}
for t, chains := range rcs {
if t.Type != target.Type {
continue
}
if !util.GlobMatch(target.Name, t.Name) {
continue
}
return chains, nil
}
return []*chain.Chain{}, nil
}
func (s *inmemoryLocalStorage) DropAllOverrides(name chain.Name) error {
s.guard.Lock()
defer s.guard.Unlock()
s.nameToResourceChains[name] = make(targetToChain)
return nil
}