diff --git a/pkg/rpcclient/local.go b/pkg/rpcclient/local.go index 398631495..d1a3c411e 100644 --- a/pkg/rpcclient/local.go +++ b/pkg/rpcclient/local.go @@ -29,7 +29,8 @@ func NewInternal(ctx context.Context, register InternalHook) (*Internal, error) Client: Client{}, shutdown: make(chan struct{}), - done: make(chan struct{}), + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), subscriptions: make(map[string]notificationReceiver), receivers: make(map[any][]string), }, @@ -63,7 +64,7 @@ eventloop: c.notifySubscribers(ntf) } } - close(c.done) + close(c.readerDone) c.ctxCancel() // ctx is cancelled, server is notified and will finish soon. drainloop: diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index ef3a94e1d..f86b9e9f7 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -58,7 +58,8 @@ type WSClient struct { ws *websocket.Conn wsOpts WSOptions - done chan struct{} + readerDone chan struct{} + writerDone chan struct{} requests chan *neorpc.Request shutdown chan struct{} closeCalled atomic.Bool @@ -425,7 +426,8 @@ func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, err ws: ws, wsOpts: opts, shutdown: make(chan struct{}), - done: make(chan struct{}), + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), respChannels: make(map[uint64]chan *neorpc.Response), requests: make(chan *neorpc.Request), subscriptions: make(map[string]notificationReceiver), @@ -457,7 +459,7 @@ func (c *WSClient) Close() { // Call to cancel will send signal to all users of Context(). c.Client.ctxCancel() } - <-c.done + <-c.readerDone } func (c *WSClient) wsReader() { @@ -551,7 +553,7 @@ readloop: if connCloseErr != nil { c.setCloseErr(connCloseErr) } - close(c.done) + close(c.readerDone) c.respLock.Lock() for _, ch := range c.respChannels { close(ch) @@ -583,13 +585,14 @@ func (c *WSClient) wsWriter() { pingTicker := time.NewTicker(wsPingPeriod) defer c.ws.Close() defer pingTicker.Stop() + defer close(c.writerDone) var connCloseErr error writeloop: for { select { case <-c.shutdown: return - case <-c.done: + case <-c.readerDone: return case req, ok := <-c.requests: if !ok { @@ -660,28 +663,42 @@ func (c *WSClient) getResponseChannel(id uint64) chan *neorpc.Response { return c.respChannels[id] } +// closeErrOrConnLost returns the error that may occur either in wsReader or wsWriter. +// If wsReader or wsWriter do not set the error, it returns ErrWSConnLost. +func (c *WSClient) closeErrOrConnLost() (err error) { + err = ErrWSConnLost + if closeErr := c.GetError(); closeErr != nil { + err = closeErr + } + return +} + func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) { ch := make(chan *neorpc.Response) c.respLock.Lock() select { - case <-c.done: + case <-c.readerDone: c.respLock.Unlock() - return nil, fmt.Errorf("%w: before registering response channel", ErrWSConnLost) + return nil, fmt.Errorf("%w: before registering response channel", c.closeErrOrConnLost()) default: c.respChannels[r.ID] = ch c.respLock.Unlock() } select { - case <-c.done: - return nil, fmt.Errorf("%w: before sending the request", ErrWSConnLost) + case <-c.readerDone: + return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost()) + case <-c.writerDone: + return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost()) case c.requests <- r: } select { - case <-c.done: - return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost) + case <-c.readerDone: + return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost()) + case <-c.writerDone: + return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost()) case resp, ok := <-ch: if !ok { - return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost) + return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost()) } c.unregisterRespChannel(r.ID) return resp, nil