From ed93bb5cc57465c1b9b5caa4c8e3b62cdf5357e1 Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Thu, 21 Dec 2023 11:13:22 +0300 Subject: [PATCH] [#35] local_storage: Make thread safe Signed-off-by: Dmitrii Stepanov --- pkg/engine/inmemory/local_storage.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pkg/engine/inmemory/local_storage.go b/pkg/engine/inmemory/local_storage.go index 4055657..21d3d55 100644 --- a/pkg/engine/inmemory/local_storage.go +++ b/pkg/engine/inmemory/local_storage.go @@ -4,6 +4,7 @@ import ( "fmt" "math/rand" "strings" + "sync" "git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain" "git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine" @@ -15,12 +16,14 @@ type targetToChain map[engine.Target][]*chain.Chain type inmemoryLocalStorage struct { usedChainID map[chain.ID]struct{} nameToResourceChains map[chain.Name]targetToChain + guard *sync.RWMutex } func NewInmemoryLocalStorage() engine.LocalOverrideStorage { return &inmemoryLocalStorage{ usedChainID: map[chain.ID]struct{}{}, nameToResourceChains: make(map[chain.Name]targetToChain), + guard: &sync.RWMutex{}, } } @@ -44,6 +47,9 @@ func (s *inmemoryLocalStorage) generateChainID(name chain.Name, target engine.Ta } func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target, c *chain.Chain) (chain.ID, error) { + s.guard.Lock() + defer s.guard.Unlock() + // AddOverride assigns generated chain ID if it has not been assigned. if c.ID == "" { c.ID = s.generateChainID(name, target) @@ -63,6 +69,9 @@ func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target } func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target, chainID chain.ID) (*chain.Chain, error) { + s.guard.RLock() + defer s.guard.RUnlock() + if _, ok := s.nameToResourceChains[name]; !ok { return nil, engine.ErrChainNameNotFound } @@ -79,6 +88,9 @@ func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target } func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Target, chainID chain.ID) error { + s.guard.Lock() + defer s.guard.Unlock() + if _, ok := s.nameToResourceChains[name]; !ok { return engine.ErrChainNameNotFound } @@ -96,6 +108,9 @@ func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Tar } func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Target) ([]*chain.Chain, error) { + s.guard.RLock() + defer s.guard.RUnlock() + rcs, ok := s.nameToResourceChains[name] if !ok { return []*chain.Chain{}, nil @@ -113,6 +128,9 @@ func (s *inmemoryLocalStorage) ListOverrides(name chain.Name, target engine.Targ } func (s *inmemoryLocalStorage) DropAllOverrides(name chain.Name) error { + s.guard.Lock() + defer s.guard.Unlock() + s.nameToResourceChains[name] = make(targetToChain) return nil }