rpc: make RPC WSClient thread-safe

Add ability to use unique request IDs for RPC requests.
This commit is contained in:
AnnaShaleva 2022-02-18 20:28:13 +03:00
parent 5b2e88b916
commit 8991ee91cd
5 changed files with 93 additions and 20 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"strconv"
"sync"
"time"
@ -31,14 +32,16 @@ type WSClient struct {
// be closed, so make sure to handle this.
Notifications chan Notification
ws *websocket.Conn
done chan struct{}
responses chan *response.Raw
requests chan *request.Raw
shutdown chan struct{}
ws *websocket.Conn
done chan struct{}
requests chan *request.Raw
shutdown chan struct{}
subscriptionsLock sync.RWMutex
subscriptions map[string]bool
respLock sync.RWMutex
respChannels map[uint64]chan *response.Raw
}
// Notification represents server-generated notification for client subscriptions.
@ -95,7 +98,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error
ws: ws,
shutdown: make(chan struct{}),
done: make(chan struct{}),
responses: make(chan *response.Raw),
respChannels: make(map[uint64]chan *response.Raw),
requests: make(chan *request.Raw),
subscriptions: make(map[string]bool),
}
@ -178,14 +181,27 @@ readloop:
resp.JSONRPC = rr.JSONRPC
resp.Error = rr.Error
resp.Result = rr.Result
c.responses <- resp
id, err := strconv.Atoi(string(resp.ID))
if err != nil {
break // Malformed response (invalid response ID).
}
ch := c.getResponseChannel(uint64(id))
if ch == nil {
break // Unknown response (unexpected response ID).
}
ch <- resp
} else {
// Malformed response, neither valid request, nor valid response.
break
}
}
close(c.done)
close(c.responses)
c.respLock.Lock()
for _, ch := range c.respChannels {
close(ch)
}
c.respChannels = nil
c.respLock.Unlock()
close(c.Notifications)
}
@ -220,16 +236,41 @@ func (c *WSClient) wsWriter() {
}
}
func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) {
c.respLock.Lock()
defer c.respLock.Unlock()
c.respChannels[id] = ch
}
func (c *WSClient) unregisterRespChannel(id uint64) {
c.respLock.Lock()
defer c.respLock.Unlock()
if ch, ok := c.respChannels[id]; ok {
delete(c.respChannels, id)
close(ch)
}
}
func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw {
c.respLock.RLock()
defer c.respLock.RUnlock()
return c.respChannels[id]
}
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
ch := make(chan *response.Raw)
c.registerRespChannel(r.ID, ch)
select {
case <-c.done:
return nil, errors.New("connection lost")
return nil, errors.New("connection lost before sending the request")
case c.requests <- r:
}
select {
case <-c.done:
return nil, errors.New("connection lost")
case resp := <-c.responses:
return nil, errors.New("connection lost while waiting for the response")
case resp := <-ch:
c.unregisterRespChannel(r.ID)
return resp, nil
}
}