diff --git a/util/proto/marshal.go b/util/proto/marshal.go index 025d9b4..cc055a5 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -128,6 +128,165 @@ func EnumSize(field int, v int32) int { return UInt64Size(field, uint64(v)) } +func RepeatedBytesMarshal(field int, buf []byte, v [][]byte) (int, error) { + var offset int + + for i := range v { + off, err := BytesMarshal(field, buf[offset:], v[i]) + if err != nil { + return 0, err + } + + offset += off + } + + return offset, nil +} + +func RepeatedBytesSize(field int, v [][]byte) (size int) { + for i := range v { + size += BytesSize(field, v[i]) + } + + return size +} + +func RepeatedStringMarshal(field int, buf []byte, v []string) (int, error) { + var offset int + + for i := range v { + off, err := StringMarshal(field, buf[offset:], v[i]) + if err != nil { + return 0, err + } + + offset += off + } + + return offset, nil +} + +func RepeatedStringSize(field int, v []string) (size int) { + for i := range v { + size += StringSize(field, v[i]) + } + + return size +} + +func RepeatedUInt64Marshal(field int, buf []byte, v []uint64) (int, error) { + if len(v) == 0 { + return 0, nil + } + + prefix := field<<3 | 0x02 + offset := binary.PutUvarint(buf, uint64(prefix)) + + _, arrSize := RepeatedUInt64Size(field, v) + offset += binary.PutUvarint(buf[offset:], uint64(arrSize)) + for i := range v { + offset += binary.PutUvarint(buf[offset:], v[i]) + } + + return offset, nil +} + +func RepeatedUInt64Size(field int, v []uint64) (size, arraySize int) { + if len(v) == 0 { + return 0, 0 + } + + for i := range v { + size += VarUIntSize(v[i]) + } + arraySize = size + + size += VarUIntSize(uint64(size)) + + prefix := field<<3 | 0x2 + size += VarUIntSize(uint64(prefix)) + + return size, arraySize +} + +func RepeatedInt64Marshal(field int, buf []byte, v []int64) (int, error) { + if len(v) == 0 { + return 0, nil + } + + convert := make([]uint64, len(v)) + for i := range v { + convert[i] = uint64(v[i]) + } + + return RepeatedUInt64Marshal(field, buf, convert) +} + +func RepeatedInt64Size(field int, v []int64) (size, arraySize int) { + if len(v) == 0 { + return 0, 0 + } + + convert := make([]uint64, len(v)) + for i := range v { + convert[i] = uint64(v[i]) + } + + return RepeatedUInt64Size(field, convert) +} + +func RepeatedUInt32Marshal(field int, buf []byte, v []uint32) (int, error) { + if len(v) == 0 { + return 0, nil + } + + convert := make([]uint64, len(v)) + for i := range v { + convert[i] = uint64(v[i]) + } + + return RepeatedUInt64Marshal(field, buf, convert) +} + +func RepeatedUInt32Size(field int, v []uint32) (size, arraySize int) { + if len(v) == 0 { + return 0, 0 + } + + convert := make([]uint64, len(v)) + for i := range v { + convert[i] = uint64(v[i]) + } + + return RepeatedUInt64Size(field, convert) +} + +func RepeatedInt32Marshal(field int, buf []byte, v []int32) (int, error) { + if len(v) == 0 { + return 0, nil + } + + convert := make([]uint64, len(v)) + for i := range v { + convert[i] = uint64(v[i]) + } + + return RepeatedUInt64Marshal(field, buf, convert) +} + +func RepeatedInt32Size(field int, v []int32) (size, arraySize int) { + if len(v) == 0 { + return 0, 0 + } + + convert := make([]uint64, len(v)) + for i := range v { + convert[i] = uint64(v[i]) + } + + return RepeatedUInt64Size(field, convert) +} + // varUIntSize returns length of varint byte sequence for uint64 value 'x'. func VarUIntSize(x uint64) int { return (bits.Len64(x|1) + 6) / 7 diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go index 6802523..de7b1ca 100644 --- a/util/proto/marshal_test.go +++ b/util/proto/marshal_test.go @@ -23,6 +23,15 @@ type stablePrimitives struct { FieldH SomeEnum } +type stableRepPrimitives struct { + FieldA [][]byte + FieldB []string + FieldC []int32 + FieldD []uint32 + FieldE []int64 + FieldF []uint64 +} + const ( ENUM_UNKNOWN SomeEnum = 0 ENUM_POSITIVE = 1 @@ -118,7 +127,7 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e } offset, err = proto.EnumMarshal(fieldNum, buf, int32(s.FieldH)) if err != nil { - return nil, errors.Wrap(err, "can't marshal field g") + return nil, errors.Wrap(err, "can't marshal field h") } i += offset @@ -136,6 +145,93 @@ func (s *stablePrimitives) stableSize() int { proto.EnumSize(300, int32(s.FieldH)) } +func (s *stableRepPrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + + if buf == nil { + buf = make([]byte, s.stableSize()) + } + + var ( + i, offset, fieldNum int + ) + + fieldNum = 1 + if wrongField { + fieldNum++ + } + offset, err := proto.RepeatedBytesMarshal(fieldNum, buf, s.FieldA) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field a") + } + i += offset + + fieldNum = 2 + if wrongField { + fieldNum++ + } + offset, err = proto.RepeatedStringMarshal(fieldNum, buf, s.FieldB) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field b") + } + i += offset + + fieldNum = 3 + if wrongField { + fieldNum++ + } + offset, err = proto.RepeatedInt32Marshal(fieldNum, buf, s.FieldC) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field c") + } + i += offset + + fieldNum = 4 + if wrongField { + fieldNum++ + } + offset, err = proto.RepeatedUInt32Marshal(fieldNum, buf, s.FieldD) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field d") + } + i += offset + + fieldNum = 5 + if wrongField { + fieldNum++ + } + offset, err = proto.RepeatedInt64Marshal(fieldNum, buf, s.FieldE) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field e") + } + i += offset + + fieldNum = 6 + if wrongField { + fieldNum++ + } + offset, err = proto.RepeatedUInt64Marshal(fieldNum, buf, s.FieldF) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field f") + } + i += offset + + return buf, nil +} + +func (s *stableRepPrimitives) stableSize() int { + f1 := proto.RepeatedBytesSize(1, s.FieldA) + f2 := proto.RepeatedStringSize(2, s.FieldB) + f3, _ := proto.RepeatedInt32Size(3, s.FieldC) + f4, _ := proto.RepeatedUInt32Size(4, s.FieldD) + f5, _ := proto.RepeatedInt64Size(5, s.FieldE) + f6, _ := proto.RepeatedUInt64Size(6, s.FieldF) + + return f1 + f2 + f3 + f4 + f5 + f6 +} + func TestBytesMarshal(t *testing.T) { t.Run("not empty", func(t *testing.T) { data := []byte("Hello World") @@ -237,112 +333,103 @@ func TestEnumMarshal(t *testing.T) { testEnumMarshal(t, ENUM_NEGATIVE, true) } -func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { - var ( - wire []byte - err error +func TestRepeatedBytesMarshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := [][]byte{[]byte("One"), []byte("Two"), []byte("Three")} + testRepeatedBytesMarshal(t, data, false) + testRepeatedBytesMarshal(t, data, true) + }) - custom = stablePrimitives{FieldA: data} - transport = test.Primitives{FieldA: data} - ) + t.Run("empty", func(t *testing.T) { + testRepeatedBytesMarshal(t, [][]byte{}, false) + }) - wire, err = custom.stableMarshal(nil, wrongField) - require.NoError(t, err) - - wireGen, err := transport.Marshal() - require.NoError(t, err) - - if !wrongField { - // we can check equality because single field cannot be unstable marshalled - require.Equal(t, wireGen, wire) - } else { - require.NotEqual(t, wireGen, wire) - } - - result := new(test.Primitives) - err = result.Unmarshal(wire) - require.NoError(t, err) - - if !wrongField { - require.Len(t, result.FieldA, len(data)) - if len(data) > 0 { - require.Equal(t, data, result.FieldA) - } - } else { - require.Len(t, result.FieldA, 0) - } + t.Run("nil", func(t *testing.T) { + testRepeatedBytesMarshal(t, nil, false) + }) } -func testStringMarshal(t *testing.T, s string, wrongField bool) { - var ( - wire []byte - err error +func TestRepeatedStringMarshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := []string{"One", "Two", "Three"} + testRepeatedStringMarshal(t, data, false) + testRepeatedStringMarshal(t, data, true) + }) - custom = stablePrimitives{FieldB: s} - transport = test.Primitives{FieldB: s} - ) + t.Run("empty", func(t *testing.T) { + testRepeatedStringMarshal(t, []string{}, false) + }) - wire, err = custom.stableMarshal(nil, wrongField) - require.NoError(t, err) - - wireGen, err := transport.Marshal() - require.NoError(t, err) - - if !wrongField { - // we can check equality because single field cannot be unstable marshalled - require.Equal(t, wireGen, wire) - } else { - require.NotEqual(t, wireGen, wire) - } - - result := new(test.Primitives) - err = result.Unmarshal(wire) - require.NoError(t, err) - - if !wrongField { - require.Len(t, result.FieldB, len(s)) - if len(s) > 0 { - require.Equal(t, s, result.FieldB) - } - } else { - require.Len(t, result.FieldB, 0) - } + t.Run("nil", func(t *testing.T) { + testRepeatedStringMarshal(t, nil, false) + }) } -func testBoolMarshal(t *testing.T, b bool, wrongField bool) { - var ( - wire []byte - err error +func TestRepeatedInt32Marshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := []int32{-1, 0, 1, 2, 3, 4, 5} + testRepeatedInt32Marshal(t, data, false) + testRepeatedInt32Marshal(t, data, true) + }) - custom = stablePrimitives{FieldC: b} - transport = test.Primitives{FieldC: b} - ) + t.Run("empty", func(t *testing.T) { + testRepeatedInt32Marshal(t, []int32{}, false) + }) - wire, err = custom.stableMarshal(nil, wrongField) - require.NoError(t, err) - - wireGen, err := transport.Marshal() - require.NoError(t, err) - - if !wrongField { - // we can check equality because single field cannot be unstable marshalled - require.Equal(t, wireGen, wire) - } else { - require.NotEqual(t, wireGen, wire) - } - - result := new(test.Primitives) - err = result.Unmarshal(wire) - require.NoError(t, err) - - if !wrongField { - require.Equal(t, b, result.FieldC) - } else { - require.False(t, false, result.FieldC) - } + t.Run("nil", func(t *testing.T) { + testRepeatedInt32Marshal(t, nil, false) + }) } -func testIntMarshal(t *testing.T, c stablePrimitives, tr test.Primitives, wrongField bool) *test.Primitives { +func TestRepeatedUInt32Marshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := []uint32{0, 1, 2, 3, 4, 5} + testRepeatedUInt32Marshal(t, data, false) + testRepeatedUInt32Marshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testRepeatedUInt32Marshal(t, []uint32{}, false) + }) + + t.Run("nil", func(t *testing.T) { + testRepeatedUInt32Marshal(t, nil, false) + }) +} + +func TestRepeatedInt64Marshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := []int64{-1, 0, 1, 2, 3, 4, 5} + testRepeatedInt64Marshal(t, data, false) + testRepeatedInt64Marshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testRepeatedInt64Marshal(t, []int64{}, false) + }) + + t.Run("nil", func(t *testing.T) { + testRepeatedInt64Marshal(t, nil, false) + }) +} + +func TestRepeatedUInt64Marshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := []uint64{0, 1, 2, 3, 4, 5} + testRepeatedUInt64Marshal(t, data, false) + testRepeatedUInt64Marshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testRepeatedUInt64Marshal(t, []uint64{}, false) + }) + + t.Run("nil", func(t *testing.T) { + testRepeatedUInt64Marshal(t, nil, false) + }) +} + +func testMarshal(t *testing.T, c stablePrimitives, tr test.Primitives, wrongField bool) *test.Primitives { var ( wire []byte err error @@ -367,13 +454,64 @@ func testIntMarshal(t *testing.T, c stablePrimitives, tr test.Primitives, wrongF return result } +func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { + var ( + custom = stablePrimitives{FieldA: data} + transport = test.Primitives{FieldA: data} + ) + + result := testMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Len(t, result.FieldA, len(data)) + if len(data) > 0 { + require.Equal(t, data, result.FieldA) + } + } else { + require.Len(t, result.FieldA, 0) + } +} + +func testStringMarshal(t *testing.T, s string, wrongField bool) { + var ( + custom = stablePrimitives{FieldB: s} + transport = test.Primitives{FieldB: s} + ) + + result := testMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Len(t, result.FieldB, len(s)) + if len(s) > 0 { + require.Equal(t, s, result.FieldB) + } + } else { + require.Len(t, result.FieldB, 0) + } +} + +func testBoolMarshal(t *testing.T, b bool, wrongField bool) { + var ( + custom = stablePrimitives{FieldC: b} + transport = test.Primitives{FieldC: b} + ) + + result := testMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Equal(t, b, result.FieldC) + } else { + require.False(t, false, result.FieldC) + } +} + func testInt32Marshal(t *testing.T, n int32, wrongField bool) { var ( custom = stablePrimitives{FieldD: n} transport = test.Primitives{FieldD: n} ) - result := testIntMarshal(t, custom, transport, wrongField) + result := testMarshal(t, custom, transport, wrongField) if !wrongField { require.Equal(t, n, result.FieldD) @@ -388,7 +526,7 @@ func testUInt32Marshal(t *testing.T, n uint32, wrongField bool) { transport = test.Primitives{FieldE: n} ) - result := testIntMarshal(t, custom, transport, wrongField) + result := testMarshal(t, custom, transport, wrongField) if !wrongField { require.Equal(t, n, result.FieldE) @@ -403,7 +541,7 @@ func testInt64Marshal(t *testing.T, n int64, wrongField bool) { transport = test.Primitives{FieldF: n} ) - result := testIntMarshal(t, custom, transport, wrongField) + result := testMarshal(t, custom, transport, wrongField) if !wrongField { require.Equal(t, n, result.FieldF) @@ -418,7 +556,7 @@ func testUInt64Marshal(t *testing.T, n uint64, wrongField bool) { transport = test.Primitives{FieldG: n} ) - result := testIntMarshal(t, custom, transport, wrongField) + result := testMarshal(t, custom, transport, wrongField) if !wrongField { require.Equal(t, n, result.FieldG) @@ -429,17 +567,28 @@ func testUInt64Marshal(t *testing.T, n uint64, wrongField bool) { func testEnumMarshal(t *testing.T, e SomeEnum, wrongField bool) { var ( - wire []byte - err error - custom = stablePrimitives{FieldH: e} transport = test.Primitives{FieldH: test.Primitives_SomeEnum(e)} ) - wire, err = custom.stableMarshal(nil, wrongField) + result := testMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.EqualValues(t, custom.FieldH, result.FieldH) + } else { + require.EqualValues(t, 0, result.FieldH) + } +} + +func testRepMarshal(t *testing.T, c stableRepPrimitives, tr test.RepPrimitives, wrongField bool) *test.RepPrimitives { + var ( + wire []byte + err error + ) + wire, err = c.stableMarshal(nil, wrongField) require.NoError(t, err) - wireGen, err := transport.Marshal() + wireGen, err := tr.Marshal() require.NoError(t, err) if !wrongField { @@ -449,13 +598,117 @@ func testEnumMarshal(t *testing.T, e SomeEnum, wrongField bool) { require.NotEqual(t, wireGen, wire) } - result := new(test.Primitives) + result := new(test.RepPrimitives) err = result.Unmarshal(wire) require.NoError(t, err) + return result +} + +func testRepeatedBytesMarshal(t *testing.T, data [][]byte, wrongField bool) { + var ( + custom = stableRepPrimitives{FieldA: data} + transport = test.RepPrimitives{FieldA: data} + ) + + result := testRepMarshal(t, custom, transport, wrongField) + if !wrongField { - require.EqualValues(t, custom.FieldH, result.FieldH) + require.Len(t, result.FieldA, len(data)) + if len(data) > 0 { + require.Equal(t, data, result.FieldA) + } } else { - require.EqualValues(t, 0, result.FieldH) + require.Len(t, result.FieldA, 0) + } +} + +func testRepeatedStringMarshal(t *testing.T, s []string, wrongField bool) { + var ( + custom = stableRepPrimitives{FieldB: s} + transport = test.RepPrimitives{FieldB: s} + ) + + result := testRepMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Len(t, result.FieldB, len(s)) + if len(s) > 0 { + require.Equal(t, s, result.FieldB) + } + } else { + require.Len(t, result.FieldB, 0) + } +} + +func testRepeatedInt32Marshal(t *testing.T, n []int32, wrongField bool) { + var ( + custom = stableRepPrimitives{FieldC: n} + transport = test.RepPrimitives{FieldC: n} + ) + + result := testRepMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Len(t, result.FieldC, len(n)) + if len(n) > 0 { + require.Equal(t, n, result.FieldC) + } + } else { + require.Len(t, result.FieldC, 0) + } +} + +func testRepeatedUInt32Marshal(t *testing.T, n []uint32, wrongField bool) { + var ( + custom = stableRepPrimitives{FieldD: n} + transport = test.RepPrimitives{FieldD: n} + ) + + result := testRepMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Len(t, result.FieldD, len(n)) + if len(n) > 0 { + require.Equal(t, n, result.FieldD) + } + } else { + require.Len(t, result.FieldD, 0) + } +} + +func testRepeatedInt64Marshal(t *testing.T, n []int64, wrongField bool) { + var ( + custom = stableRepPrimitives{FieldE: n} + transport = test.RepPrimitives{FieldE: n} + ) + + result := testRepMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Len(t, result.FieldE, len(n)) + if len(n) > 0 { + require.Equal(t, n, result.FieldE) + } + } else { + require.Len(t, result.FieldE, 0) + } +} + +func testRepeatedUInt64Marshal(t *testing.T, n []uint64, wrongField bool) { + var ( + custom = stableRepPrimitives{FieldF: n} + transport = test.RepPrimitives{FieldF: n} + ) + + result := testRepMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Len(t, result.FieldF, len(n)) + if len(n) > 0 { + require.Equal(t, n, result.FieldF) + } + } else { + require.Len(t, result.FieldF, 0) } } diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go index 0c31467..686ef72 100644 Binary files a/util/proto/test/test.pb.go and b/util/proto/test/test.pb.go differ diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto index 5aed2b2..8170b87 100644 --- a/util/proto/test/test.proto +++ b/util/proto/test/test.proto @@ -17,4 +17,13 @@ message Primitives { NEGATIVE = -1; } SomeEnum field_h = 300; +} + +message RepPrimitives { + repeated bytes field_a = 1; + repeated string field_b = 2; + repeated int32 field_c = 3; + repeated uint32 field_d = 4; + repeated int64 field_e = 5; + repeated uint64 field_f = 6; } \ No newline at end of file