mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-22 19:43:46 +00:00
rpc: protect supscriptions of RPC WSClient from concurrent access
This commit is contained in:
parent
d77f188d10
commit
5b2e88b916
1 changed files with 25 additions and 8 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -30,12 +31,14 @@ type WSClient struct {
|
|||
// be closed, so make sure to handle this.
|
||||
Notifications chan Notification
|
||||
|
||||
ws *websocket.Conn
|
||||
done chan struct{}
|
||||
responses chan *response.Raw
|
||||
requests chan *request.Raw
|
||||
shutdown chan struct{}
|
||||
subscriptions map[string]bool
|
||||
ws *websocket.Conn
|
||||
done chan struct{}
|
||||
responses chan *response.Raw
|
||||
requests chan *request.Raw
|
||||
shutdown chan struct{}
|
||||
|
||||
subscriptionsLock sync.RWMutex
|
||||
subscriptions map[string]bool
|
||||
}
|
||||
|
||||
// Notification represents server-generated notification for client subscriptions.
|
||||
|
@ -237,6 +240,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
|
|||
if err := c.performRequest("subscribe", params, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
c.subscriptionsLock.Lock()
|
||||
defer c.subscriptionsLock.Unlock()
|
||||
|
||||
c.subscriptions[resp] = true
|
||||
return resp, nil
|
||||
}
|
||||
|
@ -244,6 +251,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
|
|||
func (c *WSClient) performUnsubscription(id string) error {
|
||||
var resp bool
|
||||
|
||||
c.subscriptionsLock.Lock()
|
||||
defer c.subscriptionsLock.Unlock()
|
||||
|
||||
if !c.subscriptions[id] {
|
||||
return errors.New("no subscription with this ID")
|
||||
}
|
||||
|
@ -325,11 +335,18 @@ func (c *WSClient) Unsubscribe(id string) error {
|
|||
|
||||
// UnsubscribeAll removes all active subscriptions of current client.
|
||||
func (c *WSClient) UnsubscribeAll() error {
|
||||
c.subscriptionsLock.Lock()
|
||||
defer c.subscriptionsLock.Unlock()
|
||||
|
||||
for id := range c.subscriptions {
|
||||
err := c.performUnsubscription(id)
|
||||
if err != nil {
|
||||
var resp bool
|
||||
if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp {
|
||||
return errors.New("unsubscribe method returned false result")
|
||||
}
|
||||
delete(c.subscriptions, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue