From 593dd77d841aa6652377d3755684d0a968e25fff Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Thu, 30 Jan 2025 10:49:34 +0300 Subject: [PATCH] [#327] rpc: Fix mem leak gRPC stream must be closed by `cancel` to prevent memleak. Signed-off-by: Dmitrii Stepanov --- api/rpc/client/flows.go | 36 +++++++++++++++++++++++--------- api/rpc/client/init.go | 5 ++++- api/rpc/client/stream_wrapper.go | 11 +++++++++- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/api/rpc/client/flows.go b/api/rpc/client/flows.go index 671c679..2a945b4 100644 --- a/api/rpc/client/flows.go +++ b/api/rpc/client/flows.go @@ -12,18 +12,20 @@ import ( // SendUnary initializes communication session by RPC info, performs unary RPC // and closes the session. func SendUnary(cli *Client, info common.CallMethodInfo, req, resp message.Message, opts ...CallOption) error { - rw, err := cli.Init(info, opts...) + rw, err := cli.initInternal(info, opts...) if err != nil { return err } err = rw.WriteMessage(req) if err != nil { + rw.cancel() return err } err = rw.ReadMessage(resp) if err != nil { + rw.cancel() return err } @@ -38,18 +40,28 @@ type MessageWriterCloser interface { } type clientStreamWriterCloser struct { - MessageReadWriter - + sw *streamWrapper resp message.Message } +// WriteMessage implements MessageWriterCloser. +func (c *clientStreamWriterCloser) WriteMessage(m message.Message) error { + return c.sw.WriteMessage(m) +} + func (c *clientStreamWriterCloser) Close() error { - err := c.MessageReadWriter.Close() + err := c.sw.closeSend() if err != nil { + c.sw.cancel() return err } - return c.ReadMessage(c.resp) + if err = c.sw.ReadMessage(c.resp); err != nil { + c.sw.cancel() + return err + } + + return c.sw.Close() } // OpenClientStream initializes communication session by RPC info, opens client-side stream @@ -57,14 +69,14 @@ func (c *clientStreamWriterCloser) Close() error { // // All stream writes must be performed before the closing. Close must be called once. func OpenClientStream(cli *Client, info common.CallMethodInfo, resp message.Message, opts ...CallOption) (MessageWriterCloser, error) { - rw, err := cli.Init(info, opts...) + rw, err := cli.initInternal(info, opts...) if err != nil { return nil, err } return &clientStreamWriterCloser{ - MessageReadWriter: rw, - resp: resp, + sw: rw, + resp: resp, }, nil } @@ -76,7 +88,7 @@ type MessageReaderCloser interface { } type serverStreamReaderCloser struct { - rw MessageReadWriter + rw *streamWrapper once sync.Once @@ -91,11 +103,15 @@ func (s *serverStreamReaderCloser) ReadMessage(msg message.Message) error { }) if err != nil { + s.rw.cancel() return err } err = s.rw.ReadMessage(msg) if !errors.Is(err, io.EOF) { + if err != nil { + s.rw.cancel() + } return err } @@ -112,7 +128,7 @@ func (s *serverStreamReaderCloser) ReadMessage(msg message.Message) error { // // All stream reads must be performed before the closing. Close must be called once. func OpenServerStream(cli *Client, info common.CallMethodInfo, req message.Message, opts ...CallOption) (MessageReader, error) { - rw, err := cli.Init(info, opts...) + rw, err := cli.initInternal(info, opts...) if err != nil { return nil, err } diff --git a/api/rpc/client/init.go b/api/rpc/client/init.go index 4edfd0b..08a9925 100644 --- a/api/rpc/client/init.go +++ b/api/rpc/client/init.go @@ -41,6 +41,10 @@ type MessageReadWriter interface { // Init initiates a messaging session and returns the interface for message transmitting. func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageReadWriter, error) { + return c.initInternal(info, opts...) +} + +func (c *Client) initInternal(info common.CallMethodInfo, opts ...CallOption) (*streamWrapper, error) { prm := defaultCallParameters() for _, opt := range opts { @@ -52,7 +56,6 @@ func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageRe } ctx, cancel := context.WithCancel(prm.ctx) - defer cancel() // `conn.NewStream` doesn't check if `conn` may turn up invalidated right before this invocation. // In such cases, the operation can hang indefinitely, with the context timeout being the only diff --git a/api/rpc/client/stream_wrapper.go b/api/rpc/client/stream_wrapper.go index 4c7bb1f..85d5ad5 100644 --- a/api/rpc/client/stream_wrapper.go +++ b/api/rpc/client/stream_wrapper.go @@ -34,10 +34,15 @@ func (w streamWrapper) WriteMessage(m message.Message) error { }) } -func (w *streamWrapper) Close() error { +func (w *streamWrapper) closeSend() error { return w.withTimeout(w.ClientStream.CloseSend) } +func (w *streamWrapper) Close() error { + w.cancel() + return nil +} + func (w *streamWrapper) withTimeout(closure func() error) error { ch := make(chan error, 1) go func() { @@ -50,6 +55,10 @@ func (w *streamWrapper) withTimeout(closure func() error) error { select { case err := <-ch: tt.Stop() + select { + case <-tt.C: + default: + } return err case <-tt.C: w.cancel()