diff --git a/rpc/message/encoding.go b/rpc/message/encoding.go new file mode 100644 index 0000000..7d9f7e9 --- /dev/null +++ b/rpc/message/encoding.go @@ -0,0 +1,48 @@ +package message + +import ( + "github.com/nspcc-dev/neofs-api-go/rpc/grpc" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +// GRPCConvertedMessage is an interface +// of the gRPC message that is used +// for Message encoding/decoding. +type GRPCConvertedMessage interface { + grpc.Message + proto.Message +} + +// Unmarshal decodes m from its Protobuf binary representation +// via related gRPC message. +// +// gm should be tof the same type as the m.ToGRPCMessage() return. +func Unmarshal(m Message, data []byte, gm GRPCConvertedMessage) error { + if err := proto.Unmarshal(data, gm); err != nil { + return err + } + + return m.FromGRPCMessage(gm) +} + +// MarshalJSON encodes m to Protobuf JSON representation. +func MarshalJSON(m Message) ([]byte, error) { + return protojson.MarshalOptions{ + EmitUnpopulated: true, + }.Marshal( + m.ToGRPCMessage().(proto.Message), + ) +} + +// UnmarshalJSON decodes m from its Protobuf JSON representation +// via related gRPC message. +// +// gm should be tof the same type as the m.ToGRPCMessage() return. +func UnmarshalJSON(m Message, data []byte, gm GRPCConvertedMessage) error { + if err := protojson.Unmarshal(data, gm); err != nil { + return err + } + + return m.FromGRPCMessage(gm) +} diff --git a/rpc/message/message.go b/rpc/message/message.go new file mode 100644 index 0000000..5cc3702 --- /dev/null +++ b/rpc/message/message.go @@ -0,0 +1,43 @@ +package message + +import ( + "fmt" + + "github.com/nspcc-dev/neofs-api-go/rpc/grpc" +) + +// Message represents raw Protobuf message +// that can be transmitted via several +// transport protocols. +type Message interface { + // Must return gRPC message that can + // be used for gRPC protocol transmission. + ToGRPCMessage() grpc.Message + + // Must restore the message from related + // gRPC message. + // + // If gRPC message is not a related one, + // ErrUnexpectedMessageType can be returned + // to indicate this. + FromGRPCMessage(grpc.Message) error +} + +// ErrUnexpectedMessageType is an error that +// is used to indicate message mismatch. +type ErrUnexpectedMessageType struct { + exp, act interface{} +} + +// NewUnexpectedMessageType initializes an error about message mismatch +// between act and exp. +func NewUnexpectedMessageType(act, exp interface{}) ErrUnexpectedMessageType { + return ErrUnexpectedMessageType{ + exp: exp, + act: act, + } +} + +func (e ErrUnexpectedMessageType) Error() string { + return fmt.Sprintf("unexpected message type %T: expected %T", e.act, e.exp) +} diff --git a/rpc/message/test/message.go b/rpc/message/test/message.go new file mode 100644 index 0000000..12ad8ad --- /dev/null +++ b/rpc/message/test/message.go @@ -0,0 +1,68 @@ +package messagetest + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/nspcc-dev/neofs-api-go/rpc/message" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +type jsonMessage interface { + json.Marshaler + json.Unmarshaler +} + +type binaryMessage interface { + StableMarshal([]byte) ([]byte, error) + Unmarshal([]byte) error +} + +func TestRPCMessage(t *testing.T, msgGens ...func(empty bool) message.Message) { + for _, msgGen := range msgGens { + msg := msgGen(false) + + t.Run(fmt.Sprintf("convert_%T", msg), func(t *testing.T) { + msg := msgGen(false) + + err := msg.FromGRPCMessage(100) + + require.True(t, errors.As(err, new(message.ErrUnexpectedMessageType))) + + msg2 := msgGen(true) + + err = msg2.FromGRPCMessage(msg.ToGRPCMessage()) + require.NoError(t, err) + + require.Equal(t, msg, msg2) + }) + + t.Run("encoding", func(t *testing.T) { + if jm, ok := msg.(jsonMessage); ok { + t.Run(fmt.Sprintf("JSON_%T", msg), func(t *testing.T) { + data, err := jm.MarshalJSON() + require.NoError(t, err) + + jm2 := msgGen(true).(jsonMessage) + require.NoError(t, jm2.UnmarshalJSON(data)) + + require.Equal(t, jm, jm2) + }) + } + + if bm, ok := msg.(binaryMessage); ok { + t.Run(fmt.Sprintf("Binary_%T", msg), func(t *testing.T) { + data, err := bm.StableMarshal(nil) + require.NoError(t, err) + + bm2 := msgGen(true).(binaryMessage) + require.NoError(t, bm2.Unmarshal(data)) + + require.Equal(t, bm, bm2) + }) + } + }) + } +}