rpc: protect supscriptions of RPC WSClient from concurrent access

This commit is contained in:
AnnaShaleva 2022-02-21 16:40:50 +03:00 committed by Anna Shaleva
parent d77f188d10
commit 5b2e88b916

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -35,6 +36,8 @@ type WSClient struct {
responses chan *response.Raw responses chan *response.Raw
requests chan *request.Raw requests chan *request.Raw
shutdown chan struct{} shutdown chan struct{}
subscriptionsLock sync.RWMutex
subscriptions map[string]bool subscriptions map[string]bool
} }
@ -237,6 +240,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
if err := c.performRequest("subscribe", params, &resp); err != nil { if err := c.performRequest("subscribe", params, &resp); err != nil {
return "", err return "", err
} }
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
c.subscriptions[resp] = true c.subscriptions[resp] = true
return resp, nil return resp, nil
} }
@ -244,6 +251,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
func (c *WSClient) performUnsubscription(id string) error { func (c *WSClient) performUnsubscription(id string) error {
var resp bool var resp bool
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
if !c.subscriptions[id] { if !c.subscriptions[id] {
return errors.New("no subscription with this ID") return errors.New("no subscription with this ID")
} }
@ -325,11 +335,18 @@ func (c *WSClient) Unsubscribe(id string) error {
// UnsubscribeAll removes all active subscriptions of current client. // UnsubscribeAll removes all active subscriptions of current client.
func (c *WSClient) UnsubscribeAll() error { func (c *WSClient) UnsubscribeAll() error {
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
for id := range c.subscriptions { for id := range c.subscriptions {
err := c.performUnsubscription(id) var resp bool
if err != nil { if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil {
return err return err
} }
if !resp {
return errors.New("unsubscribe method returned false result")
}
delete(c.subscriptions, id)
} }
return nil return nil
} }