rpc: make RPC WSClient thread-safe
Add ability to use unique request IDs for RPC requests.
This commit is contained in:
parent
5b2e88b916
commit
8991ee91cd
5 changed files with 93 additions and 20 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -31,14 +32,16 @@ type WSClient struct {
|
|||
// be closed, so make sure to handle this.
|
||||
Notifications chan Notification
|
||||
|
||||
ws *websocket.Conn
|
||||
done chan struct{}
|
||||
responses chan *response.Raw
|
||||
requests chan *request.Raw
|
||||
shutdown chan struct{}
|
||||
ws *websocket.Conn
|
||||
done chan struct{}
|
||||
requests chan *request.Raw
|
||||
shutdown chan struct{}
|
||||
|
||||
subscriptionsLock sync.RWMutex
|
||||
subscriptions map[string]bool
|
||||
|
||||
respLock sync.RWMutex
|
||||
respChannels map[uint64]chan *response.Raw
|
||||
}
|
||||
|
||||
// Notification represents server-generated notification for client subscriptions.
|
||||
|
@ -95,7 +98,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error
|
|||
ws: ws,
|
||||
shutdown: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
responses: make(chan *response.Raw),
|
||||
respChannels: make(map[uint64]chan *response.Raw),
|
||||
requests: make(chan *request.Raw),
|
||||
subscriptions: make(map[string]bool),
|
||||
}
|
||||
|
@ -178,14 +181,27 @@ readloop:
|
|||
resp.JSONRPC = rr.JSONRPC
|
||||
resp.Error = rr.Error
|
||||
resp.Result = rr.Result
|
||||
c.responses <- resp
|
||||
id, err := strconv.Atoi(string(resp.ID))
|
||||
if err != nil {
|
||||
break // Malformed response (invalid response ID).
|
||||
}
|
||||
ch := c.getResponseChannel(uint64(id))
|
||||
if ch == nil {
|
||||
break // Unknown response (unexpected response ID).
|
||||
}
|
||||
ch <- resp
|
||||
} else {
|
||||
// Malformed response, neither valid request, nor valid response.
|
||||
break
|
||||
}
|
||||
}
|
||||
close(c.done)
|
||||
close(c.responses)
|
||||
c.respLock.Lock()
|
||||
for _, ch := range c.respChannels {
|
||||
close(ch)
|
||||
}
|
||||
c.respChannels = nil
|
||||
c.respLock.Unlock()
|
||||
close(c.Notifications)
|
||||
}
|
||||
|
||||
|
@ -220,16 +236,41 @@ func (c *WSClient) wsWriter() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) {
|
||||
c.respLock.Lock()
|
||||
defer c.respLock.Unlock()
|
||||
c.respChannels[id] = ch
|
||||
}
|
||||
|
||||
func (c *WSClient) unregisterRespChannel(id uint64) {
|
||||
c.respLock.Lock()
|
||||
defer c.respLock.Unlock()
|
||||
if ch, ok := c.respChannels[id]; ok {
|
||||
delete(c.respChannels, id)
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw {
|
||||
c.respLock.RLock()
|
||||
defer c.respLock.RUnlock()
|
||||
return c.respChannels[id]
|
||||
}
|
||||
|
||||
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
||||
ch := make(chan *response.Raw)
|
||||
c.registerRespChannel(r.ID, ch)
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
return nil, errors.New("connection lost")
|
||||
return nil, errors.New("connection lost before sending the request")
|
||||
case c.requests <- r:
|
||||
}
|
||||
select {
|
||||
case <-c.done:
|
||||
return nil, errors.New("connection lost")
|
||||
case resp := <-c.responses:
|
||||
return nil, errors.New("connection lost while waiting for the response")
|
||||
case resp := <-ch:
|
||||
c.unregisterRespChannel(r.ID)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue