package inmemory import ( "bytes" "fmt" "math/rand" "strings" "sync" "git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain" "git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine" "git.frostfs.info/TrueCloudLab/policy-engine/util" ) type targetToChain map[engine.Target][]*chain.Chain type inmemoryLocalStorage struct { usedChainID map[string]struct{} nameToResourceChains map[chain.Name]targetToChain guard *sync.RWMutex } func NewInmemoryLocalStorage() engine.LocalOverrideStorage { return &inmemoryLocalStorage{ usedChainID: map[string]struct{}{}, nameToResourceChains: make(map[chain.Name]targetToChain), guard: &sync.RWMutex{}, } } func (s *inmemoryLocalStorage) generateChainID(name chain.Name, target engine.Target) chain.ID { var id chain.ID for { suffix := rand.Uint32() % 100 sid := fmt.Sprintf("%s:%s/%d", name, target.Name, suffix) sid = strings.ReplaceAll(sid, "*", "") sid = strings.ReplaceAll(sid, "/", ":") sid = strings.ReplaceAll(sid, "::", ":") _, ok := s.usedChainID[sid] if ok { continue } s.usedChainID[sid] = struct{}{} id = chain.ID(sid) break } return id } func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target, c *chain.Chain) (chain.ID, error) { s.guard.Lock() defer s.guard.Unlock() // AddOverride assigns generated chain ID if it has not been assigned. if len(c.ID) == 0 { c.ID = s.generateChainID(name, target) } if s.nameToResourceChains[name] == nil { s.nameToResourceChains[name] = make(targetToChain) } rc := s.nameToResourceChains[name] for i := range rc[target] { if bytes.Equal(rc[target][i].ID, c.ID) { rc[target][i] = c return c.ID, nil } } rc[target] = append(rc[target], c) return c.ID, nil } func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target, chainID chain.ID) (*chain.Chain, error) { s.guard.RLock() defer s.guard.RUnlock() if _, ok := s.nameToResourceChains[name]; !ok { return nil, engine.ErrChainNameNotFound } if target.Name == "" { target.Name = "root" } chains, ok := s.nameToResourceChains[name][target] if !ok { return nil, engine.ErrResourceNotFound } for _, c := range chains { if bytes.Equal(c.ID, chainID) { return c, nil } } return nil, engine.ErrChainNotFound } func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Target, chainID chain.ID) error { s.guard.Lock() defer s.guard.Unlock() if _, ok := s.nameToResourceChains[name]; !ok { return engine.ErrChainNameNotFound } if target.Name == "" { target.Name = "root" } chains, ok := s.nameToResourceChains[name][target] if !ok { return engine.ErrResourceNotFound } for i, c := range chains { if bytes.Equal(c.ID, chainID) { s.nameToResourceChains[name][target] = append(chains[:i], chains[i+1:]...) return nil } } return engine.ErrChainNotFound } func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Target) ([]*chain.Chain, error) { s.guard.RLock() defer s.guard.RUnlock() rcs, ok := s.nameToResourceChains[name] if !ok { return []*chain.Chain{}, nil } if target.Name == "" { target.Name = "root" } for t, chains := range rcs { if t.Type != target.Type { continue } if !util.GlobMatch(target.Name, t.Name) { continue } return chains, nil } return []*chain.Chain{}, nil } func (s *inmemoryLocalStorage) DropAllOverrides(name chain.Name) error { s.guard.Lock() defer s.guard.Unlock() s.nameToResourceChains[name] = make(targetToChain) return nil } func (s *inmemoryLocalStorage) ListOverrideDefinedTargets(name chain.Name) ([]engine.Target, error) { s.guard.RLock() defer s.guard.RUnlock() ttc := s.nameToResourceChains[name] var keys []engine.Target for k := range ttc { keys = append(keys, k) } return keys, nil }