diff --git a/session/marshal.go b/session/marshal.go index cda9579..8632398 100644 --- a/session/marshal.go +++ b/session/marshal.go @@ -211,7 +211,7 @@ func (c *ObjectSessionContext) StableMarshal(buf []byte) []byte { } offset := proto.EnumMarshal(objectCtxVerbField, buf, int32(c.verb)) - proto.NestedStructureMarshal(objectCtxTargetField, buf[offset:], objectSessionContextTarget{ + proto.NestedStructureMarshalUnchecked(objectCtxTargetField, buf[offset:], objectSessionContextTarget{ cnr: c.cnr, objs: c.objs, }) @@ -225,7 +225,7 @@ func (c *ObjectSessionContext) StableSize() (size int) { } size += proto.EnumSize(objectCtxVerbField, int32(c.verb)) - size += proto.NestedStructureSize(objectCtxTargetField, objectSessionContextTarget{ + size += proto.NestedStructureSizeUnchecked(objectCtxTargetField, objectSessionContextTarget{ cnr: c.cnr, objs: c.objs, }) diff --git a/util/proto/marshal.go b/util/proto/marshal.go index b16375a..2704960 100644 --- a/util/proto/marshal.go +++ b/util/proto/marshal.go @@ -20,9 +20,10 @@ type ( StableSize() int } - setMarshalData interface { + setMarshalData[T any] interface { SetMarshalData([]byte) StableSize() int + ~*T } ) @@ -254,12 +255,21 @@ func VarUIntSize(x uint64) int { return (bits.Len64(x|1) + 6) / 7 } -func NestedStructureMarshal[T stableMarshaller](field int64, buf []byte, v T) int { - n := v.StableSize() - if n == 0 { +type ptrStableMarshaler[T any] interface { + stableMarshaller + ~*T +} + +func NestedStructureMarshal[T any, M ptrStableMarshaler[T]](field int64, buf []byte, v M) int { + if v == nil { return 0 } + return NestedStructureMarshalUnchecked(field, buf, v) +} + +func NestedStructureMarshalUnchecked[T stableMarshaller](field int64, buf []byte, v T) int { + n := v.StableSize() prefix := protowire.EncodeTag(protowire.Number(field), protowire.BytesType) offset := binary.PutUvarint(buf, prefix) offset += binary.PutUvarint(buf[offset:], uint64(n)) @@ -272,12 +282,12 @@ func NestedStructureMarshal[T stableMarshaller](field int64, buf []byte, v T) in // and calls SetMarshalData for nested structure. // // Returns marshalled data length of nested structure. -func NestedStructureSetMarshalData(field int64, parentData []byte, v setMarshalData) int { - n := v.StableSize() - if n == 0 { +func NestedStructureSetMarshalData[T any, M setMarshalData[T]](field int64, parentData []byte, v M) int { + if v == nil { return 0 } + n := v.StableSize() buf := make([]byte, binary.MaxVarintLen64) prefix := protowire.EncodeTag(protowire.Number(field), protowire.BytesType) offset := binary.PutUvarint(buf, prefix) @@ -288,13 +298,17 @@ func NestedStructureSetMarshalData(field int64, parentData []byte, v setMarshalD return offset + n } -func NestedStructureSize[T stableMarshaller](field int64, v T) (size int) { - n := v.StableSize() - if n == 0 { +func NestedStructureSize[T any, M ptrStableMarshaler[T]](field int64, v M) (size int) { + if v == nil { return 0 } - size = protowire.SizeGroup(protowire.Number(field), protowire.SizeBytes(n)) - return + + return NestedStructureSizeUnchecked(field, v) +} + +func NestedStructureSizeUnchecked[T stableMarshaller](field int64, v T) int { + n := v.StableSize() + return protowire.SizeGroup(protowire.Number(field), protowire.SizeBytes(n)) } func Fixed64Marshal(field int, buf []byte, v uint64) int {