diff --git a/pkg/io/binaryWriter.go b/pkg/io/binaryWriter.go index 5ad499988..5ec3e42fa 100644 --- a/pkg/io/binaryWriter.go +++ b/pkg/io/binaryWriter.go @@ -11,6 +11,7 @@ import ( // from a struct with many fields. type BinWriter struct { w io.Writer + uv []byte u64 []byte u32 []byte u16 []byte @@ -20,11 +21,12 @@ type BinWriter struct { // NewBinWriterFromIO makes a BinWriter from io.Writer. func NewBinWriterFromIO(iow io.Writer) *BinWriter { - u64 := make([]byte, 8) + uv := make([]byte, 9) + u64 := uv[:8] u32 := u64[:4] u16 := u64[:2] u8 := u64[:1] - return &BinWriter{w: iow, u64: u64, u32: u32, u16: u16, u8: u8} + return &BinWriter{w: iow, uv: uv, u64: u64, u32: u32, u16: u16, u8: u8} } // WriteU64LE writes an uint64 value into the underlying io.Writer in @@ -106,23 +108,31 @@ func (w *BinWriter) WriteVarUint(val uint64) { return } + n := PutVarUint(w.uv, val) + w.WriteBytes(w.uv[:n]) +} + +// PutVarUint puts val in varint form to the pre-allocated buffer. +func PutVarUint(data []byte, val uint64) int { + _ = data[8] if val < 0xfd { - w.WriteB(byte(val)) - return + data[0] = byte(val) + return 1 } if val < 0xFFFF { - w.WriteB(byte(0xfd)) - w.WriteU16LE(uint16(val)) - return + data[0] = byte(0xfd) + binary.LittleEndian.PutUint16(data[1:], uint16(val)) + return 3 } if val < 0xFFFFFFFF { - w.WriteB(byte(0xfe)) - w.WriteU32LE(uint32(val)) - return + data[0] = byte(0xfe) + binary.LittleEndian.PutUint32(data[1:], uint32(val)) + return 5 } - w.WriteB(byte(0xff)) - w.WriteU64LE(val) + data[0] = byte(0xff) + binary.LittleEndian.PutUint64(data[1:], val) + return 9 } // WriteBytes writes a variable byte into the underlying io.Writer without prefix. diff --git a/pkg/vm/stackitem/serialization.go b/pkg/vm/stackitem/serialization.go index cbc54b4b0..60b98091c 100644 --- a/pkg/vm/stackitem/serialization.go +++ b/pkg/vm/stackitem/serialization.go @@ -19,38 +19,35 @@ var ErrUnserializable = errors.New("unserializable") // serContext is an internal serialization context. type serContext struct { - *io.BinWriter - buf *io.BufBinWriter + uv [9]byte + data []byte allowInvalid bool seen map[Item]bool } // Serialize encodes given Item into the byte slice. func Serialize(item Item) ([]byte, error) { - w := io.NewBufBinWriter() sc := serContext{ - BinWriter: w.BinWriter, - buf: w, allowInvalid: false, seen: make(map[Item]bool), } - sc.serialize(item) - if w.Err != nil { - return nil, w.Err + err := sc.serialize(item) + if err != nil { + return nil, err } - return w.Bytes(), nil + return sc.data, nil } // EncodeBinary encodes given Item into the given BinWriter. It's // similar to io.Serializable's EncodeBinary, but works with Item // interface. func EncodeBinary(item Item, w *io.BinWriter) { - sc := serContext{ - BinWriter: w, - allowInvalid: false, - seen: make(map[Item]bool), + data, err := Serialize(item) + if err != nil { + w.Err = err + return } - sc.serialize(item) + w.WriteBytes(data) } // EncodeBinaryProtected encodes given Item into the given BinWriter. It's @@ -59,88 +56,104 @@ func EncodeBinary(item Item, w *io.BinWriter) { // (like recursive array) is encountered it just writes special InvalidT // type of element to w. func EncodeBinaryProtected(item Item, w *io.BinWriter) { - bw := io.NewBufBinWriter() sc := serContext{ - BinWriter: bw.BinWriter, - buf: bw, allowInvalid: true, seen: make(map[Item]bool), } - sc.serialize(item) - if bw.Err != nil { + err := sc.serialize(item) + if err != nil { w.WriteBytes([]byte{byte(InvalidT)}) return } - w.WriteBytes(bw.Bytes()) + w.WriteBytes(sc.data) } -func (w *serContext) serialize(item Item) { - if w.Err != nil { - return - } +func (w *serContext) serialize(item Item) error { if w.seen[item] { - w.Err = ErrRecursive - return + return ErrRecursive } switch t := item.(type) { case *ByteArray: - w.WriteBytes([]byte{byte(ByteArrayT)}) - w.WriteVarBytes(t.Value().([]byte)) + w.data = append(w.data, byte(ByteArrayT)) + data := t.Value().([]byte) + w.appendVarUint(uint64(len(data))) + w.data = append(w.data, data...) case *Buffer: - w.WriteBytes([]byte{byte(BufferT)}) - w.WriteVarBytes(t.Value().([]byte)) + w.data = append(w.data, byte(BufferT)) + data := t.Value().([]byte) + w.appendVarUint(uint64(len(data))) + w.data = append(w.data, data...) case *Bool: - w.WriteBytes([]byte{byte(BooleanT)}) - w.WriteBool(t.Value().(bool)) + w.data = append(w.data, byte(BooleanT)) + if t.Value().(bool) { + w.data = append(w.data, 1) + } else { + w.data = append(w.data, 0) + } case *BigInteger: - w.WriteBytes([]byte{byte(IntegerT)}) - w.WriteVarBytes(bigint.ToBytes(t.Value().(*big.Int))) + w.data = append(w.data, byte(IntegerT)) + data := bigint.ToBytes(t.Value().(*big.Int)) + w.appendVarUint(uint64(len(data))) + w.data = append(w.data, data...) case *Interop: if w.allowInvalid { - w.WriteBytes([]byte{byte(InteropT)}) + w.data = append(w.data, byte(InteropT)) } else { - w.Err = fmt.Errorf("%w: Interop", ErrUnserializable) + return fmt.Errorf("%w: Interop", ErrUnserializable) } case *Array, *Struct: w.seen[item] = true _, isArray := t.(*Array) if isArray { - w.WriteBytes([]byte{byte(ArrayT)}) + w.data = append(w.data, byte(ArrayT)) } else { - w.WriteBytes([]byte{byte(StructT)}) + w.data = append(w.data, byte(StructT)) } arr := t.Value().([]Item) - w.WriteVarUint(uint64(len(arr))) + w.appendVarUint(uint64(len(arr))) for i := range arr { - w.serialize(arr[i]) + if err := w.serialize(arr[i]); err != nil { + return err + } } delete(w.seen, item) case *Map: w.seen[item] = true - w.WriteBytes([]byte{byte(MapT)}) - w.WriteVarUint(uint64(len(t.Value().([]MapElement)))) - for i := range t.Value().([]MapElement) { - w.serialize(t.Value().([]MapElement)[i].Key) - w.serialize(t.Value().([]MapElement)[i].Value) + elems := t.Value().([]MapElement) + w.data = append(w.data, byte(MapT)) + w.appendVarUint(uint64(len(elems))) + for i := range elems { + if err := w.serialize(elems[i].Key); err != nil { + return err + } + if err := w.serialize(elems[i].Value); err != nil { + return err + } } delete(w.seen, item) case Null: - w.WriteB(byte(AnyT)) + w.data = append(w.data, byte(AnyT)) case nil: if w.allowInvalid { - w.WriteBytes([]byte{byte(InvalidT)}) + w.data = append(w.data, byte(InvalidT)) } else { - w.Err = fmt.Errorf("%w: nil", ErrUnserializable) + return fmt.Errorf("%w: nil", ErrUnserializable) } } - if w.Err == nil && w.buf != nil && w.buf.Len() > MaxSize { - w.Err = errTooBigSize + if len(w.data) > MaxSize { + return errTooBigSize } + return nil +} + +func (w *serContext) appendVarUint(val uint64) { + n := io.PutVarUint(w.uv[:], val) + w.data = append(w.data, w.uv[:n]...) } // Deserialize decodes Item from the given byte slice.