diff --git a/util/proto/marshal.go b/util/proto/marshal.go index 3200176..025d9b4 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -120,6 +120,14 @@ func Int32Size(field int, v int32) int { return UInt64Size(field, uint64(v)) } +func EnumMarshal(field int, buf []byte, v int32) (int, error) { + return UInt64Marshal(field, buf, uint64(v)) +} + +func EnumSize(field int, v int32) int { + return UInt64Size(field, uint64(v)) +} + // varUIntSize returns length of varint byte sequence for uint64 value 'x'. func VarUIntSize(x uint64) int { return (bits.Len64(x|1) + 6) / 7 diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go index 0abcef9..6802523 100644 --- a/util/proto/marshal_test.go +++ b/util/proto/marshal_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +type SomeEnum int32 + type stablePrimitives struct { FieldA []byte FieldB string @@ -18,8 +20,15 @@ type stablePrimitives struct { FieldE uint32 FieldF int64 FieldG uint64 + FieldH SomeEnum } +const ( + ENUM_UNKNOWN SomeEnum = 0 + ENUM_POSITIVE = 1 + ENUM_NEGATIVE = -1 +) + func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) { if s == nil { return []byte{}, nil @@ -103,6 +112,16 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e } i += offset + fieldNum = 300 + if wrongField { + fieldNum++ + } + offset, err = proto.EnumMarshal(fieldNum, buf, int32(s.FieldH)) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field g") + } + i += offset + return buf, nil } @@ -113,7 +132,8 @@ func (s *stablePrimitives) stableSize() int { proto.Int32Size(201, s.FieldD) + proto.UInt32Size(202, s.FieldE) + proto.Int64Size(203, s.FieldF) + - proto.UInt64Size(204, s.FieldG) + proto.UInt64Size(204, s.FieldG) + + proto.EnumSize(300, int32(s.FieldH)) } func TestBytesMarshal(t *testing.T) { @@ -209,6 +229,14 @@ func TestUInt64Marshal(t *testing.T) { }) } +func TestEnumMarshal(t *testing.T) { + testEnumMarshal(t, ENUM_UNKNOWN, false) + testEnumMarshal(t, ENUM_POSITIVE, false) + testEnumMarshal(t, ENUM_POSITIVE, true) + testEnumMarshal(t, ENUM_NEGATIVE, false) + testEnumMarshal(t, ENUM_NEGATIVE, true) +} + func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { var ( wire []byte @@ -398,3 +426,36 @@ func testUInt64Marshal(t *testing.T, n uint64, wrongField bool) { require.EqualValues(t, 0, result.FieldG) } } + +func testEnumMarshal(t *testing.T, e SomeEnum, wrongField bool) { + var ( + wire []byte + err error + + custom = stablePrimitives{FieldH: e} + transport = test.Primitives{FieldH: test.Primitives_SomeEnum(e)} + ) + + wire, err = custom.stableMarshal(nil, wrongField) + require.NoError(t, err) + + wireGen, err := transport.Marshal() + require.NoError(t, err) + + if !wrongField { + // we can check equality because single field cannot be unstable marshalled + require.Equal(t, wireGen, wire) + } else { + require.NotEqual(t, wireGen, wire) + } + + result := new(test.Primitives) + err = result.Unmarshal(wire) + require.NoError(t, err) + + if !wrongField { + require.EqualValues(t, custom.FieldH, result.FieldH) + } else { + require.EqualValues(t, 0, result.FieldH) + } +} diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go index 6790854..0c31467 100644 Binary files a/util/proto/test/test.pb.go and b/util/proto/test/test.pb.go differ diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto index d3b43e6..5aed2b2 100644 --- a/util/proto/test/test.proto +++ b/util/proto/test/test.proto @@ -10,4 +10,11 @@ message Primitives { uint32 field_e = 202; int64 field_f = 203; uint64 field_g = 204; + + enum SomeEnum { + UNKNOWN = 0; + POSITIVE = 1; + NEGATIVE = -1; + } + SomeEnum field_h = 300; } \ No newline at end of file