117 lines
2.7 KiB
Go
117 lines
2.7 KiB
Go
|
// Copyright (C) 2019 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package drpcmux
|
||
|
|
||
|
import (
|
||
|
"reflect"
|
||
|
|
||
|
"github.com/zeebo/errs"
|
||
|
|
||
|
"storj.io/drpc"
|
||
|
)
|
||
|
|
||
|
// Mux is an implementation of Handler to serve drpc connections to the
|
||
|
// appropriate Receivers registered by Descriptions.
|
||
|
type Mux struct {
|
||
|
rpcs map[string]rpcData
|
||
|
}
|
||
|
|
||
|
// New constructs a new Mux.
|
||
|
func New() *Mux {
|
||
|
return &Mux{
|
||
|
rpcs: make(map[string]rpcData),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
streamType = reflect.TypeOf((*drpc.Stream)(nil)).Elem()
|
||
|
messageType = reflect.TypeOf((*drpc.Message)(nil)).Elem()
|
||
|
)
|
||
|
|
||
|
type rpcData struct {
|
||
|
srv interface{}
|
||
|
receiver drpc.Receiver
|
||
|
in1 reflect.Type
|
||
|
in2 reflect.Type
|
||
|
}
|
||
|
|
||
|
// Register associates the rpcs described by the description in the server.
|
||
|
// It returns an error if there was a problem registering it.
|
||
|
func (m *Mux) Register(srv interface{}, desc drpc.Description) error {
|
||
|
n := desc.NumMethods()
|
||
|
for i := 0; i < n; i++ {
|
||
|
rpc, receiver, method, ok := desc.Method(i)
|
||
|
if !ok {
|
||
|
return errs.New("Description returned invalid method for index %d", i)
|
||
|
}
|
||
|
if err := m.registerOne(srv, rpc, receiver, method); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// registerOne does the work to register a single rpc.
|
||
|
func (m *Mux) registerOne(srv interface{}, rpc string, receiver drpc.Receiver, method interface{}) error {
|
||
|
data := rpcData{srv: srv, receiver: receiver}
|
||
|
|
||
|
switch mt := reflect.TypeOf(method); {
|
||
|
// unitary input, unitary output
|
||
|
case mt.NumOut() == 2:
|
||
|
data.in1 = mt.In(2)
|
||
|
if !data.in1.Implements(messageType) {
|
||
|
return errs.New("input argument not a drpc message: %v", data.in1)
|
||
|
}
|
||
|
|
||
|
// unitary input, stream output
|
||
|
case mt.NumIn() == 3:
|
||
|
data.in1 = mt.In(1)
|
||
|
if !data.in1.Implements(messageType) {
|
||
|
return errs.New("input argument not a drpc message: %v", data.in1)
|
||
|
}
|
||
|
data.in2 = streamType
|
||
|
|
||
|
// stream input
|
||
|
case mt.NumIn() == 2:
|
||
|
data.in1 = streamType
|
||
|
|
||
|
// code gen bug?
|
||
|
default:
|
||
|
return errs.New("unknown method type: %v", mt)
|
||
|
}
|
||
|
|
||
|
m.rpcs[rpc] = data
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// HandleRPC handles the rpc that has been requested by the stream.
|
||
|
func (m *Mux) HandleRPC(stream drpc.Stream, rpc string) (err error) {
|
||
|
data, ok := m.rpcs[rpc]
|
||
|
if !ok {
|
||
|
return drpc.ProtocolError.New("unknown rpc: %q", rpc)
|
||
|
}
|
||
|
|
||
|
in := interface{}(stream)
|
||
|
if data.in1 != streamType {
|
||
|
msg, ok := reflect.New(data.in1.Elem()).Interface().(drpc.Message)
|
||
|
if !ok {
|
||
|
return drpc.InternalError.New("invalid rpc input type")
|
||
|
}
|
||
|
if err := stream.MsgRecv(msg); err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
in = msg
|
||
|
}
|
||
|
|
||
|
out, err := data.receiver(data.srv, stream.Context(), in, stream)
|
||
|
switch {
|
||
|
case err != nil:
|
||
|
return errs.Wrap(err)
|
||
|
case out != nil:
|
||
|
return stream.MsgSend(out)
|
||
|
default:
|
||
|
return nil
|
||
|
}
|
||
|
}
|