neo-go/pkg/vm/stackitem/serialization.go
Roman Khimov 9ba18b5dfa stackitem: serialize/deserialize pointers, fix #2815
They of course can't be serialized, but in protected mode we still need to
handle them somehow.
2022-11-20 16:02:50 +03:00

353 lines
8.9 KiB
Go

package stackitem
import (
"errors"
"fmt"
"math/big"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MaxDeserialized is the maximum number one deserialized item can contain
// (including itself).
const MaxDeserialized = 2048
// typicalNumOfItems is the number of items covering most serialization needs.
// It's a hint used for map creation, so it does not limit anything, it's just
// a microoptimization to avoid excessive reallocations. Most of the serialized
// items are structs, so there is at least one of them.
const typicalNumOfItems = 4
// ErrRecursive is returned upon an attempt to serialize some recursive stack item
// (like an array including an item with the reference to the same array).
var ErrRecursive = errors.New("recursive item")
// ErrUnserializable is returned upon an attempt to serialize some item that can't
// be serialized (like Interop item or Pointer).
var ErrUnserializable = errors.New("unserializable")
// SerializationContext is a serialization context.
type SerializationContext struct {
uv [9]byte
data []byte
allowInvalid bool
seen map[Item]sliceNoPointer
}
// deserContext is an internal deserialization context.
type deserContext struct {
*io.BinReader
allowInvalid bool
limit int
}
// Serialize encodes the given Item into a byte slice.
func Serialize(item Item) ([]byte, error) {
sc := SerializationContext{
allowInvalid: false,
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
}
err := sc.serialize(item)
if err != nil {
return nil, err
}
return sc.data, nil
}
// EncodeBinary encodes the given Item into the given BinWriter. It's
// similar to io.Serializable's EncodeBinary but works with Item
// interface.
func EncodeBinary(item Item, w *io.BinWriter) {
data, err := Serialize(item)
if err != nil {
w.Err = err
return
}
w.WriteBytes(data)
}
// EncodeBinaryProtected encodes the given Item into the given BinWriter. It's
// similar to EncodeBinary but allows to encode interop items (only type,
// value is lost) and doesn't return any errors in the w. Instead, if an error
// (like recursive array) is encountered, it just writes the special InvalidT
// type of an element to the w.
func EncodeBinaryProtected(item Item, w *io.BinWriter) {
sc := SerializationContext{
allowInvalid: true,
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
}
err := sc.serialize(item)
if err != nil {
w.WriteBytes([]byte{byte(InvalidT)})
return
}
w.WriteBytes(sc.data)
}
func (w *SerializationContext) writeArray(item Item, arr []Item, start int) error {
w.seen[item] = sliceNoPointer{}
w.appendVarUint(uint64(len(arr)))
for i := range arr {
if err := w.serialize(arr[i]); err != nil {
return err
}
}
w.seen[item] = sliceNoPointer{start, len(w.data)}
return nil
}
// NewSerializationContext returns reusable stack item serialization context.
func NewSerializationContext() *SerializationContext {
return &SerializationContext{
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
}
}
// Serialize returns flat slice of bytes with the given item. The process can be protected
// from bad elements if appropriate flag is given (otherwise an error is returned on
// encountering any of them). The buffer returned is only valid until the call to Serialize.
func (w *SerializationContext) Serialize(item Item, protected bool) ([]byte, error) {
w.allowInvalid = protected
if w.data != nil {
w.data = w.data[:0]
}
for k := range w.seen {
delete(w.seen, k)
}
err := w.serialize(item)
if err != nil && protected {
if w.data == nil {
w.data = make([]byte, 0, 1)
}
w.data = append(w.data[:0], byte(InvalidT))
err = nil
}
return w.data, err
}
func (w *SerializationContext) serialize(item Item) error {
if v, ok := w.seen[item]; ok {
if v.start == v.end {
return ErrRecursive
}
if len(w.data)+v.end-v.start > MaxSize {
return ErrTooBig
}
w.data = append(w.data, w.data[v.start:v.end]...)
return nil
}
start := len(w.data)
switch t := item.(type) {
case *ByteArray:
w.data = append(w.data, byte(ByteArrayT))
w.appendVarUint(uint64(len(*t)))
w.data = append(w.data, *t...)
case *Buffer:
w.data = append(w.data, byte(BufferT))
w.appendVarUint(uint64(len(*t)))
w.data = append(w.data, *t...)
case Bool:
w.data = append(w.data, byte(BooleanT))
if t {
w.data = append(w.data, 1)
} else {
w.data = append(w.data, 0)
}
case *BigInteger:
w.data = append(w.data, byte(IntegerT))
ln := len(w.data)
w.data = append(w.data, 0)
data := bigint.ToPreallocatedBytes((*big.Int)(t), w.data[len(w.data):])
w.data[ln] = byte(len(data))
w.data = append(w.data, data...)
case *Interop:
if w.allowInvalid {
w.data = append(w.data, byte(InteropT))
} else {
return fmt.Errorf("%w: Interop", ErrUnserializable)
}
case *Pointer:
if w.allowInvalid {
w.data = append(w.data, byte(PointerT))
w.appendVarUint(uint64(t.pos))
} else {
return fmt.Errorf("%w: Pointer", ErrUnserializable)
}
case *Array:
w.data = append(w.data, byte(ArrayT))
if err := w.writeArray(item, t.value, start); err != nil {
return err
}
case *Struct:
w.data = append(w.data, byte(StructT))
if err := w.writeArray(item, t.value, start); err != nil {
return err
}
case *Map:
w.seen[item] = sliceNoPointer{}
elems := t.value
w.data = append(w.data, byte(MapT))
w.appendVarUint(uint64(len(elems)))
for i := range elems {
if err := w.serialize(elems[i].Key); err != nil {
return err
}
if err := w.serialize(elems[i].Value); err != nil {
return err
}
}
w.seen[item] = sliceNoPointer{start, len(w.data)}
case Null:
w.data = append(w.data, byte(AnyT))
case nil:
if w.allowInvalid {
w.data = append(w.data, byte(InvalidT))
} else {
return fmt.Errorf("%w: nil", ErrUnserializable)
}
}
if len(w.data) > MaxSize {
return errTooBigSize
}
return nil
}
func (w *SerializationContext) appendVarUint(val uint64) {
n := io.PutVarUint(w.uv[:], val)
w.data = append(w.data, w.uv[:n]...)
}
// Deserialize decodes the Item from the given byte slice.
func Deserialize(data []byte) (Item, error) {
r := io.NewBinReaderFromBuf(data)
item := DecodeBinary(r)
if r.Err != nil {
return nil, r.Err
}
return item, nil
}
// DecodeBinary decodes the previously serialized Item from the given
// reader. It's similar to the io.Serializable's DecodeBinary() but implemented
// as a function because Item itself is an interface. Caveat: always check
// reader's error value before using the returned Item.
func DecodeBinary(r *io.BinReader) Item {
dc := deserContext{
BinReader: r,
allowInvalid: false,
limit: MaxDeserialized,
}
return dc.decodeBinary()
}
// DecodeBinaryProtected is similar to DecodeBinary but allows Interop and
// Invalid values to be present (making it symmetric to EncodeBinaryProtected).
func DecodeBinaryProtected(r *io.BinReader) Item {
dc := deserContext{
BinReader: r,
allowInvalid: true,
limit: MaxDeserialized,
}
return dc.decodeBinary()
}
func (r *deserContext) decodeBinary() Item {
var t = Type(r.ReadB())
if r.Err != nil {
return nil
}
r.limit--
if r.limit < 0 {
r.Err = errTooBigElements
return nil
}
switch t {
case ByteArrayT, BufferT:
data := r.ReadVarBytes(MaxSize)
if t == ByteArrayT {
return NewByteArray(data)
}
return NewBuffer(data)
case BooleanT:
var b = r.ReadBool()
return NewBool(b)
case IntegerT:
data := r.ReadVarBytes(bigint.MaxBytesLen)
num := bigint.FromBytes(data)
return NewBigInteger(num)
case ArrayT, StructT:
size := int(r.ReadVarUint())
if size > MaxDeserialized {
r.Err = errTooBigElements
return nil
}
arr := make([]Item, size)
for i := 0; i < size; i++ {
arr[i] = r.decodeBinary()
}
if t == ArrayT {
return NewArray(arr)
}
return NewStruct(arr)
case MapT:
size := int(r.ReadVarUint())
if size > MaxDeserialized {
r.Err = errTooBigElements
return nil
}
m := NewMap()
for i := 0; i < size; i++ {
key := r.decodeBinary()
value := r.decodeBinary()
if r.Err != nil {
break
}
m.Add(key, value)
}
return m
case AnyT:
return Null{}
case InteropT:
if r.allowInvalid {
return NewInterop(nil)
}
fallthrough
case PointerT:
if r.allowInvalid {
pos := int(r.ReadVarUint())
return NewPointerWithHash(pos, nil, util.Uint160{})
}
fallthrough
default:
if t == InvalidT && r.allowInvalid {
return nil
}
r.Err = fmt.Errorf("%w: %v", ErrInvalidType, t)
return nil
}
}
// SerializeConvertible serializes Convertible into a slice of bytes.
func SerializeConvertible(conv Convertible) ([]byte, error) {
item, err := conv.ToStackItem()
if err != nil {
return nil, err
}
return Serialize(item)
}
// DeserializeConvertible deserializes Convertible from a slice of bytes.
func DeserializeConvertible(data []byte, conv Convertible) error {
item, err := Deserialize(data)
if err != nil {
return err
}
return conv.FromStackItem(item)
}