From 41f9c504240dfb6f51d200d5e66f2b82af3c82e7 Mon Sep 17 00:00:00 2001 From: Alex Vanin Date: Fri, 14 Aug 2020 13:57:19 +0300 Subject: [PATCH] Add proto marshal helper for bytes Signed-off-by: Alex Vanin --- util/proto/marshal.go | 44 +++++ util/proto/marshal_test.go | 118 +++++++++++++ util/proto/test/test.pb.go | 327 +++++++++++++++++++++++++++++++++++++ util/proto/test/test.proto | 7 + 4 files changed, 496 insertions(+) create mode 100644 util/proto/marshal.go create mode 100644 util/proto/marshal_test.go create mode 100644 util/proto/test/test.pb.go create mode 100644 util/proto/test/test.proto diff --git a/util/proto/marshal.go b/util/proto/marshal.go new file mode 100644 index 0000000..cd9c493 --- /dev/null +++ b/util/proto/marshal.go @@ -0,0 +1,44 @@ +/* +This package contains help functions for stable marshaller. Their usage is +totally optional. One can implement fast stable marshaller without these +runtime function calls. +*/ + +package proto + +import ( + "encoding/binary" + "math/bits" +) + +func BytesMarshal(field int, buf, v []byte) (int, error) { + if len(v) == 0 { + return 0, nil + } + + // buf length check can prevent panic at PutUvarint, but it will make + // marshaller a bit slower. + + prefix := field<<3 | 0x2 + i := binary.PutUvarint(buf, uint64(prefix)) + i += binary.PutUvarint(buf[i:], uint64(len(v))) + i += copy(buf[i:], v) + + return i, nil +} + +func BytesSize(field int, v []byte) int { + ln := len(v) + if ln == 0 { + return 0 + } + + prefix := field<<3 | 0x2 + + return VarUIntSize(uint64(prefix)) + VarUIntSize(uint64(ln)) + ln +} + +// varUIntSize returns length of varint byte sequence for uint64 value 'x'. +func VarUIntSize(x uint64) int { + return (bits.Len64(x|1) + 6) / 7 +} diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go new file mode 100644 index 0000000..ccc1cca --- /dev/null +++ b/util/proto/marshal_test.go @@ -0,0 +1,118 @@ +package proto_test + +import ( + "testing" + + "github.com/nspcc-dev/neofs-api-go/util/proto" + "github.com/nspcc-dev/neofs-api-go/util/proto/test" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +type stablePrimitives struct { + FieldA []byte +} + +func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + + if buf == nil { + buf = make([]byte, s.stableSize()) + } + + var ( + i, offset int + ) + + offset, err := proto.BytesMarshal(1, buf, s.FieldA) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field a") + } + i += offset + + return buf, nil +} + +func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + + if buf == nil { + buf = make([]byte, s.stableSize()) + } + + var ( + i, offset int + ) + + offset, err := proto.BytesMarshal(1+1, buf, s.FieldA) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field a") + } + i += offset + + return buf, nil +} + +func (s *stablePrimitives) stableSize() int { + return proto.BytesSize(1, s.FieldA) +} + +func TestBytesMarshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := []byte("Hello World") + testBytesMarshal(t, data, false) + testBytesMarshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testBytesMarshal(t, []byte{}, false) + }) + + t.Run("nil", func(t *testing.T) { + testBytesMarshal(t, nil, false) + }) +} + +func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { + var ( + wire []byte + err error + + custom = stablePrimitives{FieldA: data} + transport = test.Primitives{FieldA: data} + ) + + if !wrongField { + wire, err = custom.stableMarshal(nil) + } else { + wire, err = custom.stableMarshalWrongFieldNum(nil) + } + require.NoError(t, err) + + wireGen, err := transport.Marshal() + require.NoError(t, err) + + if !wrongField { + // we can check equality because single field cannot be unstable marshalled + require.Equal(t, wireGen, wire) + } else { + require.NotEqual(t, wireGen, wire) + } + + result := new(test.Primitives) + err = result.Unmarshal(wire) + require.NoError(t, err) + + if !wrongField { + require.Len(t, result.FieldA, len(data)) + if len(data) > 0 { + require.Equal(t, data, result.FieldA) + } + } else { + require.Len(t, result.FieldA, 0) + } +} diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go new file mode 100644 index 0000000..5db45e5 --- /dev/null +++ b/util/proto/test/test.pb.go @@ -0,0 +1,327 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: util/proto/test/test.proto + +package test + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type Primitives struct { + FieldA []byte `protobuf:"bytes,1,opt,name=field_a,json=fieldA,proto3" json:"field_a,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Primitives) Reset() { *m = Primitives{} } +func (m *Primitives) String() string { return proto.CompactTextString(m) } +func (*Primitives) ProtoMessage() {} +func (*Primitives) Descriptor() ([]byte, []int) { + return fileDescriptor_998ad0e1a3de8558, []int{0} +} +func (m *Primitives) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Primitives) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Primitives.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Primitives) XXX_Merge(src proto.Message) { + xxx_messageInfo_Primitives.Merge(m, src) +} +func (m *Primitives) XXX_Size() int { + return m.Size() +} +func (m *Primitives) XXX_DiscardUnknown() { + xxx_messageInfo_Primitives.DiscardUnknown(m) +} + +var xxx_messageInfo_Primitives proto.InternalMessageInfo + +func (m *Primitives) GetFieldA() []byte { + if m != nil { + return m.FieldA + } + return nil +} + +func init() { + proto.RegisterType((*Primitives)(nil), "test.Primitives") +} + +func init() { proto.RegisterFile("util/proto/test/test.proto", fileDescriptor_998ad0e1a3de8558) } + +var fileDescriptor_998ad0e1a3de8558 = []byte{ + // 109 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2a, 0x2d, 0xc9, 0xcc, + 0xd1, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0xd7, 0x2f, 0x49, 0x2d, 0x2e, 0x01, 0x13, 0x7a, 0x60, 0xbe, + 0x10, 0x0b, 0x88, 0xad, 0xa4, 0xca, 0xc5, 0x15, 0x50, 0x94, 0x99, 0x9b, 0x59, 0x92, 0x59, 0x96, + 0x5a, 0x2c, 0x24, 0xce, 0xc5, 0x9e, 0x96, 0x99, 0x9a, 0x93, 0x12, 0x9f, 0x28, 0xc1, 0xa8, 0xc0, + 0xa8, 0xc1, 0x13, 0xc4, 0x06, 0xe6, 0x3a, 0x3a, 0x09, 0x9c, 0x78, 0x24, 0xc7, 0x78, 0xe1, 0x91, + 0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c, 0x33, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0x4d, 0x31, 0x06, + 0x04, 0x00, 0x00, 0xff, 0xff, 0xe5, 0xa7, 0x6b, 0x1f, 0x63, 0x00, 0x00, 0x00, +} + +func (m *Primitives) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Primitives) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Primitives) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if len(m.FieldA) > 0 { + i -= len(m.FieldA) + copy(dAtA[i:], m.FieldA) + i = encodeVarintTest(dAtA, i, uint64(len(m.FieldA))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintTest(dAtA []byte, offset int, v uint64) int { + offset -= sovTest(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *Primitives) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.FieldA) + if l > 0 { + n += 1 + l + sovTest(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func sovTest(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozTest(x uint64) (n int) { + return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *Primitives) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Primitives: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Primitives: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field FieldA", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthTest + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthTest + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.FieldA = append(m.FieldA[:0], dAtA[iNdEx:postIndex]...) + if m.FieldA == nil { + m.FieldA = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTest(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipTest(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthTest + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupTest + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthTest + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowTest = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupTest = fmt.Errorf("proto: unexpected end of group") +) diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto new file mode 100644 index 0000000..cf51a44 --- /dev/null +++ b/util/proto/test/test.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package test; + +message Primitives { + bytes field_a = 1; +} \ No newline at end of file