forked from TrueCloudLab/frostfs-sdk-go
[#327] rpc: Fix mem leak
gRPC stream must be closed by `cancel` to prevent memleak. Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
parent
2786fadb25
commit
593dd77d84
3 changed files with 40 additions and 12 deletions
|
@ -12,18 +12,20 @@ import (
|
||||||
// SendUnary initializes communication session by RPC info, performs unary RPC
|
// SendUnary initializes communication session by RPC info, performs unary RPC
|
||||||
// and closes the session.
|
// and closes the session.
|
||||||
func SendUnary(cli *Client, info common.CallMethodInfo, req, resp message.Message, opts ...CallOption) error {
|
func SendUnary(cli *Client, info common.CallMethodInfo, req, resp message.Message, opts ...CallOption) error {
|
||||||
rw, err := cli.Init(info, opts...)
|
rw, err := cli.initInternal(info, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rw.WriteMessage(req)
|
err = rw.WriteMessage(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
rw.cancel()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rw.ReadMessage(resp)
|
err = rw.ReadMessage(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
rw.cancel()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,18 +40,28 @@ type MessageWriterCloser interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientStreamWriterCloser struct {
|
type clientStreamWriterCloser struct {
|
||||||
MessageReadWriter
|
sw *streamWrapper
|
||||||
|
|
||||||
resp message.Message
|
resp message.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteMessage implements MessageWriterCloser.
|
||||||
|
func (c *clientStreamWriterCloser) WriteMessage(m message.Message) error {
|
||||||
|
return c.sw.WriteMessage(m)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *clientStreamWriterCloser) Close() error {
|
func (c *clientStreamWriterCloser) Close() error {
|
||||||
err := c.MessageReadWriter.Close()
|
err := c.sw.closeSend()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.sw.cancel()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.ReadMessage(c.resp)
|
if err = c.sw.ReadMessage(c.resp); err != nil {
|
||||||
|
c.sw.cancel()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.sw.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenClientStream initializes communication session by RPC info, opens client-side stream
|
// OpenClientStream initializes communication session by RPC info, opens client-side stream
|
||||||
|
@ -57,14 +69,14 @@ func (c *clientStreamWriterCloser) Close() error {
|
||||||
//
|
//
|
||||||
// All stream writes must be performed before the closing. Close must be called once.
|
// All stream writes must be performed before the closing. Close must be called once.
|
||||||
func OpenClientStream(cli *Client, info common.CallMethodInfo, resp message.Message, opts ...CallOption) (MessageWriterCloser, error) {
|
func OpenClientStream(cli *Client, info common.CallMethodInfo, resp message.Message, opts ...CallOption) (MessageWriterCloser, error) {
|
||||||
rw, err := cli.Init(info, opts...)
|
rw, err := cli.initInternal(info, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &clientStreamWriterCloser{
|
return &clientStreamWriterCloser{
|
||||||
MessageReadWriter: rw,
|
sw: rw,
|
||||||
resp: resp,
|
resp: resp,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +88,7 @@ type MessageReaderCloser interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverStreamReaderCloser struct {
|
type serverStreamReaderCloser struct {
|
||||||
rw MessageReadWriter
|
rw *streamWrapper
|
||||||
|
|
||||||
once sync.Once
|
once sync.Once
|
||||||
|
|
||||||
|
@ -91,11 +103,15 @@ func (s *serverStreamReaderCloser) ReadMessage(msg message.Message) error {
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.rw.cancel()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.rw.ReadMessage(msg)
|
err = s.rw.ReadMessage(msg)
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
|
if err != nil {
|
||||||
|
s.rw.cancel()
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +128,7 @@ func (s *serverStreamReaderCloser) ReadMessage(msg message.Message) error {
|
||||||
//
|
//
|
||||||
// All stream reads must be performed before the closing. Close must be called once.
|
// All stream reads must be performed before the closing. Close must be called once.
|
||||||
func OpenServerStream(cli *Client, info common.CallMethodInfo, req message.Message, opts ...CallOption) (MessageReader, error) {
|
func OpenServerStream(cli *Client, info common.CallMethodInfo, req message.Message, opts ...CallOption) (MessageReader, error) {
|
||||||
rw, err := cli.Init(info, opts...)
|
rw, err := cli.initInternal(info, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,10 @@ type MessageReadWriter interface {
|
||||||
|
|
||||||
// Init initiates a messaging session and returns the interface for message transmitting.
|
// Init initiates a messaging session and returns the interface for message transmitting.
|
||||||
func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageReadWriter, error) {
|
func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageReadWriter, error) {
|
||||||
|
return c.initInternal(info, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) initInternal(info common.CallMethodInfo, opts ...CallOption) (*streamWrapper, error) {
|
||||||
prm := defaultCallParameters()
|
prm := defaultCallParameters()
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
@ -52,7 +56,6 @@ func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageRe
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(prm.ctx)
|
ctx, cancel := context.WithCancel(prm.ctx)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// `conn.NewStream` doesn't check if `conn` may turn up invalidated right before this invocation.
|
// `conn.NewStream` doesn't check if `conn` may turn up invalidated right before this invocation.
|
||||||
// In such cases, the operation can hang indefinitely, with the context timeout being the only
|
// In such cases, the operation can hang indefinitely, with the context timeout being the only
|
||||||
|
|
|
@ -34,10 +34,15 @@ func (w streamWrapper) WriteMessage(m message.Message) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *streamWrapper) Close() error {
|
func (w *streamWrapper) closeSend() error {
|
||||||
return w.withTimeout(w.ClientStream.CloseSend)
|
return w.withTimeout(w.ClientStream.CloseSend)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *streamWrapper) Close() error {
|
||||||
|
w.cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *streamWrapper) withTimeout(closure func() error) error {
|
func (w *streamWrapper) withTimeout(closure func() error) error {
|
||||||
ch := make(chan error, 1)
|
ch := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -50,6 +55,10 @@ func (w *streamWrapper) withTimeout(closure func() error) error {
|
||||||
select {
|
select {
|
||||||
case err := <-ch:
|
case err := <-ch:
|
||||||
tt.Stop()
|
tt.Stop()
|
||||||
|
select {
|
||||||
|
case <-tt.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
case <-tt.C:
|
case <-tt.C:
|
||||||
w.cancel()
|
w.cancel()
|
||||||
|
|
Loading…
Add table
Reference in a new issue