[#327] rpc: Fix mem leak

gRPC stream must be closed by `cancel` to prevent memleak.

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
Dmitrii Stepanov 2025-01-30 10:49:34 +03:00
parent 2786fadb25
commit 593dd77d84
Signed by: dstepanov-yadro
GPG key ID: 237AF1A763293BC0
3 changed files with 40 additions and 12 deletions

View file

@ -12,18 +12,20 @@ import (
// SendUnary initializes communication session by RPC info, performs unary RPC // SendUnary initializes communication session by RPC info, performs unary RPC
// and closes the session. // and closes the session.
func SendUnary(cli *Client, info common.CallMethodInfo, req, resp message.Message, opts ...CallOption) error { func SendUnary(cli *Client, info common.CallMethodInfo, req, resp message.Message, opts ...CallOption) error {
rw, err := cli.Init(info, opts...) rw, err := cli.initInternal(info, opts...)
if err != nil { if err != nil {
return err return err
} }
err = rw.WriteMessage(req) err = rw.WriteMessage(req)
if err != nil { if err != nil {
rw.cancel()
return err return err
} }
err = rw.ReadMessage(resp) err = rw.ReadMessage(resp)
if err != nil { if err != nil {
rw.cancel()
return err return err
} }
@ -38,18 +40,28 @@ type MessageWriterCloser interface {
} }
type clientStreamWriterCloser struct { type clientStreamWriterCloser struct {
MessageReadWriter sw *streamWrapper
resp message.Message resp message.Message
} }
// WriteMessage implements MessageWriterCloser.
func (c *clientStreamWriterCloser) WriteMessage(m message.Message) error {
return c.sw.WriteMessage(m)
}
func (c *clientStreamWriterCloser) Close() error { func (c *clientStreamWriterCloser) Close() error {
err := c.MessageReadWriter.Close() err := c.sw.closeSend()
if err != nil { if err != nil {
c.sw.cancel()
return err return err
} }
return c.ReadMessage(c.resp) if err = c.sw.ReadMessage(c.resp); err != nil {
c.sw.cancel()
return err
}
return c.sw.Close()
} }
// OpenClientStream initializes communication session by RPC info, opens client-side stream // OpenClientStream initializes communication session by RPC info, opens client-side stream
@ -57,14 +69,14 @@ func (c *clientStreamWriterCloser) Close() error {
// //
// All stream writes must be performed before the closing. Close must be called once. // All stream writes must be performed before the closing. Close must be called once.
func OpenClientStream(cli *Client, info common.CallMethodInfo, resp message.Message, opts ...CallOption) (MessageWriterCloser, error) { func OpenClientStream(cli *Client, info common.CallMethodInfo, resp message.Message, opts ...CallOption) (MessageWriterCloser, error) {
rw, err := cli.Init(info, opts...) rw, err := cli.initInternal(info, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &clientStreamWriterCloser{ return &clientStreamWriterCloser{
MessageReadWriter: rw, sw: rw,
resp: resp, resp: resp,
}, nil }, nil
} }
@ -76,7 +88,7 @@ type MessageReaderCloser interface {
} }
type serverStreamReaderCloser struct { type serverStreamReaderCloser struct {
rw MessageReadWriter rw *streamWrapper
once sync.Once once sync.Once
@ -91,11 +103,15 @@ func (s *serverStreamReaderCloser) ReadMessage(msg message.Message) error {
}) })
if err != nil { if err != nil {
s.rw.cancel()
return err return err
} }
err = s.rw.ReadMessage(msg) err = s.rw.ReadMessage(msg)
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
if err != nil {
s.rw.cancel()
}
return err return err
} }
@ -112,7 +128,7 @@ func (s *serverStreamReaderCloser) ReadMessage(msg message.Message) error {
// //
// All stream reads must be performed before the closing. Close must be called once. // All stream reads must be performed before the closing. Close must be called once.
func OpenServerStream(cli *Client, info common.CallMethodInfo, req message.Message, opts ...CallOption) (MessageReader, error) { func OpenServerStream(cli *Client, info common.CallMethodInfo, req message.Message, opts ...CallOption) (MessageReader, error) {
rw, err := cli.Init(info, opts...) rw, err := cli.initInternal(info, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -41,6 +41,10 @@ type MessageReadWriter interface {
// Init initiates a messaging session and returns the interface for message transmitting. // Init initiates a messaging session and returns the interface for message transmitting.
func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageReadWriter, error) { func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageReadWriter, error) {
return c.initInternal(info, opts...)
}
func (c *Client) initInternal(info common.CallMethodInfo, opts ...CallOption) (*streamWrapper, error) {
prm := defaultCallParameters() prm := defaultCallParameters()
for _, opt := range opts { for _, opt := range opts {
@ -52,7 +56,6 @@ func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageRe
} }
ctx, cancel := context.WithCancel(prm.ctx) ctx, cancel := context.WithCancel(prm.ctx)
defer cancel()
// `conn.NewStream` doesn't check if `conn` may turn up invalidated right before this invocation. // `conn.NewStream` doesn't check if `conn` may turn up invalidated right before this invocation.
// In such cases, the operation can hang indefinitely, with the context timeout being the only // In such cases, the operation can hang indefinitely, with the context timeout being the only

View file

@ -34,10 +34,15 @@ func (w streamWrapper) WriteMessage(m message.Message) error {
}) })
} }
func (w *streamWrapper) Close() error { func (w *streamWrapper) closeSend() error {
return w.withTimeout(w.ClientStream.CloseSend) return w.withTimeout(w.ClientStream.CloseSend)
} }
func (w *streamWrapper) Close() error {
w.cancel()
return nil
}
func (w *streamWrapper) withTimeout(closure func() error) error { func (w *streamWrapper) withTimeout(closure func() error) error {
ch := make(chan error, 1) ch := make(chan error, 1)
go func() { go func() {
@ -50,6 +55,10 @@ func (w *streamWrapper) withTimeout(closure func() error) error {
select { select {
case err := <-ch: case err := <-ch:
tt.Stop() tt.Stop()
select {
case <-tt.C:
default:
}
return err return err
case <-tt.C: case <-tt.C:
w.cancel() w.cancel()