Add proto marshal helper for string

Signed-off-by: Alex Vanin <alexey@nspcc.ru>
This commit is contained in:
Alex Vanin 2020-08-14 15:29:48 +03:00 committed by Stanislav Bogatyrev
parent 41f9c50424
commit 311b76b75b
4 changed files with 129 additions and 30 deletions

View file

@ -38,6 +38,14 @@ func BytesSize(field int, v []byte) int {
return VarUIntSize(uint64(prefix)) + VarUIntSize(uint64(ln)) + ln
}
func StringMarshal(field int, buf []byte, v string) (int, error) {
return BytesMarshal(field, buf, []byte(v))
}
func StringSize(field int, v string) int {
return BytesSize(field, []byte(v))
}
// varUIntSize returns length of varint byte sequence for uint64 value 'x'.
func VarUIntSize(x uint64) int {
return (bits.Len64(x|1) + 6) / 7

View file

@ -11,9 +11,10 @@ import (
type stablePrimitives struct {
FieldA []byte
FieldB string
}
func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) {
func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
@ -23,32 +24,24 @@ func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) {
}
var (
i, offset int
i, offset, fieldNum int
)
offset, err := proto.BytesMarshal(1, buf, s.FieldA)
fieldNum = 1
if wrongField {
fieldNum++
}
offset, err := proto.BytesMarshal(fieldNum, buf, s.FieldA)
if err != nil {
return nil, errors.Wrap(err, "can't marshal field a")
}
i += offset
return buf, nil
}
func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
fieldNum = 2
if wrongField {
fieldNum++
}
if buf == nil {
buf = make([]byte, s.stableSize())
}
var (
i, offset int
)
offset, err := proto.BytesMarshal(1+1, buf, s.FieldA)
offset, err = proto.StringMarshal(fieldNum, buf, s.FieldB)
if err != nil {
return nil, errors.Wrap(err, "can't marshal field a")
}
@ -58,7 +51,8 @@ func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error
}
func (s *stablePrimitives) stableSize() int {
return proto.BytesSize(1, s.FieldA)
return proto.BytesSize(1, s.FieldA) +
proto.StringSize(2, s.FieldB)
}
func TestBytesMarshal(t *testing.T) {
@ -86,11 +80,7 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) {
transport = test.Primitives{FieldA: data}
)
if !wrongField {
wire, err = custom.stableMarshal(nil)
} else {
wire, err = custom.stableMarshalWrongFieldNum(nil)
}
wire, err = custom.stableMarshal(nil, wrongField)
require.NoError(t, err)
wireGen, err := transport.Marshal()
@ -116,3 +106,51 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) {
require.Len(t, result.FieldA, 0)
}
}
func TestStringMarshal(t *testing.T) {
t.Run("not empty", func(t *testing.T) {
data := "Hello World"
testStringMarshal(t, data, false)
testStringMarshal(t, data, true)
})
t.Run("empty", func(t *testing.T) {
testStringMarshal(t, "", false)
})
}
func testStringMarshal(t *testing.T, s string, wrongField bool) {
var (
wire []byte
err error
custom = stablePrimitives{FieldB: s}
transport = test.Primitives{FieldB: s}
)
wire, err = custom.stableMarshal(nil, wrongField)
require.NoError(t, err)
wireGen, err := transport.Marshal()
require.NoError(t, err)
if !wrongField {
// we can check equality because single field cannot be unstable marshalled
require.Equal(t, wireGen, wire)
} else {
require.NotEqual(t, wireGen, wire)
}
result := new(test.Primitives)
err = result.Unmarshal(wire)
require.NoError(t, err)
if !wrongField {
require.Len(t, result.FieldB, len(s))
if len(s) > 0 {
require.Equal(t, s, result.FieldB)
}
} else {
require.Len(t, result.FieldB, 0)
}
}

View file

@ -24,6 +24,7 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type Primitives struct {
FieldA []byte `protobuf:"bytes,1,opt,name=field_a,json=fieldA,proto3" json:"field_a,omitempty"`
FieldB string `protobuf:"bytes,2,opt,name=field_b,json=fieldB,proto3" json:"field_b,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
@ -69,6 +70,13 @@ func (m *Primitives) GetFieldA() []byte {
return nil
}
func (m *Primitives) GetFieldB() string {
if m != nil {
return m.FieldB
}
return ""
}
func init() {
proto.RegisterType((*Primitives)(nil), "test.Primitives")
}
@ -76,14 +84,15 @@ func init() {
func init() { proto.RegisterFile("util/proto/test/test.proto", fileDescriptor_998ad0e1a3de8558) }
var fileDescriptor_998ad0e1a3de8558 = []byte{
// 109 bytes of a gzipped FileDescriptorProto
// 121 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2a, 0x2d, 0xc9, 0xcc,
0xd1, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0xd7, 0x2f, 0x49, 0x2d, 0x2e, 0x01, 0x13, 0x7a, 0x60, 0xbe,
0x10, 0x0b, 0x88, 0xad, 0xa4, 0xca, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96,
0x10, 0x0b, 0x88, 0xad, 0x64, 0xc7, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96,
0x5a, 0x2c, 0x24, 0xce, 0xc5, 0x9e, 0x96, 0x99, 0x9a, 0x93, 0x12, 0x9f, 0x28, 0xc1, 0xa8, 0xc0,
0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x3a, 0x09, 0x9c, 0x78, 0x24, 0xc7, 0x78, 0xe1, 0x91,
0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c, 0x33, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0x4d, 0x31, 0x06,
0x04, 0x00, 0x00, 0xff, 0xff, 0xe5, 0xa7, 0x6b, 0x1f, 0x63, 0x00, 0x00, 0x00,
0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x22, 0x24, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x38,
0xa1, 0x12, 0x4e, 0x4e, 0x02, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, 0x91,
0x1c, 0xe3, 0x8c, 0xc7, 0x72, 0x0c, 0x49, 0x6c, 0x60, 0xe3, 0x8d, 0x01, 0x01, 0x00, 0x00, 0xff,
0xff, 0x1a, 0xa1, 0x65, 0x4f, 0x7c, 0x00, 0x00, 0x00,
}
func (m *Primitives) Marshal() (dAtA []byte, err error) {
@ -110,6 +119,13 @@ func (m *Primitives) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i -= len(m.XXX_unrecognized)
copy(dAtA[i:], m.XXX_unrecognized)
}
if len(m.FieldB) > 0 {
i -= len(m.FieldB)
copy(dAtA[i:], m.FieldB)
i = encodeVarintTest(dAtA, i, uint64(len(m.FieldB)))
i--
dAtA[i] = 0x12
}
if len(m.FieldA) > 0 {
i -= len(m.FieldA)
copy(dAtA[i:], m.FieldA)
@ -141,6 +157,10 @@ func (m *Primitives) Size() (n int) {
if l > 0 {
n += 1 + l + sovTest(uint64(l))
}
l = len(m.FieldB)
if l > 0 {
n += 1 + l + sovTest(uint64(l))
}
if m.XXX_unrecognized != nil {
n += len(m.XXX_unrecognized)
}
@ -216,6 +236,38 @@ func (m *Primitives) Unmarshal(dAtA []byte) error {
m.FieldA = []byte{}
}
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field FieldB", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthTest
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthTest
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.FieldB = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipTest(dAtA[iNdEx:])

View file

@ -4,4 +4,5 @@ package test;
message Primitives {
bytes field_a = 1;
string field_b = 2;
}