Merge pull request #3142 from Ayrtat/fix/wsclient_hang

rpcclient: fix wsclient hang on making request
This commit is contained in:
Roman Khimov 2023-10-16 15:26:13 +03:00 committed by GitHub
commit 38f77c39d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 14 deletions

View file

@ -29,7 +29,8 @@ func NewInternal(ctx context.Context, register InternalHook) (*Internal, error)
Client: Client{}, Client: Client{},
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
done: make(chan struct{}), readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
subscriptions: make(map[string]notificationReceiver), subscriptions: make(map[string]notificationReceiver),
receivers: make(map[any][]string), receivers: make(map[any][]string),
}, },
@ -63,7 +64,7 @@ eventloop:
c.notifySubscribers(ntf) c.notifySubscribers(ntf)
} }
} }
close(c.done) close(c.readerDone)
c.ctxCancel() c.ctxCancel()
// ctx is cancelled, server is notified and will finish soon. // ctx is cancelled, server is notified and will finish soon.
drainloop: drainloop:

View file

@ -58,7 +58,8 @@ type WSClient struct {
ws *websocket.Conn ws *websocket.Conn
wsOpts WSOptions wsOpts WSOptions
done chan struct{} readerDone chan struct{}
writerDone chan struct{}
requests chan *neorpc.Request requests chan *neorpc.Request
shutdown chan struct{} shutdown chan struct{}
closeCalled atomic.Bool closeCalled atomic.Bool
@ -425,7 +426,8 @@ func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, err
ws: ws, ws: ws,
wsOpts: opts, wsOpts: opts,
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
done: make(chan struct{}), readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
respChannels: make(map[uint64]chan *neorpc.Response), respChannels: make(map[uint64]chan *neorpc.Response),
requests: make(chan *neorpc.Request), requests: make(chan *neorpc.Request),
subscriptions: make(map[string]notificationReceiver), subscriptions: make(map[string]notificationReceiver),
@ -457,7 +459,7 @@ func (c *WSClient) Close() {
// Call to cancel will send signal to all users of Context(). // Call to cancel will send signal to all users of Context().
c.Client.ctxCancel() c.Client.ctxCancel()
} }
<-c.done <-c.readerDone
} }
func (c *WSClient) wsReader() { func (c *WSClient) wsReader() {
@ -551,7 +553,7 @@ readloop:
if connCloseErr != nil { if connCloseErr != nil {
c.setCloseErr(connCloseErr) c.setCloseErr(connCloseErr)
} }
close(c.done) close(c.readerDone)
c.respLock.Lock() c.respLock.Lock()
for _, ch := range c.respChannels { for _, ch := range c.respChannels {
close(ch) close(ch)
@ -583,13 +585,14 @@ func (c *WSClient) wsWriter() {
pingTicker := time.NewTicker(wsPingPeriod) pingTicker := time.NewTicker(wsPingPeriod)
defer c.ws.Close() defer c.ws.Close()
defer pingTicker.Stop() defer pingTicker.Stop()
defer close(c.writerDone)
var connCloseErr error var connCloseErr error
writeloop: writeloop:
for { for {
select { select {
case <-c.shutdown: case <-c.shutdown:
return return
case <-c.done: case <-c.readerDone:
return return
case req, ok := <-c.requests: case req, ok := <-c.requests:
if !ok { if !ok {
@ -660,28 +663,42 @@ func (c *WSClient) getResponseChannel(id uint64) chan *neorpc.Response {
return c.respChannels[id] return c.respChannels[id]
} }
// closeErrOrConnLost returns the error that may occur either in wsReader or wsWriter.
// If wsReader or wsWriter do not set the error, it returns ErrWSConnLost.
func (c *WSClient) closeErrOrConnLost() (err error) {
err = ErrWSConnLost
if closeErr := c.GetError(); closeErr != nil {
err = closeErr
}
return
}
func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) { func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
ch := make(chan *neorpc.Response) ch := make(chan *neorpc.Response)
c.respLock.Lock() c.respLock.Lock()
select { select {
case <-c.done: case <-c.readerDone:
c.respLock.Unlock() c.respLock.Unlock()
return nil, fmt.Errorf("%w: before registering response channel", ErrWSConnLost) return nil, fmt.Errorf("%w: before registering response channel", c.closeErrOrConnLost())
default: default:
c.respChannels[r.ID] = ch c.respChannels[r.ID] = ch
c.respLock.Unlock() c.respLock.Unlock()
} }
select { select {
case <-c.done: case <-c.readerDone:
return nil, fmt.Errorf("%w: before sending the request", ErrWSConnLost) return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost())
case <-c.writerDone:
return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost())
case c.requests <- r: case c.requests <- r:
} }
select { select {
case <-c.done: case <-c.readerDone:
return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost) return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
case <-c.writerDone:
return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
case resp, ok := <-ch: case resp, ok := <-ch:
if !ok { if !ok {
return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost) return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
} }
c.unregisterRespChannel(r.ID) c.unregisterRespChannel(r.ID)
return resp, nil return resp, nil