diff --git a/api/rpc/client/flows.go b/api/rpc/client/flows.go index 671c679..9728cb6 100644 --- a/api/rpc/client/flows.go +++ b/api/rpc/client/flows.go @@ -38,18 +38,26 @@ type MessageWriterCloser interface { } type clientStreamWriterCloser struct { - MessageReadWriter - + sw *streamWrapper resp message.Message } +// WriteMessage implements MessageWriterCloser. +func (c *clientStreamWriterCloser) WriteMessage(m message.Message) error { + return c.sw.WriteMessage(m) +} + func (c *clientStreamWriterCloser) Close() error { - err := c.MessageReadWriter.Close() + err := c.sw.closeSend() if err != nil { return err } - return c.ReadMessage(c.resp) + if err = c.sw.ReadMessage(c.resp); err != nil { + return err + } + + return c.sw.Close() } // OpenClientStream initializes communication session by RPC info, opens client-side stream @@ -57,14 +65,14 @@ func (c *clientStreamWriterCloser) Close() error { // // All stream writes must be performed before the closing. Close must be called once. func OpenClientStream(cli *Client, info common.CallMethodInfo, resp message.Message, opts ...CallOption) (MessageWriterCloser, error) { - rw, err := cli.Init(info, opts...) + rw, err := cli.initInternal(info, opts...) if err != nil { return nil, err } return &clientStreamWriterCloser{ - MessageReadWriter: rw, - resp: resp, + sw: rw, + resp: resp, }, nil } @@ -112,7 +120,7 @@ func (s *serverStreamReaderCloser) ReadMessage(msg message.Message) error { // // All stream reads must be performed before the closing. Close must be called once. func OpenServerStream(cli *Client, info common.CallMethodInfo, req message.Message, opts ...CallOption) (MessageReader, error) { - rw, err := cli.Init(info, opts...) + rw, err := cli.initInternal(info, opts...) if err != nil { return nil, err } diff --git a/api/rpc/client/init.go b/api/rpc/client/init.go index 4edfd0b..08a9925 100644 --- a/api/rpc/client/init.go +++ b/api/rpc/client/init.go @@ -41,6 +41,10 @@ type MessageReadWriter interface { // Init initiates a messaging session and returns the interface for message transmitting. func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageReadWriter, error) { + return c.initInternal(info, opts...) +} + +func (c *Client) initInternal(info common.CallMethodInfo, opts ...CallOption) (*streamWrapper, error) { prm := defaultCallParameters() for _, opt := range opts { @@ -52,7 +56,6 @@ func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageRe } ctx, cancel := context.WithCancel(prm.ctx) - defer cancel() // `conn.NewStream` doesn't check if `conn` may turn up invalidated right before this invocation. // In such cases, the operation can hang indefinitely, with the context timeout being the only diff --git a/api/rpc/client/stream_wrapper.go b/api/rpc/client/stream_wrapper.go index 4c7bb1f..6f13ec2 100644 --- a/api/rpc/client/stream_wrapper.go +++ b/api/rpc/client/stream_wrapper.go @@ -2,6 +2,7 @@ package client import ( "context" + "sync" "time" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/rpc/message" @@ -10,11 +11,12 @@ import ( type streamWrapper struct { grpc.ClientStream - timeout time.Duration - cancel context.CancelFunc + timeout time.Duration + cancel context.CancelFunc + closeSendOnce sync.Once } -func (w streamWrapper) ReadMessage(m message.Message) error { +func (w *streamWrapper) ReadMessage(m message.Message) error { // Can be optimized: we can create blank message here. gm := m.ToGRPCMessage() @@ -28,14 +30,26 @@ func (w streamWrapper) ReadMessage(m message.Message) error { return m.FromGRPCMessage(gm) } -func (w streamWrapper) WriteMessage(m message.Message) error { +func (w *streamWrapper) WriteMessage(m message.Message) error { return w.withTimeout(func() error { return w.ClientStream.SendMsg(m.ToGRPCMessage()) }) } +func (w *streamWrapper) closeSend() error { + var err error + w.closeSendOnce.Do( + func() { + err = w.withTimeout(w.ClientStream.CloseSend) + }, + ) + return err +} + func (w *streamWrapper) Close() error { - return w.withTimeout(w.ClientStream.CloseSend) + err := w.closeSend() + w.cancel() + return err } func (w *streamWrapper) withTimeout(closure func() error) error { @@ -50,6 +64,10 @@ func (w *streamWrapper) withTimeout(closure func() error) error { select { case err := <-ch: tt.Stop() + select { + case <-tt.C: + default: + } return err case <-tt.C: w.cancel()