Add proto marshal helper for bytes

Signed-off-by: Alex Vanin <alexey@nspcc.ru>
This commit is contained in:
Alex Vanin 2020-08-14 13:57:19 +03:00 committed by Stanislav Bogatyrev
parent 51e1c3bbcb
commit 41f9c50424
4 changed files with 496 additions and 0 deletions

44
util/proto/marshal.go Normal file
View file

@ -0,0 +1,44 @@
/*
This package contains help functions for stable marshaller. Their usage is
totally optional. One can implement fast stable marshaller without these
runtime function calls.
*/
package proto
import (
"encoding/binary"
"math/bits"
)
func BytesMarshal(field int, buf, v []byte) (int, error) {
if len(v) == 0 {
return 0, nil
}
// buf length check can prevent panic at PutUvarint, but it will make
// marshaller a bit slower.
prefix := field<<3 | 0x2
i := binary.PutUvarint(buf, uint64(prefix))
i += binary.PutUvarint(buf[i:], uint64(len(v)))
i += copy(buf[i:], v)
return i, nil
}
func BytesSize(field int, v []byte) int {
ln := len(v)
if ln == 0 {
return 0
}
prefix := field<<3 | 0x2
return VarUIntSize(uint64(prefix)) + VarUIntSize(uint64(ln)) + ln
}
// varUIntSize returns length of varint byte sequence for uint64 value 'x'.
func VarUIntSize(x uint64) int {
return (bits.Len64(x|1) + 6) / 7
}

118
util/proto/marshal_test.go Normal file
View file

@ -0,0 +1,118 @@
package proto_test
import (
"testing"
"github.com/nspcc-dev/neofs-api-go/util/proto"
"github.com/nspcc-dev/neofs-api-go/util/proto/test"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
)
type stablePrimitives struct {
FieldA []byte
}
func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
if buf == nil {
buf = make([]byte, s.stableSize())
}
var (
i, offset int
)
offset, err := proto.BytesMarshal(1, buf, s.FieldA)
if err != nil {
return nil, errors.Wrap(err, "can't marshal field a")
}
i += offset
return buf, nil
}
func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
if buf == nil {
buf = make([]byte, s.stableSize())
}
var (
i, offset int
)
offset, err := proto.BytesMarshal(1+1, buf, s.FieldA)
if err != nil {
return nil, errors.Wrap(err, "can't marshal field a")
}
i += offset
return buf, nil
}
func (s *stablePrimitives) stableSize() int {
return proto.BytesSize(1, s.FieldA)
}
func TestBytesMarshal(t *testing.T) {
t.Run("not empty", func(t *testing.T) {
data := []byte("Hello World")
testBytesMarshal(t, data, false)
testBytesMarshal(t, data, true)
})
t.Run("empty", func(t *testing.T) {
testBytesMarshal(t, []byte{}, false)
})
t.Run("nil", func(t *testing.T) {
testBytesMarshal(t, nil, false)
})
}
func testBytesMarshal(t *testing.T, data []byte, wrongField bool) {
var (
wire []byte
err error
custom = stablePrimitives{FieldA: data}
transport = test.Primitives{FieldA: data}
)
if !wrongField {
wire, err = custom.stableMarshal(nil)
} else {
wire, err = custom.stableMarshalWrongFieldNum(nil)
}
require.NoError(t, err)
wireGen, err := transport.Marshal()
require.NoError(t, err)
if !wrongField {
// we can check equality because single field cannot be unstable marshalled
require.Equal(t, wireGen, wire)
} else {
require.NotEqual(t, wireGen, wire)
}
result := new(test.Primitives)
err = result.Unmarshal(wire)
require.NoError(t, err)
if !wrongField {
require.Len(t, result.FieldA, len(data))
if len(data) > 0 {
require.Equal(t, data, result.FieldA)
}
} else {
require.Len(t, result.FieldA, 0)
}
}

327
util/proto/test/test.pb.go Normal file
View file

@ -0,0 +1,327 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: util/proto/test/test.proto
package test
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
io "io"
math "math"
math_bits "math/bits"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type Primitives struct {
FieldA []byte `protobuf:"bytes,1,opt,name=field_a,json=fieldA,proto3" json:"field_a,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Primitives) Reset() { *m = Primitives{} }
func (m *Primitives) String() string { return proto.CompactTextString(m) }
func (*Primitives) ProtoMessage() {}
func (*Primitives) Descriptor() ([]byte, []int) {
return fileDescriptor_998ad0e1a3de8558, []int{0}
}
func (m *Primitives) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Primitives) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Primitives.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Primitives) XXX_Merge(src proto.Message) {
xxx_messageInfo_Primitives.Merge(m, src)
}
func (m *Primitives) XXX_Size() int {
return m.Size()
}
func (m *Primitives) XXX_DiscardUnknown() {
xxx_messageInfo_Primitives.DiscardUnknown(m)
}
var xxx_messageInfo_Primitives proto.InternalMessageInfo
func (m *Primitives) GetFieldA() []byte {
if m != nil {
return m.FieldA
}
return nil
}
func init() {
proto.RegisterType((*Primitives)(nil), "test.Primitives")
}
func init() { proto.RegisterFile("util/proto/test/test.proto", fileDescriptor_998ad0e1a3de8558) }
var fileDescriptor_998ad0e1a3de8558 = []byte{
// 109 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2a, 0x2d, 0xc9, 0xcc,
0xd1, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0xd7, 0x2f, 0x49, 0x2d, 0x2e, 0x01, 0x13, 0x7a, 0x60, 0xbe,
0x10, 0x0b, 0x88, 0xad, 0xa4, 0xca, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96,
0x5a, 0x2c, 0x24, 0xce, 0xc5, 0x9e, 0x96, 0x99, 0x9a, 0x93, 0x12, 0x9f, 0x28, 0xc1, 0xa8, 0xc0,
0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x3a, 0x09, 0x9c, 0x78, 0x24, 0xc7, 0x78, 0xe1, 0x91,
0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c, 0x33, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0x4d, 0x31, 0x06,
0x04, 0x00, 0x00, 0xff, 0xff, 0xe5, 0xa7, 0x6b, 0x1f, 0x63, 0x00, 0x00, 0x00,
}
func (m *Primitives) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Primitives) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Primitives) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.XXX_unrecognized != nil {
i -= len(m.XXX_unrecognized)
copy(dAtA[i:], m.XXX_unrecognized)
}
if len(m.FieldA) > 0 {
i -= len(m.FieldA)
copy(dAtA[i:], m.FieldA)
i = encodeVarintTest(dAtA, i, uint64(len(m.FieldA)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func encodeVarintTest(dAtA []byte, offset int, v uint64) int {
offset -= sovTest(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *Primitives) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.FieldA)
if l > 0 {
n += 1 + l + sovTest(uint64(l))
}
if m.XXX_unrecognized != nil {
n += len(m.XXX_unrecognized)
}
return n
}
func sovTest(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozTest(x uint64) (n int) {
return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *Primitives) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Primitives: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Primitives: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field FieldA", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthTest
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthTest
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.FieldA = append(m.FieldA[:0], dAtA[iNdEx:postIndex]...)
if m.FieldA == nil {
m.FieldA = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipTest(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthTest
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthTest
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...)
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipTest(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthTest
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroupTest
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLengthTest
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}
var (
ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowTest = fmt.Errorf("proto: integer overflow")
ErrUnexpectedEndOfGroupTest = fmt.Errorf("proto: unexpected end of group")
)

View file

@ -0,0 +1,7 @@
syntax = "proto3";
package test;
message Primitives {
bytes field_a = 1;
}