stackitem: refactor serialization code, add explicit context

This commit is contained in:
Roman Khimov 2021-07-06 19:32:52 +03:00
parent 1b7b7e4bec
commit 8472064bbc

View file

@ -9,10 +9,22 @@ import (
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
) )
// serContext is an internal serialization context.
type serContext struct {
*io.BinWriter
allowInvalid bool
seen map[Item]bool
}
// SerializeItem encodes given Item into the byte slice. // SerializeItem encodes given Item into the byte slice.
func SerializeItem(item Item) ([]byte, error) { func SerializeItem(item Item) ([]byte, error) {
w := io.NewBufBinWriter() w := io.NewBufBinWriter()
EncodeBinaryStackItem(item, w.BinWriter) sc := serContext{
BinWriter: w.BinWriter,
allowInvalid: false,
seen: make(map[Item]bool),
}
sc.serialize(item)
if w.Err != nil { if w.Err != nil {
return nil, w.Err return nil, w.Err
} }
@ -23,14 +35,24 @@ func SerializeItem(item Item) ([]byte, error) {
// similar to io.Serializable's EncodeBinary, but works with Item // similar to io.Serializable's EncodeBinary, but works with Item
// interface. // interface.
func EncodeBinaryStackItem(item Item, w *io.BinWriter) { func EncodeBinaryStackItem(item Item, w *io.BinWriter) {
serializeItemTo(item, w, false, make(map[Item]bool)) sc := serContext{
BinWriter: w,
allowInvalid: false,
seen: make(map[Item]bool),
}
sc.serialize(item)
} }
// EncodeBinaryStackItemAppExec encodes given Item into the given BinWriter. It's // EncodeBinaryStackItemAppExec encodes given Item into the given BinWriter. It's
// similar to EncodeBinaryStackItem but allows to encode interop (only type, value is lost). // similar to EncodeBinaryStackItem but allows to encode interop (only type, value is lost).
func EncodeBinaryStackItemAppExec(item Item, w *io.BinWriter) { func EncodeBinaryStackItemAppExec(item Item, w *io.BinWriter) {
bw := io.NewBufBinWriter() bw := io.NewBufBinWriter()
serializeItemTo(item, bw.BinWriter, true, make(map[Item]bool)) sc := serContext{
BinWriter: bw.BinWriter,
allowInvalid: true,
seen: make(map[Item]bool),
}
sc.serialize(item)
if bw.Err != nil { if bw.Err != nil {
w.WriteBytes([]byte{byte(InvalidT)}) w.WriteBytes([]byte{byte(InvalidT)})
return return
@ -38,15 +60,15 @@ func EncodeBinaryStackItemAppExec(item Item, w *io.BinWriter) {
w.WriteBytes(bw.Bytes()) w.WriteBytes(bw.Bytes())
} }
func serializeItemTo(item Item, w *io.BinWriter, allowInvalid bool, seen map[Item]bool) { func (w *serContext) serialize(item Item) {
if w.Err != nil { if w.Err != nil {
return return
} }
if seen[item] { if w.seen[item] {
w.Err = errors.New("recursive structures can't be serialized") w.Err = errors.New("recursive structures can't be serialized")
return return
} }
if item == nil && allowInvalid { if item == nil && w.allowInvalid {
w.WriteBytes([]byte{byte(InvalidT)}) w.WriteBytes([]byte{byte(InvalidT)})
return return
} }
@ -65,13 +87,13 @@ func serializeItemTo(item Item, w *io.BinWriter, allowInvalid bool, seen map[Ite
w.WriteBytes([]byte{byte(IntegerT)}) w.WriteBytes([]byte{byte(IntegerT)})
w.WriteVarBytes(bigint.ToBytes(t.Value().(*big.Int))) w.WriteVarBytes(bigint.ToBytes(t.Value().(*big.Int)))
case *Interop: case *Interop:
if allowInvalid { if w.allowInvalid {
w.WriteBytes([]byte{byte(InteropT)}) w.WriteBytes([]byte{byte(InteropT)})
return } else {
w.Err = errors.New("interop item can't be serialized")
} }
w.Err = errors.New("interop item can't be serialized")
case *Array, *Struct: case *Array, *Struct:
seen[item] = true w.seen[item] = true
_, isArray := t.(*Array) _, isArray := t.(*Array)
if isArray { if isArray {
@ -83,19 +105,19 @@ func serializeItemTo(item Item, w *io.BinWriter, allowInvalid bool, seen map[Ite
arr := t.Value().([]Item) arr := t.Value().([]Item)
w.WriteVarUint(uint64(len(arr))) w.WriteVarUint(uint64(len(arr)))
for i := range arr { for i := range arr {
serializeItemTo(arr[i], w, allowInvalid, seen) w.serialize(arr[i])
} }
delete(seen, item) delete(w.seen, item)
case *Map: case *Map:
seen[item] = true w.seen[item] = true
w.WriteBytes([]byte{byte(MapT)}) w.WriteBytes([]byte{byte(MapT)})
w.WriteVarUint(uint64(len(t.Value().([]MapElement)))) w.WriteVarUint(uint64(len(t.Value().([]MapElement))))
for i := range t.Value().([]MapElement) { for i := range t.Value().([]MapElement) {
serializeItemTo(t.Value().([]MapElement)[i].Key, w, allowInvalid, seen) w.serialize(t.Value().([]MapElement)[i].Key)
serializeItemTo(t.Value().([]MapElement)[i].Value, w, allowInvalid, seen) w.serialize(t.Value().([]MapElement)[i].Value)
} }
delete(seen, item) delete(w.seen, item)
case Null: case Null:
w.WriteB(byte(AnyT)) w.WriteB(byte(AnyT))
} }