rpc: protect supscriptions of RPC WSClient from concurrent access

This commit is contained in:
AnnaShaleva 2022-02-21 16:40:50 +03:00 committed by Anna Shaleva
parent d77f188d10
commit 5b2e88b916

View file

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"sync"
"time"
"github.com/gorilla/websocket"
@ -35,6 +36,8 @@ type WSClient struct {
responses chan *response.Raw
requests chan *request.Raw
shutdown chan struct{}
subscriptionsLock sync.RWMutex
subscriptions map[string]bool
}
@ -237,6 +240,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
if err := c.performRequest("subscribe", params, &resp); err != nil {
return "", err
}
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
c.subscriptions[resp] = true
return resp, nil
}
@ -244,6 +251,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
func (c *WSClient) performUnsubscription(id string) error {
var resp bool
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
if !c.subscriptions[id] {
return errors.New("no subscription with this ID")
}
@ -325,11 +335,18 @@ func (c *WSClient) Unsubscribe(id string) error {
// UnsubscribeAll removes all active subscriptions of current client.
func (c *WSClient) UnsubscribeAll() error {
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
for id := range c.subscriptions {
err := c.performUnsubscription(id)
if err != nil {
var resp bool
if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil {
return err
}
if !resp {
return errors.New("unsubscribe method returned false result")
}
delete(c.subscriptions, id)
}
return nil
}