rpc: make RPC Client thread-safe

This commit is contained in:
AnnaShaleva 2022-02-21 16:06:43 +03:00 committed by Anna Shaleva
parent 0092330fe1
commit 408f6b050e
2 changed files with 21 additions and 0 deletions

View file

@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
@ -34,6 +35,7 @@ type Client struct {
opts Options
requestF func(*request.Raw) (*response.Raw, error)
cacheLock sync.RWMutex
// cache stores RPC node related information client is bound to.
// cache is mostly filled in during Init(), but can also be updated
// during regular Client lifecycle.
@ -123,6 +125,10 @@ func (c *Client) Init() error {
if err != nil {
return fmt.Errorf("failed to get network magic: %w", err)
}
c.cacheLock.Lock()
defer c.cacheLock.Unlock()
c.cache.network = version.Protocol.Network
c.cache.stateRootInHeader = version.Protocol.StateRootInHeader
if version.Protocol.MillisecondsPerBlock == 0 {

View file

@ -924,18 +924,23 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) {
return result, fmt.Errorf("can't get block count: %w", err)
}
c.cacheLock.RLock()
if c.cache.calculateValidUntilBlock.expiresAt > blockCount {
validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount
c.cacheLock.RUnlock()
} else {
c.cacheLock.RUnlock()
validators, err := c.GetNextBlockValidators()
if err != nil {
return result, fmt.Errorf("can't get validators: %w", err)
}
validatorsCount = uint32(len(validators))
c.cacheLock.Lock()
c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{
validatorsCount: validatorsCount,
expiresAt: blockCount + cacheTimeout,
}
c.cacheLock.Unlock()
}
return blockCount + validatorsCount + 1, nil
}
@ -993,6 +998,9 @@ func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs
// GetNetwork returns the network magic of the RPC node client connected to.
func (c *Client) GetNetwork() (netmode.Magic, error) {
c.cacheLock.RLock()
defer c.cacheLock.RUnlock()
if !c.cache.initDone {
return 0, errNetworkNotInitialized
}
@ -1002,6 +1010,9 @@ func (c *Client) GetNetwork() (netmode.Magic, error) {
// StateRootInHeader returns true if state root is contained in block header.
// You should initialize Client cache with Init() before calling StateRootInHeader.
func (c *Client) StateRootInHeader() (bool, error) {
c.cacheLock.RLock()
defer c.cacheLock.RUnlock()
if !c.cache.initDone {
return false, errNetworkNotInitialized
}
@ -1010,7 +1021,9 @@ func (c *Client) StateRootInHeader() (bool, error) {
// GetNativeContractHash returns native contract hash by its name.
func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
c.cacheLock.RLock()
hash, ok := c.cache.nativeHashes[name]
c.cacheLock.RUnlock()
if ok {
return hash, nil
}
@ -1018,6 +1031,8 @@ func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
if err != nil {
return util.Uint160{}, err
}
c.cacheLock.Lock()
c.cache.nativeHashes[name] = cs.Hash
c.cacheLock.Unlock()
return cs.Hash, nil
}