From 56c72b5c67a8885d0f273af2dbac49311debf063 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Mon, 16 Sep 2019 15:58:26 +0300 Subject: [PATCH] io: redo GetVarSize for Serializable things Use writes to a fake io.Writer that counts the bytes. Allows us to kill Size() methods as useless and duplicating lots of functionality. --- pkg/io/size.go | 29 +++++++++++++++++++++++++++-- pkg/io/size_test.go | 4 +++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/pkg/io/size.go b/pkg/io/size.go index bde30be19..ca966740c 100644 --- a/pkg/io/size.go +++ b/pkg/io/size.go @@ -17,6 +17,20 @@ var ( i64 int64 ) +// This structure is used to calculate the wire size of the serializable +// structure. It's an io.Writer that doesn't do any real writes, but instead +// just counts the number of bytes to be written. +type counterWriter struct { + counter int +} + +// Write implements the io.Writer interface +func (cw *counterWriter) Write(p []byte) (int, error) { + n := len(p) + cw.counter += n + return n, nil +} + // GetVarIntSize returns the size in number of bytes of a variable integer // (reference: GetVarSize(int value), https://github.com/neo-project/neo/blob/master/neo/IO/Helper.cs) func GetVarIntSize(value int) int { @@ -59,6 +73,18 @@ func GetVarSize(value interface{}) int { reflect.Uint32, reflect.Uint64: return GetVarIntSize(int(v.Uint())) + case reflect.Ptr: + vser, ok := v.Interface().(Serializable) + if !ok { + panic(fmt.Sprintf("unable to calculate GetVarSize for a non-Serializable pointer")) + } + cw := counterWriter{} + w := NewBinWriterFromIO(&cw) + err := vser.EncodeBinary(w) + if err != nil { + panic(fmt.Sprintf("error serializing %s: %s", reflect.TypeOf(value), err.Error())) + } + return cw.counter case reflect.Slice, reflect.Array: valueLength := v.Len() valueSize := 0 @@ -67,8 +93,7 @@ func GetVarSize(value interface{}) int { switch reflect.ValueOf(value).Index(0).Interface().(type) { case Serializable: for i := 0; i < valueLength; i++ { - elem := v.Index(i).Interface().(Serializable) - valueSize += elem.Size() + valueSize += GetVarSize(v.Index(i).Interface()) } case uint8, int8: valueSize = valueLength diff --git a/pkg/io/size_test.go b/pkg/io/size_test.go index e3028accd..9a5d2de66 100644 --- a/pkg/io/size_test.go +++ b/pkg/io/size_test.go @@ -10,13 +10,15 @@ import ( // Mock structure to test getting size of an array of serializable things type smthSerializable struct { + some [42]byte } func (*smthSerializable) DecodeBinary(*BinReader) error { return nil } -func (*smthSerializable) EncodeBinary(*BinWriter) error { +func (ss *smthSerializable) EncodeBinary(bw *BinWriter) error { + bw.WriteLE(ss.some) return nil }