diff --git a/util/proto/marshal.go b/util/proto/marshal.go index cd9c493a..2d5f9e69 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -38,6 +38,14 @@ func BytesSize(field int, v []byte) int { return VarUIntSize(uint64(prefix)) + VarUIntSize(uint64(ln)) + ln } +func StringMarshal(field int, buf []byte, v string) (int, error) { + return BytesMarshal(field, buf, []byte(v)) +} + +func StringSize(field int, v string) int { + return BytesSize(field, []byte(v)) +} + // varUIntSize returns length of varint byte sequence for uint64 value 'x'. func VarUIntSize(x uint64) int { return (bits.Len64(x|1) + 6) / 7 diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go index ccc1cca7..66a8b625 100644 --- a/util/proto/marshal_test.go +++ b/util/proto/marshal_test.go @@ -11,9 +11,10 @@ import ( type stablePrimitives struct { FieldA []byte + FieldB string } -func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) { +func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, error) { if s == nil { return []byte{}, nil } @@ -23,32 +24,24 @@ func (s *stablePrimitives) stableMarshal(buf []byte) ([]byte, error) { } var ( - i, offset int + i, offset, fieldNum int ) - offset, err := proto.BytesMarshal(1, buf, s.FieldA) + fieldNum = 1 + if wrongField { + fieldNum++ + } + offset, err := proto.BytesMarshal(fieldNum, buf, s.FieldA) if err != nil { return nil, errors.Wrap(err, "can't marshal field a") } i += offset - return buf, nil -} - -func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error) { - if s == nil { - return []byte{}, nil + fieldNum = 2 + if wrongField { + fieldNum++ } - - if buf == nil { - buf = make([]byte, s.stableSize()) - } - - var ( - i, offset int - ) - - offset, err := proto.BytesMarshal(1+1, buf, s.FieldA) + offset, err = proto.StringMarshal(fieldNum, buf, s.FieldB) if err != nil { return nil, errors.Wrap(err, "can't marshal field a") } @@ -58,7 +51,8 @@ func (s *stablePrimitives) stableMarshalWrongFieldNum(buf []byte) ([]byte, error } func (s *stablePrimitives) stableSize() int { - return proto.BytesSize(1, s.FieldA) + return proto.BytesSize(1, s.FieldA) + + proto.StringSize(2, s.FieldB) } func TestBytesMarshal(t *testing.T) { @@ -86,11 +80,7 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { transport = test.Primitives{FieldA: data} ) - if !wrongField { - wire, err = custom.stableMarshal(nil) - } else { - wire, err = custom.stableMarshalWrongFieldNum(nil) - } + wire, err = custom.stableMarshal(nil, wrongField) require.NoError(t, err) wireGen, err := transport.Marshal() @@ -116,3 +106,51 @@ func testBytesMarshal(t *testing.T, data []byte, wrongField bool) { require.Len(t, result.FieldA, 0) } } + +func TestStringMarshal(t *testing.T) { + t.Run("not empty", func(t *testing.T) { + data := "Hello World" + testStringMarshal(t, data, false) + testStringMarshal(t, data, true) + }) + + t.Run("empty", func(t *testing.T) { + testStringMarshal(t, "", false) + }) +} + +func testStringMarshal(t *testing.T, s string, wrongField bool) { + var ( + wire []byte + err error + + custom = stablePrimitives{FieldB: s} + transport = test.Primitives{FieldB: s} + ) + + wire, err = custom.stableMarshal(nil, wrongField) + require.NoError(t, err) + + wireGen, err := transport.Marshal() + require.NoError(t, err) + + if !wrongField { + // we can check equality because single field cannot be unstable marshalled + require.Equal(t, wireGen, wire) + } else { + require.NotEqual(t, wireGen, wire) + } + + result := new(test.Primitives) + err = result.Unmarshal(wire) + require.NoError(t, err) + + if !wrongField { + require.Len(t, result.FieldB, len(s)) + if len(s) > 0 { + require.Equal(t, s, result.FieldB) + } + } else { + require.Len(t, result.FieldB, 0) + } +} diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go index 5db45e5e..9a1afd35 100644 Binary files a/util/proto/test/test.pb.go and b/util/proto/test/test.pb.go differ diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto index cf51a44d..81a1058a 100644 --- a/util/proto/test/test.proto +++ b/util/proto/test/test.proto @@ -4,4 +4,5 @@ package test; message Primitives { bytes field_a = 1; + string field_b = 2; } \ No newline at end of file