336 lines
8.6 KiB
Go
336 lines
8.6 KiB
Go
|
// Copyright (C) 2019 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package drpcmanager
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/zeebo/errs"
|
||
|
|
||
|
"storj.io/drpc"
|
||
|
"storj.io/drpc/drpcctx"
|
||
|
"storj.io/drpc/drpcdebug"
|
||
|
"storj.io/drpc/drpcmetadata"
|
||
|
"storj.io/drpc/drpcsignal"
|
||
|
"storj.io/drpc/drpcstream"
|
||
|
"storj.io/drpc/drpcwire"
|
||
|
)
|
||
|
|
||
|
var managerClosed = errs.New("manager closed")
|
||
|
|
||
|
// Options controls configuration settings for a manager.
|
||
|
type Options struct {
|
||
|
// WriterBufferSize controls the size of the buffer that we will fill before
|
||
|
// flushing. Normal writes to streams typically issue a flush explicitly.
|
||
|
WriterBufferSize int
|
||
|
|
||
|
// Stream are passed to any streams the manager creates.
|
||
|
Stream drpcstream.Options
|
||
|
}
|
||
|
|
||
|
// Manager handles the logic of managing a transport for a drpc client or server.
|
||
|
// It ensures that the connection is always being read from, that it is closed
|
||
|
// in the case that the manager is and forwarding drpc protocol messages to the
|
||
|
// appropriate stream.
|
||
|
type Manager struct {
|
||
|
tr drpc.Transport
|
||
|
wr *drpcwire.Writer
|
||
|
rd *drpcwire.Reader
|
||
|
opts Options
|
||
|
|
||
|
once sync.Once
|
||
|
|
||
|
sid uint64
|
||
|
sem chan struct{}
|
||
|
term drpcsignal.Signal // set when the manager should start terminating
|
||
|
read drpcsignal.Signal // set after the goroutine reading from the transport is done
|
||
|
tport drpcsignal.Signal // set after the transport has been closed
|
||
|
queue chan drpcwire.Packet
|
||
|
ctx context.Context
|
||
|
}
|
||
|
|
||
|
// New returns a new Manager for the transport.
|
||
|
func New(tr drpc.Transport) *Manager {
|
||
|
return NewWithOptions(tr, Options{})
|
||
|
}
|
||
|
|
||
|
// NewWithOptions returns a new manager for the transport. It uses the provided
|
||
|
// options to manage details of how it uses it.
|
||
|
func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
|
||
|
m := &Manager{
|
||
|
tr: tr,
|
||
|
wr: drpcwire.NewWriter(tr, opts.WriterBufferSize),
|
||
|
rd: drpcwire.NewReader(tr),
|
||
|
opts: opts,
|
||
|
|
||
|
// this semaphore controls the number of concurrent streams. it MUST be 1.
|
||
|
sem: make(chan struct{}, 1),
|
||
|
queue: make(chan drpcwire.Packet),
|
||
|
ctx: drpcctx.WithTransport(context.Background(), tr),
|
||
|
}
|
||
|
|
||
|
go m.manageTransport()
|
||
|
go m.manageReader()
|
||
|
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
//
|
||
|
// helpers
|
||
|
//
|
||
|
|
||
|
// poll checks if a channel is immediately ready.
|
||
|
func poll(ch <-chan struct{}) bool {
|
||
|
select {
|
||
|
case <-ch:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// poll checks if the context is canceled or the manager is terminated.
|
||
|
func (m *Manager) poll(ctx context.Context) error {
|
||
|
switch {
|
||
|
case poll(ctx.Done()):
|
||
|
return ctx.Err()
|
||
|
|
||
|
case poll(m.term.Signal()):
|
||
|
return m.term.Err()
|
||
|
|
||
|
default:
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// acquireSemaphore attempts to acquire the semaphore protecting streams. If the
|
||
|
// context is canceled or the manager is terminated, it returns an error.
|
||
|
func (m *Manager) acquireSemaphore(ctx context.Context) error {
|
||
|
if err := m.poll(ctx); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
|
||
|
case <-m.term.Signal():
|
||
|
return m.term.Err()
|
||
|
|
||
|
case m.sem <- struct{}{}:
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//
|
||
|
// exported interface
|
||
|
//
|
||
|
|
||
|
// Closed returns if the manager has been closed.
|
||
|
func (m *Manager) Closed() bool {
|
||
|
return m.term.IsSet()
|
||
|
}
|
||
|
|
||
|
// Close closes the transport the manager is using.
|
||
|
func (m *Manager) Close() error {
|
||
|
// when closing, we set the manager terminated signal, wait for the goroutine
|
||
|
// managing the transport to notice and close it, acquire the semaphore to ensure
|
||
|
// there are streams running, then wait for the goroutine reading packets to be done.
|
||
|
// we protect it with a once to ensure both that we only do this once, and that
|
||
|
// concurrent calls are sure that it has fully executed.
|
||
|
|
||
|
m.once.Do(func() {
|
||
|
m.term.Set(managerClosed)
|
||
|
<-m.tport.Signal()
|
||
|
m.sem <- struct{}{}
|
||
|
<-m.read.Signal()
|
||
|
})
|
||
|
|
||
|
return m.tport.Err()
|
||
|
}
|
||
|
|
||
|
// NewClientStream starts a stream on the managed transport for use by a client.
|
||
|
func (m *Manager) NewClientStream(ctx context.Context) (stream *drpcstream.Stream, err error) {
|
||
|
if err := m.acquireSemaphore(ctx); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
m.sid++
|
||
|
stream = drpcstream.NewWithOptions(m.ctx, m.sid, m.wr, m.opts.Stream)
|
||
|
go m.manageStream(ctx, stream)
|
||
|
return stream, nil
|
||
|
}
|
||
|
|
||
|
// NewServerStream starts a stream on the managed transport for use by a server. It does
|
||
|
// this by waiting for the client to issue an invoke message and returning the details.
|
||
|
func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) {
|
||
|
if err := m.acquireSemaphore(ctx); err != nil {
|
||
|
return nil, "", err
|
||
|
}
|
||
|
|
||
|
var metadata drpcwire.Packet
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
<-m.sem
|
||
|
return nil, "", ctx.Err()
|
||
|
|
||
|
case <-m.term.Signal():
|
||
|
<-m.sem
|
||
|
return nil, "", m.term.Err()
|
||
|
|
||
|
case pkt := <-m.queue:
|
||
|
switch pkt.Kind {
|
||
|
case drpcwire.KindInvokeMetadata:
|
||
|
// keep track of any metadata being sent before an invoke so that we can
|
||
|
// include it if the stream id matches the eventual invoke.
|
||
|
metadata = pkt
|
||
|
continue
|
||
|
|
||
|
case drpcwire.KindInvoke:
|
||
|
streamCtx := m.ctx
|
||
|
if metadata.ID.Stream == pkt.ID.Stream {
|
||
|
md, err := drpcmetadata.Decode(metadata.Data)
|
||
|
if err != nil {
|
||
|
return nil, "", err
|
||
|
}
|
||
|
streamCtx = drpcmetadata.AddPairs(streamCtx, md)
|
||
|
}
|
||
|
|
||
|
stream = drpcstream.NewWithOptions(streamCtx, pkt.ID.Stream, m.wr, m.opts.Stream)
|
||
|
go m.manageStream(ctx, stream)
|
||
|
return stream, string(pkt.Data), nil
|
||
|
default:
|
||
|
// we ignore packets that arent invokes because perhaps older streams have
|
||
|
// messages in the queue sent concurrently with our notification to them
|
||
|
// that the stream they were sent for is done.
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//
|
||
|
// manage transport
|
||
|
//
|
||
|
|
||
|
// manageTransport ensures that if the manager's term signal is ever set, then
|
||
|
// the underlying transport is closed and the error is recorded.
|
||
|
func (m *Manager) manageTransport() {
|
||
|
defer mon.Task()(nil)(nil)
|
||
|
<-m.term.Signal()
|
||
|
m.tport.Set(m.tr.Close())
|
||
|
}
|
||
|
|
||
|
//
|
||
|
// manage reader
|
||
|
//
|
||
|
|
||
|
// manageReader is always reading a packet and sending it into the queue of packets
|
||
|
// the manager has. It sets the read signal when it exits so that one can wait to
|
||
|
// ensure that no one is reading on the reader. It sets the term signal if there is
|
||
|
// any error reading packets.
|
||
|
func (m *Manager) manageReader() {
|
||
|
defer mon.Task()(nil)(nil)
|
||
|
defer m.read.Set(managerClosed)
|
||
|
|
||
|
for {
|
||
|
pkt, err := m.rd.ReadPacket()
|
||
|
if err != nil {
|
||
|
m.term.Set(errs.Wrap(err))
|
||
|
return
|
||
|
}
|
||
|
|
||
|
drpcdebug.Log(func() string { return fmt.Sprintf("MAN[%p]: %v", m, pkt) })
|
||
|
|
||
|
select {
|
||
|
case <-m.term.Signal():
|
||
|
return
|
||
|
|
||
|
case m.queue <- pkt:
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//
|
||
|
// manage stream
|
||
|
//
|
||
|
|
||
|
// manageStream watches the context and the stream and returns when the stream is
|
||
|
// finished, canceling the stream if the context is canceled.
|
||
|
func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) {
|
||
|
defer mon.Task()(nil)(nil)
|
||
|
|
||
|
// create a wait group, launch the workers, and wait for them
|
||
|
wg := new(sync.WaitGroup)
|
||
|
wg.Add(2)
|
||
|
go m.manageStreamPackets(wg, stream)
|
||
|
go m.manageStreamContext(ctx, wg, stream)
|
||
|
wg.Wait()
|
||
|
|
||
|
// always ensure the stream is terminated if we're done managing it. the
|
||
|
// stream should already be in a terminal state unless we're exiting due
|
||
|
// to the manager terminating. that only happens if the underlying transport
|
||
|
// died, so just assume the remote end issued a cancel by terminating
|
||
|
// the transport.
|
||
|
stream.Cancel(context.Canceled)
|
||
|
|
||
|
// release semaphore
|
||
|
<-m.sem
|
||
|
}
|
||
|
|
||
|
// manageStreamPackets repeatedly reads from the queue of packets and asks the stream to
|
||
|
// handle them. If there is an error handling a packet, that is considered to
|
||
|
// be fatal to the manager, so we set term. HandlePacket also returns a bool to
|
||
|
// indicate that the stream requires no more packets, and so manageStream can
|
||
|
// just exit. It releases the semaphore whenever it exits.
|
||
|
func (m *Manager) manageStreamPackets(wg *sync.WaitGroup, stream *drpcstream.Stream) {
|
||
|
defer mon.Task()(nil)(nil)
|
||
|
defer wg.Done()
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-m.term.Signal():
|
||
|
return
|
||
|
|
||
|
case <-stream.Terminated():
|
||
|
return
|
||
|
|
||
|
case pkt := <-m.queue:
|
||
|
drpcdebug.Log(func() string { return fmt.Sprintf("FWD[%p][%p]: %v", m, stream, pkt) })
|
||
|
|
||
|
ok, err := stream.HandlePacket(pkt)
|
||
|
if err != nil {
|
||
|
m.term.Set(errs.Wrap(err))
|
||
|
return
|
||
|
} else if !ok {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// manageStreamContext ensures that if the stream context is canceled, we inform the stream and
|
||
|
// possibly abort the underlying transport if the stream isn't finished.
|
||
|
func (m *Manager) manageStreamContext(ctx context.Context, wg *sync.WaitGroup, stream *drpcstream.Stream) {
|
||
|
defer mon.Task()(nil)(nil)
|
||
|
defer wg.Done()
|
||
|
|
||
|
select {
|
||
|
case <-m.term.Signal():
|
||
|
return
|
||
|
|
||
|
case <-stream.Terminated():
|
||
|
return
|
||
|
|
||
|
case <-ctx.Done():
|
||
|
stream.Cancel(ctx.Err())
|
||
|
if !stream.Finished() {
|
||
|
m.term.Set(ctx.Err())
|
||
|
}
|
||
|
}
|
||
|
}
|