From 408f6b050edd4d85035d6dce13232ce1ad2bd0bf Mon Sep 17 00:00:00 2001 From: AnnaShaleva Date: Mon, 21 Feb 2022 16:06:43 +0300 Subject: [PATCH] rpc: make RPC Client thread-safe --- pkg/rpc/client/client.go | 6 ++++++ pkg/rpc/client/rpc.go | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 4f41bd10b..4dfc40ef4 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "sync" "time" "github.com/nspcc-dev/neo-go/pkg/config/netmode" @@ -34,6 +35,7 @@ type Client struct { opts Options requestF func(*request.Raw) (*response.Raw, error) + cacheLock sync.RWMutex // cache stores RPC node related information client is bound to. // cache is mostly filled in during Init(), but can also be updated // during regular Client lifecycle. @@ -123,6 +125,10 @@ func (c *Client) Init() error { if err != nil { return fmt.Errorf("failed to get network magic: %w", err) } + + c.cacheLock.Lock() + defer c.cacheLock.Unlock() + c.cache.network = version.Protocol.Network c.cache.stateRootInHeader = version.Protocol.StateRootInHeader if version.Protocol.MillisecondsPerBlock == 0 { diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index d00ede814..428cb30bb 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -924,18 +924,23 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) { return result, fmt.Errorf("can't get block count: %w", err) } + c.cacheLock.RLock() if c.cache.calculateValidUntilBlock.expiresAt > blockCount { validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount + c.cacheLock.RUnlock() } else { + c.cacheLock.RUnlock() validators, err := c.GetNextBlockValidators() if err != nil { return result, fmt.Errorf("can't get validators: %w", err) } validatorsCount = uint32(len(validators)) + c.cacheLock.Lock() c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{ validatorsCount: validatorsCount, expiresAt: blockCount + cacheTimeout, } + c.cacheLock.Unlock() } return blockCount + validatorsCount + 1, nil } @@ -993,6 +998,9 @@ func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs // GetNetwork returns the network magic of the RPC node client connected to. func (c *Client) GetNetwork() (netmode.Magic, error) { + c.cacheLock.RLock() + defer c.cacheLock.RUnlock() + if !c.cache.initDone { return 0, errNetworkNotInitialized } @@ -1002,6 +1010,9 @@ func (c *Client) GetNetwork() (netmode.Magic, error) { // StateRootInHeader returns true if state root is contained in block header. // You should initialize Client cache with Init() before calling StateRootInHeader. func (c *Client) StateRootInHeader() (bool, error) { + c.cacheLock.RLock() + defer c.cacheLock.RUnlock() + if !c.cache.initDone { return false, errNetworkNotInitialized } @@ -1010,7 +1021,9 @@ func (c *Client) StateRootInHeader() (bool, error) { // GetNativeContractHash returns native contract hash by its name. func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) { + c.cacheLock.RLock() hash, ok := c.cache.nativeHashes[name] + c.cacheLock.RUnlock() if ok { return hash, nil } @@ -1018,6 +1031,8 @@ func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) { if err != nil { return util.Uint160{}, err } + c.cacheLock.Lock() c.cache.nativeHashes[name] = cs.Hash + c.cacheLock.Unlock() return cs.Hash, nil }