diff --git a/util/proto/marshal.go b/util/proto/marshal.go index 2d5f9e6..8570f6a 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -16,10 +16,10 @@ func BytesMarshal(field int, buf, v []byte) (int, error) { return 0, nil } + prefix := field<<3 | 0x2 + // buf length check can prevent panic at PutUvarint, but it will make // marshaller a bit slower. - - prefix := field<<3 | 0x2 i := binary.PutUvarint(buf, uint64(prefix)) i += binary.PutUvarint(buf[i:], uint64(len(v))) i += copy(buf[i:], v) @@ -46,6 +46,30 @@ func StringSize(field int, v string) int { return BytesSize(field, []byte(v)) } +func BoolMarshal(field int, buf []byte, v bool) (int, error) { + if !v { + return 0, nil + } + + prefix := field << 3 + + // buf length check can prevent panic at PutUvarint, but it will make + // marshaller a bit slower. + i := binary.PutUvarint(buf, uint64(prefix)) + buf[i] = 0x1 + + return i + 1, nil +} + +func BoolSize(field int, v bool) int { + if !v { + return 0 + } + + prefix := field << 3 + return VarUIntSize(uint64(prefix)) + 1 // bool is always 1 byte long +} + // varUIntSize returns length of varint byte sequence for uint64 value 'x'. func VarUIntSize(x uint64) int { return (bits.Len64(x|1) + 6) / 7 diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go index 66a8b62..36847dd 100644 --- a/util/proto/marshal_test.go +++ b/util/proto/marshal_test.go @@ -12,6 +12,7 @@ import ( type stablePrimitives struct { FieldA []byte FieldB string + FieldC bool } func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) { @@ -43,7 +44,17 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e } offset, err = proto.StringMarshal(fieldNum, buf, s.FieldB) if err != nil { - return nil, errors.Wrap(err, "can't marshal field a") + return nil, errors.Wrap(err, "can't marshal field b") + } + i += offset + + fieldNum = 200 + if wrongField { + fieldNum++ + } + offset, err = proto.BoolMarshal(fieldNum, buf, s.FieldC) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field c") } i += offset @@ -52,7 +63,8 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e func (s *stablePrimitives) stableSize() int { return proto.BytesSize(1, s.FieldA) + - proto.StringSize(2, s.FieldB) + proto.StringSize(2, s.FieldB) + + proto.BoolSize(200, s.FieldC) } func TestBytesMarshal(t *testing.T) { @@ -71,6 +83,29 @@ func TestBytesMarshal(t *testing.T) { }) } +func TestStringMarshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := "Hello World" + testStringMarshal(t, data, false) + testStringMarshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testStringMarshal(t, "", false) + }) +} + +func TestBoolMarshal(t *testing.T) { + t.Run("true", func(t *testing.T) { + testBoolMarshal(t, true, false) + testBoolMarshal(t, true, true) + }) + + t.Run("false", func(t *testing.T) { + testBoolMarshal(t, false, false) + }) +} + func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { var ( wire []byte @@ -107,18 +142,6 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { } } -func TestStringMarshal(t *testing.T) { - t.Run("not empty", func(t *testing.T) { - data := "Hello World" - testStringMarshal(t, data, false) - testStringMarshal(t, data, true) - }) - - t.Run("empty", func(t *testing.T) { - testStringMarshal(t, "", false) - }) -} - func testStringMarshal(t *testing.T, s string, wrongField bool) { var ( wire []byte @@ -154,3 +177,36 @@ func testStringMarshal(t *testing.T, s string, wrongField bool) { require.Len(t, result.FieldB, 0) } } + +func testBoolMarshal(t *testing.T, b bool, wrongField bool) { + var ( + wire []byte + err error + + custom = stablePrimitives{FieldC: b} + transport = test.Primitives{FieldC: b} + ) + + wire, err = custom.stableMarshal(nil, wrongField) + require.NoError(t, err) + + wireGen, err := transport.Marshal() + require.NoError(t, err) + + if !wrongField { + // we can check equality because single field cannot be unstable marshalled + require.Equal(t, wireGen, wire) + } else { + require.NotEqual(t, wireGen, wire) + } + + result := new(test.Primitives) + err = result.Unmarshal(wire) + require.NoError(t, err) + + if !wrongField { + require.Equal(t, b, result.FieldC) + } else { + require.False(t, false, result.FieldC) + } +} diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go index 9a1afd3..910812c 100644 Binary files a/util/proto/test/test.pb.go and b/util/proto/test/test.pb.go differ diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto index 81a1058..4be07a4 100644 --- a/util/proto/test/test.proto +++ b/util/proto/test/test.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package test; message Primitives { - bytes field_a = 1; + bytes field_a = 1; string field_b = 2; + bool field_c = 200; } \ No newline at end of file