From 89bd8f39152c0cdef3e00e315e6b5edf830646b8 Mon Sep 17 00:00:00 2001 From: Alex Vanin Date: Fri, 14 Aug 2020 15:53:57 +0300 Subject: [PATCH] Add stable marshaller helper for bool Signed-off-by: Alex Vanin --- util/proto/marshal.go | 28 ++++++++++++- util/proto/marshal_test.go | 84 +++++++++++++++++++++++++++++++------- util/proto/test/test.pb.go | 54 +++++++++++++++++++++--- util/proto/test/test.proto | 3 +- 4 files changed, 147 insertions(+), 22 deletions(-) diff --git a/util/proto/marshal.go b/util/proto/marshal.go index 2d5f9e6..8570f6a 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -16,10 +16,10 @@ func BytesMarshal(field int, buf, v []byte) (int, error) { return 0, nil } + prefix := field<<3 | 0x2 + // buf length check can prevent panic at PutUvarint, but it will make // marshaller a bit slower. - - prefix := field<<3 | 0x2 i := binary.PutUvarint(buf, uint64(prefix)) i += binary.PutUvarint(buf[i:], uint64(len(v))) i += copy(buf[i:], v) @@ -46,6 +46,30 @@ func StringSize(field int, v string) int { return BytesSize(field, []byte(v)) } +func BoolMarshal(field int, buf []byte, v bool) (int, error) { + if !v { + return 0, nil + } + + prefix := field << 3 + + // buf length check can prevent panic at PutUvarint, but it will make + // marshaller a bit slower. + i := binary.PutUvarint(buf, uint64(prefix)) + buf[i] = 0x1 + + return i + 1, nil +} + +func BoolSize(field int, v bool) int { + if !v { + return 0 + } + + prefix := field << 3 + return VarUIntSize(uint64(prefix)) + 1 // bool is always 1 byte long +} + // varUIntSize returns length of varint byte sequence for uint64 value 'x'. func VarUIntSize(x uint64) int { return (bits.Len64(x|1) + 6) / 7 diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go index 66a8b62..36847dd 100644 --- a/util/proto/marshal_test.go +++ b/util/proto/marshal_test.go @@ -12,6 +12,7 @@ import ( type stablePrimitives struct { FieldA []byte FieldB string + FieldC bool } func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) { @@ -43,7 +44,17 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e } offset, err = proto.StringMarshal(fieldNum, buf, s.FieldB) if err != nil { - return nil, errors.Wrap(err, "can't marshal field a") + return nil, errors.Wrap(err, "can't marshal field b") + } + i += offset + + fieldNum = 200 + if wrongField { + fieldNum++ + } + offset, err = proto.BoolMarshal(fieldNum, buf, s.FieldC) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field c") } i += offset @@ -52,7 +63,8 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e func (s *stablePrimitives) stableSize() int { return proto.BytesSize(1, s.FieldA) + - proto.StringSize(2, s.FieldB) + proto.StringSize(2, s.FieldB) + + proto.BoolSize(200, s.FieldC) } func TestBytesMarshal(t *testing.T) { @@ -71,6 +83,29 @@ func TestBytesMarshal(t *testing.T) { }) } +func TestStringMarshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := "Hello World" + testStringMarshal(t, data, false) + testStringMarshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testStringMarshal(t, "", false) + }) +} + +func TestBoolMarshal(t *testing.T) { + t.Run("true", func(t *testing.T) { + testBoolMarshal(t, true, false) + testBoolMarshal(t, true, true) + }) + + t.Run("false", func(t *testing.T) { + testBoolMarshal(t, false, false) + }) +} + func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { var ( wire []byte @@ -107,18 +142,6 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { } } -func TestStringMarshal(t *testing.T) { - t.Run("not empty", func(t *testing.T) { - data := "Hello World" - testStringMarshal(t, data, false) - testStringMarshal(t, data, true) - }) - - t.Run("empty", func(t *testing.T) { - testStringMarshal(t, "", false) - }) -} - func testStringMarshal(t *testing.T, s string, wrongField bool) { var ( wire []byte @@ -154,3 +177,36 @@ func testStringMarshal(t *testing.T, s string, wrongField bool) { require.Len(t, result.FieldB, 0) } } + +func testBoolMarshal(t *testing.T, b bool, wrongField bool) { + var ( + wire []byte + err error + + custom = stablePrimitives{FieldC: b} + transport = test.Primitives{FieldC: b} + ) + + wire, err = custom.stableMarshal(nil, wrongField) + require.NoError(t, err) + + wireGen, err := transport.Marshal() + require.NoError(t, err) + + if !wrongField { + // we can check equality because single field cannot be unstable marshalled + require.Equal(t, wireGen, wire) + } else { + require.NotEqual(t, wireGen, wire) + } + + result := new(test.Primitives) + err = result.Unmarshal(wire) + require.NoError(t, err) + + if !wrongField { + require.Equal(t, b, result.FieldC) + } else { + require.False(t, false, result.FieldC) + } +} diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go index 9a1afd3..910812c 100644 --- a/util/proto/test/test.pb.go +++ b/util/proto/test/test.pb.go @@ -25,6 +25,7 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package type Primitives struct { FieldA []byte `protobuf:"bytes,1,opt,name=field_a,json=fieldA,proto3" json:"field_a,omitempty"` FieldB string `protobuf:"bytes,2,opt,name=field_b,json=fieldB,proto3" json:"field_b,omitempty"` + FieldC bool `protobuf:"varint,200,opt,name=field_c,json=fieldC,proto3" json:"field_c,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -77,6 +78,13 @@ func (m *Primitives) GetFieldB() string { return "" } +func (m *Primitives) GetFieldC() bool { + if m != nil { + return m.FieldC + } + return false +} + func init() { proto.RegisterType((*Primitives)(nil), "test.Primitives") } @@ -84,15 +92,16 @@ func init() { func init() { proto.RegisterFile("util/proto/test/test.proto", fileDescriptor_998ad0e1a3de8558) } var fileDescriptor_998ad0e1a3de8558 = []byte{ - // 121 bytes of a gzipped FileDescriptorProto + // 134 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2a, 0x2d, 0xc9, 0xcc, 0xd1, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0xd7, 0x2f, 0x49, 0x2d, 0x2e, 0x01, 0x13, 0x7a, 0x60, 0xbe, - 0x10, 0x0b, 0x88, 0xad, 0x64, 0xc7, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96, + 0x10, 0x0b, 0x88, 0xad, 0x14, 0xc1, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96, 0x5a, 0x2c, 0x24, 0xce, 0xc5, 0x9e, 0x96, 0x99, 0x9a, 0x93, 0x12, 0x9f, 0x28, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x22, 0x24, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x38, - 0xa1, 0x12, 0x4e, 0x4e, 0x02, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, 0x91, - 0x1c, 0xe3, 0x8c, 0xc7, 0x72, 0x0c, 0x49, 0x6c, 0x60, 0xe3, 0x8d, 0x01, 0x01, 0x00, 0x00, 0xff, - 0xff, 0x1a, 0xa1, 0x65, 0x4f, 0x7c, 0x00, 0x00, 0x00, + 0xa1, 0x12, 0x4e, 0x42, 0x12, 0x30, 0x89, 0x64, 0x89, 0x13, 0x20, 0x2d, 0x1c, 0x50, 0x19, 0x67, + 0x27, 0x81, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, 0x48, 0x8e, 0x71, 0xc6, + 0x63, 0x39, 0x86, 0x24, 0x36, 0xb0, 0xc5, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x1a, 0x10, + 0x20, 0x0a, 0x96, 0x00, 0x00, 0x00, } func (m *Primitives) Marshal() (dAtA []byte, err error) { @@ -119,6 +128,18 @@ func (m *Primitives) MarshalToSizedBuffer(dAtA []byte) (int, error) { i -= len(m.XXX_unrecognized) copy(dAtA[i:], m.XXX_unrecognized) } + if m.FieldC { + i-- + if m.FieldC { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0xc + i-- + dAtA[i] = 0xc0 + } if len(m.FieldB) > 0 { i -= len(m.FieldB) copy(dAtA[i:], m.FieldB) @@ -161,6 +182,9 @@ func (m *Primitives) Size() (n int) { if l > 0 { n += 1 + l + sovTest(uint64(l)) } + if m.FieldC { + n += 3 + } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -268,6 +292,26 @@ func (m *Primitives) Unmarshal(dAtA []byte) error { } m.FieldB = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 200: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field FieldC", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.FieldC = bool(v != 0) default: iNdEx = preIndex skippy, err := skipTest(dAtA[iNdEx:]) diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto index 81a1058..4be07a4 100644 --- a/util/proto/test/test.proto +++ b/util/proto/test/test.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package test; message Primitives { - bytes field_a = 1; + bytes field_a = 1; string field_b = 2; + bool field_c = 200; } \ No newline at end of file