Add stable marshaller helper for bool

Signed-off-by: Alex Vanin <alexey@nspcc.ru>
This commit is contained in:
Alex Vanin 2020-08-14 15:53:57 +03:00 committed by Stanislav Bogatyrev
parent 311b76b75b
commit 89bd8f3915
4 changed files with 147 additions and 22 deletions

View file

@ -16,10 +16,10 @@ func BytesMarshal(field int, buf, v []byte) (int, error) {
return 0, nil
}
prefix := field<<3 | 0x2
// buf length check can prevent panic at PutUvarint, but it will make
// marshaller a bit slower.
prefix := field<<3 | 0x2
i := binary.PutUvarint(buf, uint64(prefix))
i += binary.PutUvarint(buf[i:], uint64(len(v)))
i += copy(buf[i:], v)
@ -46,6 +46,30 @@ func StringSize(field int, v string) int {
return BytesSize(field, []byte(v))
}
func BoolMarshal(field int, buf []byte, v bool) (int, error) {
if !v {
return 0, nil
}
prefix := field << 3
// buf length check can prevent panic at PutUvarint, but it will make
// marshaller a bit slower.
i := binary.PutUvarint(buf, uint64(prefix))
buf[i] = 0x1
return i + 1, nil
}
func BoolSize(field int, v bool) int {
if !v {
return 0
}
prefix := field << 3
return VarUIntSize(uint64(prefix)) + 1 // bool is always 1 byte long
}
// varUIntSize returns length of varint byte sequence for uint64 value 'x'.
func VarUIntSize(x uint64) int {
return (bits.Len64(x|1) + 6) / 7

View file

@ -12,6 +12,7 @@ import (
type stablePrimitives struct {
FieldA []byte
FieldB string
FieldC bool
}
func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) {
@ -43,7 +44,17 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e
}
offset, err = proto.StringMarshal(fieldNum, buf, s.FieldB)
if err != nil {
return nil, errors.Wrap(err, "can't marshal field a")
return nil, errors.Wrap(err, "can't marshal field b")
}
i += offset
fieldNum = 200
if wrongField {
fieldNum++
}
offset, err = proto.BoolMarshal(fieldNum, buf, s.FieldC)
if err != nil {
return nil, errors.Wrap(err, "can't marshal field c")
}
i += offset
@ -52,7 +63,8 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e
func (s *stablePrimitives) stableSize() int {
return proto.BytesSize(1, s.FieldA) +
proto.StringSize(2, s.FieldB)
proto.StringSize(2, s.FieldB) +
proto.BoolSize(200, s.FieldC)
}
func TestBytesMarshal(t *testing.T) {
@ -71,6 +83,29 @@ func TestBytesMarshal(t *testing.T) {
})
}
func TestStringMarshal(t *testing.T) {
t.Run("not empty", func(t *testing.T) {
data := "Hello World"
testStringMarshal(t, data, false)
testStringMarshal(t, data, true)
})
t.Run("empty", func(t *testing.T) {
testStringMarshal(t, "", false)
})
}
func TestBoolMarshal(t *testing.T) {
t.Run("true", func(t *testing.T) {
testBoolMarshal(t, true, false)
testBoolMarshal(t, true, true)
})
t.Run("false", func(t *testing.T) {
testBoolMarshal(t, false, false)
})
}
func testBytesMarshal(t *testing.T, data []byte, wrongField bool) {
var (
wire []byte
@ -107,18 +142,6 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) {
}
}
func TestStringMarshal(t *testing.T) {
t.Run("not empty", func(t *testing.T) {
data := "Hello World"
testStringMarshal(t, data, false)
testStringMarshal(t, data, true)
})
t.Run("empty", func(t *testing.T) {
testStringMarshal(t, "", false)
})
}
func testStringMarshal(t *testing.T, s string, wrongField bool) {
var (
wire []byte
@ -154,3 +177,36 @@ func testStringMarshal(t *testing.T, s string, wrongField bool) {
require.Len(t, result.FieldB, 0)
}
}
func testBoolMarshal(t *testing.T, b bool, wrongField bool) {
var (
wire []byte
err error
custom = stablePrimitives{FieldC: b}
transport = test.Primitives{FieldC: b}
)
wire, err = custom.stableMarshal(nil, wrongField)
require.NoError(t, err)
wireGen, err := transport.Marshal()
require.NoError(t, err)
if !wrongField {
// we can check equality because single field cannot be unstable marshalled
require.Equal(t, wireGen, wire)
} else {
require.NotEqual(t, wireGen, wire)
}
result := new(test.Primitives)
err = result.Unmarshal(wire)
require.NoError(t, err)
if !wrongField {
require.Equal(t, b, result.FieldC)
} else {
require.False(t, false, result.FieldC)
}
}

View file

@ -25,6 +25,7 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type Primitives struct {
FieldA []byte `protobuf:"bytes,1,opt,name=field_a,json=fieldA,proto3" json:"field_a,omitempty"`
FieldB string `protobuf:"bytes,2,opt,name=field_b,json=fieldB,proto3" json:"field_b,omitempty"`
FieldC bool `protobuf:"varint,200,opt,name=field_c,json=fieldC,proto3" json:"field_c,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
@ -77,6 +78,13 @@ func (m *Primitives) GetFieldB() string {
return ""
}
func (m *Primitives) GetFieldC() bool {
if m != nil {
return m.FieldC
}
return false
}
func init() {
proto.RegisterType((*Primitives)(nil), "test.Primitives")
}
@ -84,15 +92,16 @@ func init() {
func init() { proto.RegisterFile("util/proto/test/test.proto", fileDescriptor_998ad0e1a3de8558) }
var fileDescriptor_998ad0e1a3de8558 = []byte{
// 121 bytes of a gzipped FileDescriptorProto
// 134 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2a, 0x2d, 0xc9, 0xcc,
0xd1, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0xd7, 0x2f, 0x49, 0x2d, 0x2e, 0x01, 0x13, 0x7a, 0x60, 0xbe,
0x10, 0x0b, 0x88, 0xad, 0x64, 0xc7, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96,
0x10, 0x0b, 0x88, 0xad, 0x14, 0xc1, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96,
0x5a, 0x2c, 0x24, 0xce, 0xc5, 0x9e, 0x96, 0x99, 0x9a, 0x93, 0x12, 0x9f, 0x28, 0xc1, 0xa8, 0xc0,
0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x22, 0x24, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x38,
0xa1, 0x12, 0x4e, 0x4e, 0x02, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, 0x91,
0x1c, 0xe3, 0x8c, 0xc7, 0x72, 0x0c, 0x49, 0x6c, 0x60, 0xe3, 0x8d, 0x01, 0x01, 0x00, 0x00, 0xff,
0xff, 0x1a, 0xa1, 0x65, 0x4f, 0x7c, 0x00, 0x00, 0x00,
0xa1, 0x12, 0x4e, 0x42, 0x12, 0x30, 0x89, 0x64, 0x89, 0x13, 0x20, 0x2d, 0x1c, 0x50, 0x19, 0x67,
0x27, 0x81, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, 0x48, 0x8e, 0x71, 0xc6,
0x63, 0x39, 0x86, 0x24, 0x36, 0xb0, 0xc5, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x1a, 0x10,
0x20, 0x0a, 0x96, 0x00, 0x00, 0x00,
}
func (m *Primitives) Marshal() (dAtA []byte, err error) {
@ -119,6 +128,18 @@ func (m *Primitives) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i -= len(m.XXX_unrecognized)
copy(dAtA[i:], m.XXX_unrecognized)
}
if m.FieldC {
i--
if m.FieldC {
dAtA[i] = 1
} else {
dAtA[i] = 0
}
i--
dAtA[i] = 0xc
i--
dAtA[i] = 0xc0
}
if len(m.FieldB) > 0 {
i -= len(m.FieldB)
copy(dAtA[i:], m.FieldB)
@ -161,6 +182,9 @@ func (m *Primitives) Size() (n int) {
if l > 0 {
n += 1 + l + sovTest(uint64(l))
}
if m.FieldC {
n += 3
}
if m.XXX_unrecognized != nil {
n += len(m.XXX_unrecognized)
}
@ -268,6 +292,26 @@ func (m *Primitives) Unmarshal(dAtA []byte) error {
}
m.FieldB = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 200:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field FieldC", wireType)
}
var v int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.FieldC = bool(v != 0)
default:
iNdEx = preIndex
skippy, err := skipTest(dAtA[iNdEx:])

View file

@ -3,6 +3,7 @@ syntax = "proto3";
package test;
message Primitives {
bytes field_a = 1;
bytes field_a = 1;
string field_b = 2;
bool field_c = 200;
}