forked from TrueCloudLab/neoneo-go
Merge pull request #3142 from Ayrtat/fix/wsclient_hang
rpcclient: fix wsclient hang on making request
This commit is contained in:
commit
38f77c39d0
2 changed files with 32 additions and 14 deletions
|
@ -29,7 +29,8 @@ func NewInternal(ctx context.Context, register InternalHook) (*Internal, error)
|
||||||
Client: Client{},
|
Client: Client{},
|
||||||
|
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
done: make(chan struct{}),
|
readerDone: make(chan struct{}),
|
||||||
|
writerDone: make(chan struct{}),
|
||||||
subscriptions: make(map[string]notificationReceiver),
|
subscriptions: make(map[string]notificationReceiver),
|
||||||
receivers: make(map[any][]string),
|
receivers: make(map[any][]string),
|
||||||
},
|
},
|
||||||
|
@ -63,7 +64,7 @@ eventloop:
|
||||||
c.notifySubscribers(ntf)
|
c.notifySubscribers(ntf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(c.done)
|
close(c.readerDone)
|
||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
// ctx is cancelled, server is notified and will finish soon.
|
// ctx is cancelled, server is notified and will finish soon.
|
||||||
drainloop:
|
drainloop:
|
||||||
|
|
|
@ -58,7 +58,8 @@ type WSClient struct {
|
||||||
|
|
||||||
ws *websocket.Conn
|
ws *websocket.Conn
|
||||||
wsOpts WSOptions
|
wsOpts WSOptions
|
||||||
done chan struct{}
|
readerDone chan struct{}
|
||||||
|
writerDone chan struct{}
|
||||||
requests chan *neorpc.Request
|
requests chan *neorpc.Request
|
||||||
shutdown chan struct{}
|
shutdown chan struct{}
|
||||||
closeCalled atomic.Bool
|
closeCalled atomic.Bool
|
||||||
|
@ -425,7 +426,8 @@ func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, err
|
||||||
ws: ws,
|
ws: ws,
|
||||||
wsOpts: opts,
|
wsOpts: opts,
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
done: make(chan struct{}),
|
readerDone: make(chan struct{}),
|
||||||
|
writerDone: make(chan struct{}),
|
||||||
respChannels: make(map[uint64]chan *neorpc.Response),
|
respChannels: make(map[uint64]chan *neorpc.Response),
|
||||||
requests: make(chan *neorpc.Request),
|
requests: make(chan *neorpc.Request),
|
||||||
subscriptions: make(map[string]notificationReceiver),
|
subscriptions: make(map[string]notificationReceiver),
|
||||||
|
@ -457,7 +459,7 @@ func (c *WSClient) Close() {
|
||||||
// Call to cancel will send signal to all users of Context().
|
// Call to cancel will send signal to all users of Context().
|
||||||
c.Client.ctxCancel()
|
c.Client.ctxCancel()
|
||||||
}
|
}
|
||||||
<-c.done
|
<-c.readerDone
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WSClient) wsReader() {
|
func (c *WSClient) wsReader() {
|
||||||
|
@ -551,7 +553,7 @@ readloop:
|
||||||
if connCloseErr != nil {
|
if connCloseErr != nil {
|
||||||
c.setCloseErr(connCloseErr)
|
c.setCloseErr(connCloseErr)
|
||||||
}
|
}
|
||||||
close(c.done)
|
close(c.readerDone)
|
||||||
c.respLock.Lock()
|
c.respLock.Lock()
|
||||||
for _, ch := range c.respChannels {
|
for _, ch := range c.respChannels {
|
||||||
close(ch)
|
close(ch)
|
||||||
|
@ -583,13 +585,14 @@ func (c *WSClient) wsWriter() {
|
||||||
pingTicker := time.NewTicker(wsPingPeriod)
|
pingTicker := time.NewTicker(wsPingPeriod)
|
||||||
defer c.ws.Close()
|
defer c.ws.Close()
|
||||||
defer pingTicker.Stop()
|
defer pingTicker.Stop()
|
||||||
|
defer close(c.writerDone)
|
||||||
var connCloseErr error
|
var connCloseErr error
|
||||||
writeloop:
|
writeloop:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.shutdown:
|
case <-c.shutdown:
|
||||||
return
|
return
|
||||||
case <-c.done:
|
case <-c.readerDone:
|
||||||
return
|
return
|
||||||
case req, ok := <-c.requests:
|
case req, ok := <-c.requests:
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -660,28 +663,42 @@ func (c *WSClient) getResponseChannel(id uint64) chan *neorpc.Response {
|
||||||
return c.respChannels[id]
|
return c.respChannels[id]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// closeErrOrConnLost returns the error that may occur either in wsReader or wsWriter.
|
||||||
|
// If wsReader or wsWriter do not set the error, it returns ErrWSConnLost.
|
||||||
|
func (c *WSClient) closeErrOrConnLost() (err error) {
|
||||||
|
err = ErrWSConnLost
|
||||||
|
if closeErr := c.GetError(); closeErr != nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
|
func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
|
||||||
ch := make(chan *neorpc.Response)
|
ch := make(chan *neorpc.Response)
|
||||||
c.respLock.Lock()
|
c.respLock.Lock()
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.readerDone:
|
||||||
c.respLock.Unlock()
|
c.respLock.Unlock()
|
||||||
return nil, fmt.Errorf("%w: before registering response channel", ErrWSConnLost)
|
return nil, fmt.Errorf("%w: before registering response channel", c.closeErrOrConnLost())
|
||||||
default:
|
default:
|
||||||
c.respChannels[r.ID] = ch
|
c.respChannels[r.ID] = ch
|
||||||
c.respLock.Unlock()
|
c.respLock.Unlock()
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.readerDone:
|
||||||
return nil, fmt.Errorf("%w: before sending the request", ErrWSConnLost)
|
return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost())
|
||||||
|
case <-c.writerDone:
|
||||||
|
return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost())
|
||||||
case c.requests <- r:
|
case c.requests <- r:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.readerDone:
|
||||||
return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost)
|
return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
|
||||||
|
case <-c.writerDone:
|
||||||
|
return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
|
||||||
case resp, ok := <-ch:
|
case resp, ok := <-ch:
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost)
|
return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
|
||||||
}
|
}
|
||||||
c.unregisterRespChannel(r.ID)
|
c.unregisterRespChannel(r.ID)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
|
|
Loading…
Reference in a new issue