From 5b2e88b91639cb2284e4513128ff9f85f81a9d19 Mon Sep 17 00:00:00 2001 From: AnnaShaleva Date: Mon, 21 Feb 2022 16:40:50 +0300 Subject: [PATCH] rpc: protect supscriptions of RPC WSClient from concurrent access --- pkg/rpc/client/wsclient.go | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index 5239d84b6..6cb5e6049 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "sync" "time" "github.com/gorilla/websocket" @@ -30,12 +31,14 @@ type WSClient struct { // be closed, so make sure to handle this. Notifications chan Notification - ws *websocket.Conn - done chan struct{} - responses chan *response.Raw - requests chan *request.Raw - shutdown chan struct{} - subscriptions map[string]bool + ws *websocket.Conn + done chan struct{} + responses chan *response.Raw + requests chan *request.Raw + shutdown chan struct{} + + subscriptionsLock sync.RWMutex + subscriptions map[string]bool } // Notification represents server-generated notification for client subscriptions. @@ -237,6 +240,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error) if err := c.performRequest("subscribe", params, &resp); err != nil { return "", err } + + c.subscriptionsLock.Lock() + defer c.subscriptionsLock.Unlock() + c.subscriptions[resp] = true return resp, nil } @@ -244,6 +251,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error) func (c *WSClient) performUnsubscription(id string) error { var resp bool + c.subscriptionsLock.Lock() + defer c.subscriptionsLock.Unlock() + if !c.subscriptions[id] { return errors.New("no subscription with this ID") } @@ -325,11 +335,18 @@ func (c *WSClient) Unsubscribe(id string) error { // UnsubscribeAll removes all active subscriptions of current client. func (c *WSClient) UnsubscribeAll() error { + c.subscriptionsLock.Lock() + defer c.subscriptionsLock.Unlock() + for id := range c.subscriptions { - err := c.performUnsubscription(id) - if err != nil { + var resp bool + if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil { return err } + if !resp { + return errors.New("unsubscribe method returned false result") + } + delete(c.subscriptions, id) } return nil }