rclone/vendor/storj.io/drpc/drpcmanager/manager.go

336 lines
8.6 KiB
Go
Raw Normal View History

2020-05-11 18:57:46 +00:00
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package drpcmanager
import (
"context"
"fmt"
"sync"
"github.com/zeebo/errs"
"storj.io/drpc"
"storj.io/drpc/drpcctx"
"storj.io/drpc/drpcdebug"
"storj.io/drpc/drpcmetadata"
"storj.io/drpc/drpcsignal"
"storj.io/drpc/drpcstream"
"storj.io/drpc/drpcwire"
)
var managerClosed = errs.New("manager closed")
// Options controls configuration settings for a manager.
type Options struct {
// WriterBufferSize controls the size of the buffer that we will fill before
// flushing. Normal writes to streams typically issue a flush explicitly.
WriterBufferSize int
// Stream are passed to any streams the manager creates.
Stream drpcstream.Options
}
// Manager handles the logic of managing a transport for a drpc client or server.
// It ensures that the connection is always being read from, that it is closed
// in the case that the manager is and forwarding drpc protocol messages to the
// appropriate stream.
type Manager struct {
tr drpc.Transport
wr *drpcwire.Writer
rd *drpcwire.Reader
opts Options
once sync.Once
sid uint64
sem chan struct{}
term drpcsignal.Signal // set when the manager should start terminating
read drpcsignal.Signal // set after the goroutine reading from the transport is done
tport drpcsignal.Signal // set after the transport has been closed
queue chan drpcwire.Packet
ctx context.Context
}
// New returns a new Manager for the transport.
func New(tr drpc.Transport) *Manager {
return NewWithOptions(tr, Options{})
}
// NewWithOptions returns a new manager for the transport. It uses the provided
// options to manage details of how it uses it.
func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
m := &Manager{
tr: tr,
wr: drpcwire.NewWriter(tr, opts.WriterBufferSize),
rd: drpcwire.NewReader(tr),
opts: opts,
// this semaphore controls the number of concurrent streams. it MUST be 1.
sem: make(chan struct{}, 1),
queue: make(chan drpcwire.Packet),
ctx: drpcctx.WithTransport(context.Background(), tr),
}
go m.manageTransport()
go m.manageReader()
return m
}
//
// helpers
//
// poll checks if a channel is immediately ready.
func poll(ch <-chan struct{}) bool {
select {
case <-ch:
return true
default:
return false
}
}
// poll checks if the context is canceled or the manager is terminated.
func (m *Manager) poll(ctx context.Context) error {
switch {
case poll(ctx.Done()):
return ctx.Err()
case poll(m.term.Signal()):
return m.term.Err()
default:
return nil
}
}
// acquireSemaphore attempts to acquire the semaphore protecting streams. If the
// context is canceled or the manager is terminated, it returns an error.
func (m *Manager) acquireSemaphore(ctx context.Context) error {
if err := m.poll(ctx); err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-m.term.Signal():
return m.term.Err()
case m.sem <- struct{}{}:
return nil
}
}
//
// exported interface
//
// Closed returns if the manager has been closed.
func (m *Manager) Closed() bool {
return m.term.IsSet()
}
// Close closes the transport the manager is using.
func (m *Manager) Close() error {
// when closing, we set the manager terminated signal, wait for the goroutine
// managing the transport to notice and close it, acquire the semaphore to ensure
// there are streams running, then wait for the goroutine reading packets to be done.
// we protect it with a once to ensure both that we only do this once, and that
// concurrent calls are sure that it has fully executed.
m.once.Do(func() {
m.term.Set(managerClosed)
<-m.tport.Signal()
m.sem <- struct{}{}
<-m.read.Signal()
})
return m.tport.Err()
}
// NewClientStream starts a stream on the managed transport for use by a client.
func (m *Manager) NewClientStream(ctx context.Context) (stream *drpcstream.Stream, err error) {
if err := m.acquireSemaphore(ctx); err != nil {
return nil, err
}
m.sid++
stream = drpcstream.NewWithOptions(m.ctx, m.sid, m.wr, m.opts.Stream)
go m.manageStream(ctx, stream)
return stream, nil
}
// NewServerStream starts a stream on the managed transport for use by a server. It does
// this by waiting for the client to issue an invoke message and returning the details.
func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) {
if err := m.acquireSemaphore(ctx); err != nil {
return nil, "", err
}
var metadata drpcwire.Packet
for {
select {
case <-ctx.Done():
<-m.sem
return nil, "", ctx.Err()
case <-m.term.Signal():
<-m.sem
return nil, "", m.term.Err()
case pkt := <-m.queue:
switch pkt.Kind {
case drpcwire.KindInvokeMetadata:
// keep track of any metadata being sent before an invoke so that we can
// include it if the stream id matches the eventual invoke.
metadata = pkt
continue
case drpcwire.KindInvoke:
streamCtx := m.ctx
if metadata.ID.Stream == pkt.ID.Stream {
md, err := drpcmetadata.Decode(metadata.Data)
if err != nil {
return nil, "", err
}
streamCtx = drpcmetadata.AddPairs(streamCtx, md)
}
stream = drpcstream.NewWithOptions(streamCtx, pkt.ID.Stream, m.wr, m.opts.Stream)
go m.manageStream(ctx, stream)
return stream, string(pkt.Data), nil
default:
// we ignore packets that arent invokes because perhaps older streams have
// messages in the queue sent concurrently with our notification to them
// that the stream they were sent for is done.
continue
}
}
}
}
//
// manage transport
//
// manageTransport ensures that if the manager's term signal is ever set, then
// the underlying transport is closed and the error is recorded.
func (m *Manager) manageTransport() {
defer mon.Task()(nil)(nil)
<-m.term.Signal()
m.tport.Set(m.tr.Close())
}
//
// manage reader
//
// manageReader is always reading a packet and sending it into the queue of packets
// the manager has. It sets the read signal when it exits so that one can wait to
// ensure that no one is reading on the reader. It sets the term signal if there is
// any error reading packets.
func (m *Manager) manageReader() {
defer mon.Task()(nil)(nil)
defer m.read.Set(managerClosed)
for {
pkt, err := m.rd.ReadPacket()
if err != nil {
m.term.Set(errs.Wrap(err))
return
}
drpcdebug.Log(func() string { return fmt.Sprintf("MAN[%p]: %v", m, pkt) })
select {
case <-m.term.Signal():
return
case m.queue <- pkt:
}
}
}
//
// manage stream
//
// manageStream watches the context and the stream and returns when the stream is
// finished, canceling the stream if the context is canceled.
func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) {
defer mon.Task()(nil)(nil)
// create a wait group, launch the workers, and wait for them
wg := new(sync.WaitGroup)
wg.Add(2)
go m.manageStreamPackets(wg, stream)
go m.manageStreamContext(ctx, wg, stream)
wg.Wait()
// always ensure the stream is terminated if we're done managing it. the
// stream should already be in a terminal state unless we're exiting due
// to the manager terminating. that only happens if the underlying transport
// died, so just assume the remote end issued a cancel by terminating
// the transport.
stream.Cancel(context.Canceled)
// release semaphore
<-m.sem
}
// manageStreamPackets repeatedly reads from the queue of packets and asks the stream to
// handle them. If there is an error handling a packet, that is considered to
// be fatal to the manager, so we set term. HandlePacket also returns a bool to
// indicate that the stream requires no more packets, and so manageStream can
// just exit. It releases the semaphore whenever it exits.
func (m *Manager) manageStreamPackets(wg *sync.WaitGroup, stream *drpcstream.Stream) {
defer mon.Task()(nil)(nil)
defer wg.Done()
for {
select {
case <-m.term.Signal():
return
case <-stream.Terminated():
return
case pkt := <-m.queue:
drpcdebug.Log(func() string { return fmt.Sprintf("FWD[%p][%p]: %v", m, stream, pkt) })
ok, err := stream.HandlePacket(pkt)
if err != nil {
m.term.Set(errs.Wrap(err))
return
} else if !ok {
return
}
}
}
}
// manageStreamContext ensures that if the stream context is canceled, we inform the stream and
// possibly abort the underlying transport if the stream isn't finished.
func (m *Manager) manageStreamContext(ctx context.Context, wg *sync.WaitGroup, stream *drpcstream.Stream) {
defer mon.Task()(nil)(nil)
defer wg.Done()
select {
case <-m.term.Signal():
return
case <-stream.Terminated():
return
case <-ctx.Done():
stream.Cancel(ctx.Err())
if !stream.Finished() {
m.term.Set(ctx.Err())
}
}
}