From 311b76b75b1aeec29708f94c432ad81e30c9cf9c Mon Sep 17 00:00:00 2001 From: Alex Vanin Date: Fri, 14 Aug 2020 15:29:48 +0300 Subject: [PATCH] Add proto marshal helper for string Signed-off-by: Alex Vanin --- util/proto/marshal.go | 8 ++++ util/proto/marshal_test.go | 88 +++++++++++++++++++++++++++----------- util/proto/test/test.pb.go | 62 ++++++++++++++++++++++++--- util/proto/test/test.proto | 1 + 4 files changed, 129 insertions(+), 30 deletions(-) diff --git a/util/proto/marshal.go b/util/proto/marshal.go index cd9c493..2d5f9e6 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -38,6 +38,14 @@ func BytesSize(field int, v []byte) int { return VarUIntSize(uint64(prefix)) + VarUIntSize(uint64(ln)) + ln } +func StringMarshal(field int, buf []byte, v string) (int, error) { + return BytesMarshal(field, buf, []byte(v)) +} + +func StringSize(field int, v string) int { + return BytesSize(field, []byte(v)) +} + // varUIntSize returns length of varint byte sequence for uint64 value 'x'. func VarUIntSize(x uint64) int { return (bits.Len64(x|1) + 6) / 7 diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go index ccc1cca..66a8b62 100644 --- a/util/proto/marshal_test.go +++ b/util/proto/marshal_test.go @@ -11,9 +11,10 @@ import ( type stablePrimitives struct { FieldA []byte + FieldB string } -func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) { +func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) { if s == nil { return []byte{}, nil } @@ -23,32 +24,24 @@ func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) { } var ( - i, offset int + i, offset, fieldNum int ) - offset, err := proto.BytesMarshal(1, buf, s.FieldA) + fieldNum = 1 + if wrongField { + fieldNum++ + } + offset, err := proto.BytesMarshal(fieldNum, buf, s.FieldA) if err != nil { return nil, errors.Wrap(err, "can't marshal field a") } i += offset - return buf, nil -} - -func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error) { - if s == nil { - return []byte{}, nil + fieldNum = 2 + if wrongField { + fieldNum++ } - - if buf == nil { - buf = make([]byte, s.stableSize()) - } - - var ( - i, offset int - ) - - offset, err := proto.BytesMarshal(1+1, buf, s.FieldA) + offset, err = proto.StringMarshal(fieldNum, buf, s.FieldB) if err != nil { return nil, errors.Wrap(err, "can't marshal field a") } @@ -58,7 +51,8 @@ func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error } func (s *stablePrimitives) stableSize() int { - return proto.BytesSize(1, s.FieldA) + return proto.BytesSize(1, s.FieldA) + + proto.StringSize(2, s.FieldB) } func TestBytesMarshal(t *testing.T) { @@ -86,11 +80,7 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { transport = test.Primitives{FieldA: data} ) - if !wrongField { - wire, err = custom.stableMarshal(nil) - } else { - wire, err = custom.stableMarshalWrongFieldNum(nil) - } + wire, err = custom.stableMarshal(nil, wrongField) require.NoError(t, err) wireGen, err := transport.Marshal() @@ -116,3 +106,51 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { require.Len(t, result.FieldA, 0) } } + +func TestStringMarshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := "Hello World" + testStringMarshal(t, data, false) + testStringMarshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testStringMarshal(t, "", false) + }) +} + +func testStringMarshal(t *testing.T, s string, wrongField bool) { + var ( + wire []byte + err error + + custom = stablePrimitives{FieldB: s} + transport = test.Primitives{FieldB: s} + ) + + wire, err = custom.stableMarshal(nil, wrongField) + require.NoError(t, err) + + wireGen, err := transport.Marshal() + require.NoError(t, err) + + if !wrongField { + // we can check equality because single field cannot be unstable marshalled + require.Equal(t, wireGen, wire) + } else { + require.NotEqual(t, wireGen, wire) + } + + result := new(test.Primitives) + err = result.Unmarshal(wire) + require.NoError(t, err) + + if !wrongField { + require.Len(t, result.FieldB, len(s)) + if len(s) > 0 { + require.Equal(t, s, result.FieldB) + } + } else { + require.Len(t, result.FieldB, 0) + } +} diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go index 5db45e5..9a1afd3 100644 --- a/util/proto/test/test.pb.go +++ b/util/proto/test/test.pb.go @@ -24,6 +24,7 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package type Primitives struct { FieldA []byte `protobuf:"bytes,1,opt,name=field_a,json=fieldA,proto3" json:"field_a,omitempty"` + FieldB string `protobuf:"bytes,2,opt,name=field_b,json=fieldB,proto3" json:"field_b,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -69,6 +70,13 @@ func (m *Primitives) GetFieldA() []byte { return nil } +func (m *Primitives) GetFieldB() string { + if m != nil { + return m.FieldB + } + return "" +} + func init() { proto.RegisterType((*Primitives)(nil), "test.Primitives") } @@ -76,14 +84,15 @@ func init() { func init() { proto.RegisterFile("util/proto/test/test.proto", fileDescriptor_998ad0e1a3de8558) } var fileDescriptor_998ad0e1a3de8558 = []byte{ - // 109 bytes of a gzipped FileDescriptorProto + // 121 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2a, 0x2d, 0xc9, 0xcc, 0xd1, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0xd7, 0x2f, 0x49, 0x2d, 0x2e, 0x01, 0x13, 0x7a, 0x60, 0xbe, - 0x10, 0x0b, 0x88, 0xad, 0xa4, 0xca, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96, + 0x10, 0x0b, 0x88, 0xad, 0x64, 0xc7, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96, 0x5a, 0x2c, 0x24, 0xce, 0xc5, 0x9e, 0x96, 0x99, 0x9a, 0x93, 0x12, 0x9f, 0x28, 0xc1, 0xa8, 0xc0, - 0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x3a, 0x09, 0x9c, 0x78, 0x24, 0xc7, 0x78, 0xe1, 0x91, - 0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c, 0x33, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0x4d, 0x31, 0x06, - 0x04, 0x00, 0x00, 0xff, 0xff, 0xe5, 0xa7, 0x6b, 0x1f, 0x63, 0x00, 0x00, 0x00, + 0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x22, 0x24, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x38, + 0xa1, 0x12, 0x4e, 0x4e, 0x02, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, 0x91, + 0x1c, 0xe3, 0x8c, 0xc7, 0x72, 0x0c, 0x49, 0x6c, 0x60, 0xe3, 0x8d, 0x01, 0x01, 0x00, 0x00, 0xff, + 0xff, 0x1a, 0xa1, 0x65, 0x4f, 0x7c, 0x00, 0x00, 0x00, } func (m *Primitives) Marshal() (dAtA []byte, err error) { @@ -110,6 +119,13 @@ func (m *Primitives) MarshalToSizedBuffer(dAtA []byte) (int, error) { i -= len(m.XXX_unrecognized) copy(dAtA[i:], m.XXX_unrecognized) } + if len(m.FieldB) > 0 { + i -= len(m.FieldB) + copy(dAtA[i:], m.FieldB) + i = encodeVarintTest(dAtA, i, uint64(len(m.FieldB))) + i-- + dAtA[i] = 0x12 + } if len(m.FieldA) > 0 { i -= len(m.FieldA) copy(dAtA[i:], m.FieldA) @@ -141,6 +157,10 @@ func (m *Primitives) Size() (n int) { if l > 0 { n += 1 + l + sovTest(uint64(l)) } + l = len(m.FieldB) + if l > 0 { + n += 1 + l + sovTest(uint64(l)) + } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -216,6 +236,38 @@ func (m *Primitives) Unmarshal(dAtA []byte) error { m.FieldA = []byte{} } iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field FieldB", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthTest + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthTest + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.FieldB = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipTest(dAtA[iNdEx:]) diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto index cf51a44..81a1058 100644 --- a/util/proto/test/test.proto +++ b/util/proto/test/test.proto @@ -4,4 +4,5 @@ package test; message Primitives { bytes field_a = 1; + string field_b = 2; } \ No newline at end of file