424 lines
11 KiB
Go
424 lines
11 KiB
Go
// Copyright (C) 2019 Storj Labs, Inc.
|
|
// See LICENSE for copying information.
|
|
|
|
package drpcstream
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
|
|
"github.com/gogo/protobuf/proto"
|
|
"github.com/zeebo/errs"
|
|
|
|
"storj.io/drpc"
|
|
"storj.io/drpc/drpcdebug"
|
|
"storj.io/drpc/drpcsignal"
|
|
"storj.io/drpc/drpcwire"
|
|
)
|
|
|
|
// Options controls configuration settings for a stream.
|
|
type Options struct {
|
|
// SplitSize controls the default size we split packets into frames.
|
|
SplitSize int
|
|
}
|
|
|
|
// Stream represents an rpc actively happening on a transport.
|
|
type Stream struct {
|
|
ctx context.Context
|
|
cancel func()
|
|
opts Options
|
|
|
|
writeMu chMutex
|
|
id drpcwire.ID
|
|
wr *drpcwire.Writer
|
|
|
|
mu sync.Mutex // protects state transitions
|
|
sigs struct {
|
|
send drpcsignal.Signal // set when done sending messages
|
|
recv drpcsignal.Signal // set when done receiving messages
|
|
term drpcsignal.Signal // set when in terminated state
|
|
finish drpcsignal.Signal // set when all writes are complete
|
|
cancel drpcsignal.Signal // set when externally canceled and transport will be closed
|
|
}
|
|
queue chan drpcwire.Packet
|
|
|
|
// avoids allocations of closures
|
|
pollWriteFn func(drpcwire.Frame) error
|
|
}
|
|
|
|
var _ drpc.Stream = (*Stream)(nil)
|
|
|
|
// New returns a new stream bound to the context with the given stream id and will
|
|
// use the writer to write messages on. It is important use monotonically increasing
|
|
// stream ids within a single transport.
|
|
func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream {
|
|
return NewWithOptions(ctx, sid, wr, Options{})
|
|
}
|
|
|
|
// NewWithOptions returns a new stream bound to the context with the given stream id
|
|
// and will use the writer to write messages on. It is important use monotonically increasing
|
|
// stream ids within a single transport. The options are used to control details of how
|
|
// the Stream operates.
|
|
func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
s := &Stream{
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
opts: opts,
|
|
|
|
wr: wr,
|
|
|
|
id: drpcwire.ID{Stream: sid},
|
|
queue: make(chan drpcwire.Packet),
|
|
}
|
|
|
|
s.pollWriteFn = s.pollWrite
|
|
|
|
return s
|
|
}
|
|
|
|
//
|
|
// monitoring helpers
|
|
//
|
|
|
|
// monCtx returns a copy of the context for use with mon.Task so that there aren't
|
|
// races overwriting the stream's context.
|
|
func (s *Stream) monCtx() *context.Context {
|
|
ctx := s.ctx
|
|
return &ctx
|
|
}
|
|
|
|
//
|
|
// accessors
|
|
//
|
|
|
|
// Context returns the context associated with the stream. It is closed when
|
|
// the Stream will no longer issue any writes or reads.
|
|
func (s *Stream) Context() context.Context { return s.ctx }
|
|
|
|
// Terminated returns a channel when the stream has been terminated.
|
|
func (s *Stream) Terminated() <-chan struct{} { return s.sigs.term.Signal() }
|
|
|
|
// Finished returns true if the stream is fully finished and will no longer
|
|
// issue any writes or reads.
|
|
func (s *Stream) Finished() bool { return s.sigs.finish.IsSet() }
|
|
|
|
//
|
|
// packet handler
|
|
//
|
|
|
|
// HandlePacket advances the stream state machine by inspecting the packet. It returns
|
|
// any major errors that should terminate the transport the stream is operating on as
|
|
// well as a boolean indicating if the stream expects more packets.
|
|
func (s *Stream) HandlePacket(pkt drpcwire.Packet) (more bool, err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
drpcdebug.Log(func() string { return fmt.Sprintf("STR[%p][%d]: %v", s, s.id.Stream, pkt) })
|
|
|
|
if pkt.ID.Stream != s.id.Stream {
|
|
return true, nil
|
|
}
|
|
|
|
switch pkt.Kind {
|
|
case drpcwire.KindInvoke:
|
|
err := drpc.ProtocolError.New("invoke on existing stream")
|
|
s.terminate(err)
|
|
return false, err
|
|
|
|
case drpcwire.KindMessage:
|
|
if s.sigs.recv.IsSet() || s.sigs.term.IsSet() {
|
|
return true, nil
|
|
}
|
|
|
|
// drop the mutex while we either send into the queue or we're told that
|
|
// receiving is done. we don't handle any more packets until the message
|
|
// is delivered, so the only way it can become set is from some of the
|
|
// stream terminating calls, in which case, shutting down the stream is
|
|
// racing with the message being received, so dropping it is valid.
|
|
s.mu.Unlock()
|
|
defer s.mu.Lock()
|
|
|
|
select {
|
|
case <-s.sigs.recv.Signal():
|
|
case <-s.sigs.term.Signal():
|
|
case s.queue <- pkt:
|
|
}
|
|
|
|
return true, nil
|
|
|
|
case drpcwire.KindError:
|
|
err := drpcwire.UnmarshalError(pkt.Data)
|
|
s.sigs.send.Set(io.EOF) // in this state, gRPC returns io.EOF on send.
|
|
s.terminate(err)
|
|
return false, nil
|
|
|
|
case drpcwire.KindClose:
|
|
s.sigs.recv.Set(io.EOF)
|
|
s.terminate(drpc.Error.New("remote closed the stream"))
|
|
return false, nil
|
|
|
|
case drpcwire.KindCloseSend:
|
|
s.sigs.recv.Set(io.EOF)
|
|
s.terminateIfBothClosed()
|
|
return false, nil
|
|
|
|
default:
|
|
err := drpc.InternalError.New("unknown packet kind: %s", pkt.Kind)
|
|
s.terminate(err)
|
|
return false, err
|
|
}
|
|
}
|
|
|
|
//
|
|
// helpers
|
|
//
|
|
|
|
// checkFinished checks to see if the stream is terminated, and if so, sets the finished
|
|
// flag. This must be called every time right before we release the write mutex.
|
|
func (s *Stream) checkFinished() {
|
|
if s.sigs.term.IsSet() {
|
|
s.sigs.finish.Set(nil)
|
|
}
|
|
}
|
|
|
|
// checkCancelError will replace the error with one from the cancel signal if it is
|
|
// set. This is to prevent errors from reads/writes to a transport after it has been
|
|
// asynchronously closed due to context cancelation.
|
|
func (s *Stream) checkCancelError(err error) error {
|
|
if sigErr, ok := s.sigs.cancel.Get(); ok {
|
|
return sigErr
|
|
}
|
|
return err
|
|
}
|
|
|
|
// newPackage bumps the internal message id and returns a packet. It must be called
|
|
// under a mutex.
|
|
func (s *Stream) newPacket(kind drpcwire.Kind, data []byte) drpcwire.Packet {
|
|
s.id.Message++
|
|
return drpcwire.Packet{
|
|
Data: data,
|
|
ID: s.id,
|
|
Kind: kind,
|
|
}
|
|
}
|
|
|
|
// pollWrite checks for any conditions that should cause a write to not happen and
|
|
// then issues the write of the frame.
|
|
func (s *Stream) pollWrite(fr drpcwire.Frame) (err error) {
|
|
switch {
|
|
case s.sigs.send.IsSet():
|
|
return s.sigs.send.Err()
|
|
case s.sigs.term.IsSet():
|
|
return s.sigs.term.Err()
|
|
}
|
|
|
|
return s.checkCancelError(errs.Wrap(s.wr.WriteFrame(fr)))
|
|
}
|
|
|
|
// sendPacket sends the packet in a single write and flushes. It does not check for
|
|
// any conditions to stop it from writing and is meant for internal stream use to
|
|
// do things like signal errors or closes to the remote side.
|
|
func (s *Stream) sendPacket(kind drpcwire.Kind, data []byte) (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
if err := s.wr.WritePacket(s.newPacket(kind, data)); err != nil {
|
|
return errs.Wrap(err)
|
|
}
|
|
if err := s.wr.Flush(); err != nil {
|
|
return errs.Wrap(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// terminateIfBothClosed is a helper to terminate the stream if both sides have
|
|
// issued a CloseSend.
|
|
func (s *Stream) terminateIfBothClosed() {
|
|
if s.sigs.send.IsSet() && s.sigs.recv.IsSet() {
|
|
s.terminate(drpc.Error.New("stream terminated by both issuing close send"))
|
|
}
|
|
}
|
|
|
|
// terminate marks the stream as terminated with the given error. It also marks
|
|
// the stream as finished if no writes are happening at the time of the call.
|
|
func (s *Stream) terminate(err error) {
|
|
s.sigs.send.Set(err)
|
|
s.sigs.recv.Set(err)
|
|
s.sigs.term.Set(err)
|
|
s.cancel()
|
|
|
|
// if we can acquire the write mutex, then checkFinished. if not, then we know
|
|
// some other write is happening, and it will call checkFinished before it
|
|
// releases the mutex.
|
|
if s.writeMu.TryLock() {
|
|
s.checkFinished()
|
|
s.writeMu.Unlock()
|
|
}
|
|
}
|
|
|
|
//
|
|
// raw read/write
|
|
//
|
|
|
|
// RawWrite sends the data bytes with the given kind.
|
|
func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
defer s.checkFinished()
|
|
|
|
return drpcwire.SplitN(s.newPacket(kind, data), s.opts.SplitSize, s.pollWriteFn)
|
|
}
|
|
|
|
// RawFlush flushes any buffers of data.
|
|
func (s *Stream) RawFlush() (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
defer s.checkFinished()
|
|
|
|
return s.checkCancelError(errs.Wrap(s.wr.Flush()))
|
|
}
|
|
|
|
// RawRecv returns the raw bytes received for a message.
|
|
func (s *Stream) RawRecv() (data []byte, err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
if s.sigs.recv.IsSet() {
|
|
return nil, s.sigs.recv.Err()
|
|
}
|
|
|
|
select {
|
|
case <-s.sigs.recv.Signal():
|
|
return nil, s.sigs.recv.Err()
|
|
case pkt := <-s.queue:
|
|
return pkt.Data, nil
|
|
}
|
|
}
|
|
|
|
//
|
|
// msg read/write
|
|
//
|
|
|
|
// MsgSend marshals the message with protobuf, writes it, and flushes.
|
|
func (s *Stream) MsgSend(msg drpc.Message) (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
data, err := proto.Marshal(msg)
|
|
if err != nil {
|
|
return errs.Wrap(err)
|
|
}
|
|
if err := s.RawWrite(drpcwire.KindMessage, data); err != nil {
|
|
return err
|
|
}
|
|
if err := s.RawFlush(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MsgRecv recives some protobuf data and unmarshals it into msg.
|
|
func (s *Stream) MsgRecv(msg drpc.Message) (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
data, err := s.RawRecv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return proto.Unmarshal(data, msg)
|
|
}
|
|
|
|
//
|
|
// terminal messages
|
|
//
|
|
|
|
// SendError terminates the stream and sends the error to the remote. It is a no-op if
|
|
// the stream is already terminated.
|
|
func (s *Stream) SendError(serr error) (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
s.mu.Lock()
|
|
if s.sigs.term.IsSet() {
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
defer s.checkFinished()
|
|
|
|
s.sigs.send.Set(io.EOF) // in this state, gRPC returns io.EOF on send.
|
|
s.terminate(drpc.Error.New("stream terminated by sending error"))
|
|
s.mu.Unlock()
|
|
|
|
return s.checkCancelError(s.sendPacket(drpcwire.KindError, drpcwire.MarshalError(serr)))
|
|
}
|
|
|
|
// Close terminates the stream and sends that the stream has been closed to the remote.
|
|
// It is a no-op if the stream is already terminated.
|
|
func (s *Stream) Close() (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
s.mu.Lock()
|
|
if s.sigs.term.IsSet() {
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
defer s.checkFinished()
|
|
|
|
s.terminate(drpc.Error.New("stream terminated by sending close"))
|
|
s.mu.Unlock()
|
|
|
|
return s.checkCancelError(s.sendPacket(drpcwire.KindClose, nil))
|
|
}
|
|
|
|
// CloseSend informs the remote that no more messages will be sent. If the remote has
|
|
// also already issued a CloseSend, the stream is terminated. It is a no-op if the
|
|
// stream already has sent a CloseSend or if it is terminated.
|
|
func (s *Stream) CloseSend() (err error) {
|
|
defer mon.Task()(s.monCtx())(&err)
|
|
|
|
s.mu.Lock()
|
|
if s.sigs.send.IsSet() || s.sigs.term.IsSet() {
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
defer s.checkFinished()
|
|
|
|
s.sigs.send.Set(drpc.Error.New("send closed"))
|
|
s.terminateIfBothClosed()
|
|
s.mu.Unlock()
|
|
|
|
return s.checkCancelError(s.sendPacket(drpcwire.KindCloseSend, nil))
|
|
}
|
|
|
|
// Cancel transitions the stream into a state where all writes to the transport will return
|
|
// the provided error, and terminates the stream. It is a no-op if the stream is already
|
|
// terminated.
|
|
func (s *Stream) Cancel(err error) {
|
|
defer mon.Task()(s.monCtx())(nil)
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.sigs.term.IsSet() {
|
|
return
|
|
}
|
|
|
|
s.sigs.cancel.Set(err)
|
|
s.sigs.send.Set(io.EOF) // in this state, gRPC returns io.EOF on send.
|
|
s.terminate(err)
|
|
}
|