diff --git a/util/proto/marshal.go b/util/proto/marshal.go index 66824b3..a602b78 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -8,6 +8,7 @@ package proto import ( "encoding/binary" + "math" "math/bits" "reflect" ) @@ -360,3 +361,26 @@ func Fixed64Size(fNum int, v uint64) int { return VarUIntSize(uint64(prefix)) + 8 } + +func Float64Marshal(field int, buf []byte, v float64) (int, error) { + if v == 0 { + return 0, nil + } + + prefix := field<<3 | 1 + + i := binary.PutUvarint(buf, uint64(prefix)) + binary.LittleEndian.PutUint64(buf[i:], math.Float64bits(v)) + + return i + 8, nil +} + +func Float64Size(fNum int, v float64) int { + if v == 0 { + return 0 + } + + prefix := fNum<<3 | 1 + + return VarUIntSize(uint64(prefix)) + 8 +} diff --git a/util/proto/marshal_test.go b/util/proto/marshal_test.go index 6844309..34b5a44 100644 --- a/util/proto/marshal_test.go +++ b/util/proto/marshal_test.go @@ -23,6 +23,7 @@ type stablePrimitives struct { FieldG uint64 FieldH SomeEnum FieldI uint64 // fixed64 + FieldJ float64 } type stableRepPrimitives struct { @@ -133,6 +134,16 @@ func (s *stablePrimitives) stableMarshal(buf []byte, wrongField bool) ([]byte, e } i += offset + fieldNum = 206 + if wrongField { + fieldNum++ + } + offset, err = proto.Float64Marshal(fieldNum, buf, s.FieldJ) + if err != nil { + return nil, errors.Wrap(err, "can't marshal field J") + } + i += offset + fieldNum = 300 if wrongField { fieldNum++ @@ -155,6 +166,7 @@ func (s *stablePrimitives) stableSize() int { proto.Int64Size(203, s.FieldF) + proto.UInt64Size(204, s.FieldG) + proto.Fixed64Size(205, s.FieldI) + + proto.Float64Size(206, s.FieldJ) + proto.EnumSize(300, int32(s.FieldH)) } @@ -453,6 +465,19 @@ func TestFixed64Marshal(t *testing.T) { }) } +func TestFloat64Marshal(t *testing.T) { + t.Run("zero", func(t *testing.T) { + testFloat64Marshal(t, 0, false) + }) + + t.Run("non zero", func(t *testing.T) { + f := math.Float64frombits(12345677890) + + testFloat64Marshal(t, f, false) + testFloat64Marshal(t, f, true) + }) +} + func testMarshal(t *testing.T, c stablePrimitives, tr test.Primitives, wrongField bool) *test.Primitives { var ( wire []byte @@ -589,6 +614,21 @@ func testUInt64Marshal(t *testing.T, n uint64, wrongField bool) { } } +func testFloat64Marshal(t *testing.T, n float64, wrongField bool) { + var ( + custom = stablePrimitives{FieldJ: n} + transport = test.Primitives{FieldJ: n} + ) + + result := testMarshal(t, custom, transport, wrongField) + + if !wrongField { + require.Equal(t, n, result.FieldJ) + } else { + require.EqualValues(t, 0, result.FieldJ) + } +} + func testEnumMarshal(t *testing.T, e SomeEnum, wrongField bool) { var ( custom = stablePrimitives{FieldH: e} diff --git a/util/proto/test/test.pb.go b/util/proto/test/test.pb.go index a0808f8..b194e1f 100644 Binary files a/util/proto/test/test.pb.go and b/util/proto/test/test.pb.go differ diff --git a/util/proto/test/test.proto b/util/proto/test/test.proto index e9c2758..c568a31 100644 --- a/util/proto/test/test.proto +++ b/util/proto/test/test.proto @@ -13,6 +13,7 @@ message Primitives { int64 field_f = 203; uint64 field_g = 204; fixed64 field_i = 205; + double field_j = 206; enum SomeEnum { UNKNOWN = 0;