[#35] local_storage: Make thread safe

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
Dmitrii Stepanov 2023-12-21 11:13:22 +03:00 committed by Evgenii Stratonikov
parent 06e9c91014
commit ed93bb5cc5

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"strings" "strings"
"sync"
"git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain" "git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain"
"git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine" "git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine"
@ -15,12 +16,14 @@ type targetToChain map[engine.Target][]*chain.Chain
type inmemoryLocalStorage struct { type inmemoryLocalStorage struct {
usedChainID map[chain.ID]struct{} usedChainID map[chain.ID]struct{}
nameToResourceChains map[chain.Name]targetToChain nameToResourceChains map[chain.Name]targetToChain
guard *sync.RWMutex
} }
func NewInmemoryLocalStorage() engine.LocalOverrideStorage { func NewInmemoryLocalStorage() engine.LocalOverrideStorage {
return &inmemoryLocalStorage{ return &inmemoryLocalStorage{
usedChainID: map[chain.ID]struct{}{}, usedChainID: map[chain.ID]struct{}{},
nameToResourceChains: make(map[chain.Name]targetToChain), nameToResourceChains: make(map[chain.Name]targetToChain),
guard: &sync.RWMutex{},
} }
} }
@ -44,6 +47,9 @@ func (s *inmemoryLocalStorage) generateChainID(name chain.Name, target engine.Ta
} }
func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target, c *chain.Chain) (chain.ID, error) { func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target, c *chain.Chain) (chain.ID, error) {
s.guard.Lock()
defer s.guard.Unlock()
// AddOverride assigns generated chain ID if it has not been assigned. // AddOverride assigns generated chain ID if it has not been assigned.
if c.ID == "" { if c.ID == "" {
c.ID = s.generateChainID(name, target) c.ID = s.generateChainID(name, target)
@ -63,6 +69,9 @@ func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target
} }
func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target, chainID chain.ID) (*chain.Chain, error) { func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target, chainID chain.ID) (*chain.Chain, error) {
s.guard.RLock()
defer s.guard.RUnlock()
if _, ok := s.nameToResourceChains[name]; !ok { if _, ok := s.nameToResourceChains[name]; !ok {
return nil, engine.ErrChainNameNotFound return nil, engine.ErrChainNameNotFound
} }
@ -79,6 +88,9 @@ func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target
} }
func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Target, chainID chain.ID) error { func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Target, chainID chain.ID) error {
s.guard.Lock()
defer s.guard.Unlock()
if _, ok := s.nameToResourceChains[name]; !ok { if _, ok := s.nameToResourceChains[name]; !ok {
return engine.ErrChainNameNotFound return engine.ErrChainNameNotFound
} }
@ -96,6 +108,9 @@ func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Tar
} }
func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Target) ([]*chain.Chain, error) { func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Target) ([]*chain.Chain, error) {
s.guard.RLock()
defer s.guard.RUnlock()
rcs, ok := s.nameToResourceChains[name] rcs, ok := s.nameToResourceChains[name]
if !ok { if !ok {
return []*chain.Chain{}, nil return []*chain.Chain{}, nil
@ -113,6 +128,9 @@ func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Targ
} }
func (s *inmemoryLocalStorage) DropAllOverrides(name chain.Name) error { func (s *inmemoryLocalStorage) DropAllOverrides(name chain.Name) error {
s.guard.Lock()
defer s.guard.Unlock()
s.nameToResourceChains[name] = make(targetToChain) s.nameToResourceChains[name] = make(targetToChain)
return nil return nil
} }