From 54145916a9a2a3dabb2dcc76e6fc8271c5673806 Mon Sep 17 00:00:00 2001 From: Denis Kirillov Date: Mon, 18 Jul 2022 11:42:29 +0300 Subject: [PATCH] [#283] pool: Change batch of atomics to mutex Signed-off-by: Denis Kirillov --- pool/mock_test.go | 2 +- pool/pool.go | 73 ++++++++++++++++++++++++----------------------- pool/pool_test.go | 22 ++++++-------- pool/statistic.go | 6 ++-- 4 files changed, 50 insertions(+), 53 deletions(-) diff --git a/pool/mock_test.go b/pool/mock_test.go index 758b8a9..0286f6d 100644 --- a/pool/mock_test.go +++ b/pool/mock_test.go @@ -21,7 +21,7 @@ import ( type mockClient struct { key ecdsa.PrivateKey - *clientStatusMonitor + clientStatusMonitor errorOnCreateSession bool errorOnEndpointInfo bool diff --git a/pool/pool.go b/pool/pool.go index c12db01..88533f6 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -60,26 +60,27 @@ type clientStatus interface { address() string currentErrorRate() uint32 overallErrorRate() uint64 - resetErrorCounter() latency() time.Duration requests() uint64 } type clientStatusMonitor struct { - addr string - healthy *atomic.Bool - currentErrorCount *atomic.Uint32 - overallErrorCount *atomic.Uint64 - errorThreshold uint32 - allTime *atomic.Uint64 - allRequests *atomic.Uint64 + addr string + healthy *atomic.Bool + errorThreshold uint32 + + mu sync.RWMutex // protect counters + currentErrorCount uint32 + overallErrorCount uint64 + allTime uint64 + allRequests uint64 } // clientWrapper is used by default, alternative implementations are intended for testing purposes only. type clientWrapper struct { client sdkClient.Client key ecdsa.PrivateKey - *clientStatusMonitor + clientStatusMonitor } type wrapperPrm struct { @@ -112,20 +113,15 @@ func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo) func newWrapper(prm wrapperPrm) (*clientWrapper, error) { var prmInit sdkClient.PrmInit - //prmInit.ResolveNeoFSFailures() prmInit.SetDefaultPrivateKey(prm.key) prmInit.SetResponseInfoCallback(prm.responseInfoCallback) res := &clientWrapper{ key: prm.key, - clientStatusMonitor: &clientStatusMonitor{ - addr: prm.address, - healthy: atomic.NewBool(true), - currentErrorCount: atomic.NewUint32(0), - overallErrorCount: atomic.NewUint64(0), - errorThreshold: prm.errorThreshold, - allTime: atomic.NewUint64(0), - allRequests: atomic.NewUint64(0), + clientStatusMonitor: clientStatusMonitor{ + addr: prm.address, + healthy: atomic.NewBool(true), + errorThreshold: prm.errorThreshold, }, } @@ -612,41 +608,48 @@ func (c *clientStatusMonitor) address() string { } func (c *clientStatusMonitor) incErrorRate() { - c.currentErrorCount.Inc() - c.overallErrorCount.Inc() - if c.currentErrorCount.Load() >= c.errorThreshold { + c.mu.Lock() + defer c.mu.Unlock() + c.currentErrorCount++ + c.overallErrorCount++ + if c.currentErrorCount >= c.errorThreshold { c.setHealthy(false) - c.resetErrorCounter() + c.currentErrorCount = 0 } } func (c *clientStatusMonitor) currentErrorRate() uint32 { - return c.currentErrorCount.Load() + c.mu.RLock() + defer c.mu.RUnlock() + return c.currentErrorCount } func (c *clientStatusMonitor) overallErrorRate() uint64 { - return c.overallErrorCount.Load() -} - -func (c *clientStatusMonitor) resetErrorCounter() { - c.currentErrorCount.Store(0) + c.mu.RLock() + defer c.mu.RUnlock() + return c.overallErrorCount } func (c *clientStatusMonitor) latency() time.Duration { - allRequests := c.requests() - if allRequests == 0 { + c.mu.RLock() + defer c.mu.RUnlock() + if c.allRequests == 0 { return 0 } - return time.Duration(c.allTime.Load() / allRequests) + return time.Duration(c.allTime / c.allRequests) } func (c *clientStatusMonitor) requests() uint64 { - return c.allRequests.Load() + c.mu.RLock() + defer c.mu.RUnlock() + return c.allRequests } func (c *clientStatusMonitor) incRequests(elapsed time.Duration) { - c.allTime.Add(uint64(elapsed)) - c.allRequests.Inc() + c.mu.Lock() + defer c.mu.Unlock() + c.allTime += uint64(elapsed) + c.allRequests++ } func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error { @@ -1971,7 +1974,7 @@ func (p Pool) Statistic() Statistic { for _, inner := range p.innerPools { inner.lock.RLock() for _, cl := range inner.clients { - node := &NodeStatistic{ + node := NodeStatistic{ address: cl.address(), latency: cl.latency(), requests: cl.requests(), diff --git a/pool/pool_test.go b/pool/pool_test.go index ec1053b..9a7f7bb 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -508,31 +508,25 @@ func TestWaitPresence(t *testing.T) { }) } -func newTestStatusMonitor(addr string) *clientStatusMonitor { - return &clientStatusMonitor{ - addr: addr, - healthy: atomic.NewBool(true), - currentErrorCount: atomic.NewUint32(0), - overallErrorCount: atomic.NewUint64(0), - errorThreshold: 10, - allTime: atomic.NewUint64(0), - allRequests: atomic.NewUint64(0), +func newTestStatusMonitor(addr string) clientStatusMonitor { + return clientStatusMonitor{ + addr: addr, + healthy: atomic.NewBool(true), + errorThreshold: 10, } } func TestStatusMonitor(t *testing.T) { monitor := newTestStatusMonitor("") + monitor.errorThreshold = 3 count := 10 - for i := 0; i < 10; i++ { + for i := 0; i < count; i++ { monitor.incErrorRate() - if i%3 == 0 { - monitor.resetErrorCounter() - } } require.Equal(t, uint64(count), monitor.overallErrorRate()) - require.Equal(t, uint32(0), monitor.currentErrorRate()) + require.Equal(t, uint32(1), monitor.currentErrorRate()) } func TestHandleError(t *testing.T) { diff --git a/pool/statistic.go b/pool/statistic.go index 89e228e..d8f3938 100644 --- a/pool/statistic.go +++ b/pool/statistic.go @@ -8,7 +8,7 @@ import ( // Statistic is metrics of the pool. type Statistic struct { overallErrors uint64 - nodes []*NodeStatistic + nodes []NodeStatistic } // OverallErrors returns sum of errors on all connections. It doesn't decrease. @@ -17,7 +17,7 @@ func (s Statistic) OverallErrors() uint64 { } // Nodes returns list of nodes statistic. -func (s Statistic) Nodes() []*NodeStatistic { +func (s Statistic) Nodes() []NodeStatistic { return s.nodes } @@ -29,7 +29,7 @@ var ErrUnknownNode = errors.New("unknown node") func (s Statistic) Node(address string) (*NodeStatistic, error) { for i := range s.nodes { if s.nodes[i].address == address { - return s.nodes[i], nil + return &s.nodes[i], nil } }