Merge pull request #2525 from nspcc-dev/immutable-items

vm: implement immutable stackitems
This commit is contained in:
Roman Khimov 2022-05-31 10:36:56 +03:00 committed by GitHub
commit e1607e23c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 112 additions and 21 deletions

View file

@ -319,7 +319,7 @@ func TestNEP11_ND_OwnerOf_BalanceOf_Transfer(t *testing.T) {
ScriptHash: verifyH, ScriptHash: verifyH,
Name: "OnNEP11Payment", Name: "OnNEP11Payment",
Item: stackitem.NewArray([]stackitem.Item{ Item: stackitem.NewArray([]stackitem.Item{
stackitem.NewBuffer(nftOwnerHash.BytesBE()), stackitem.NewByteArray(nftOwnerHash.BytesBE()),
stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewBigInteger(big.NewInt(1)),
stackitem.NewByteArray(tokenID1), stackitem.NewByteArray(tokenID1),
stackitem.NewByteArray([]byte("some_data")), stackitem.NewByteArray([]byte("some_data")),

View file

@ -68,7 +68,7 @@ func Notify(ic *interop.Context) error {
if len(bytes) > MaxNotificationSize { if len(bytes) > MaxNotificationSize {
return fmt.Errorf("notification size shouldn't exceed %d", MaxNotificationSize) return fmt.Errorf("notification size shouldn't exceed %d", MaxNotificationSize)
} }
ic.AddNotification(ic.VM.GetCurrentScriptHash(), name, stackitem.DeepCopy(stackitem.NewArray(args)).(*stackitem.Array)) ic.AddNotification(ic.VM.GetCurrentScriptHash(), name, stackitem.DeepCopy(stackitem.NewArray(args), true).(*stackitem.Array))
return nil return nil
} }

View file

@ -161,6 +161,7 @@ func TestNotify(t *testing.T) {
require.NoError(t, Notify(ic)) require.NoError(t, Notify(ic))
require.Equal(t, 1, len(ic.Notifications)) require.Equal(t, 1, len(ic.Notifications))
arr.MarkAsReadOnly() // tiny hack for test to be able to compare object references.
ev := ic.Notifications[0] ev := ic.Notifications[0]
require.Equal(t, "good event", ev.Name) require.Equal(t, "good event", ev.Name)
require.Equal(t, h, ev.ScriptHash) require.Equal(t, h, ev.ScriptHash)

View file

@ -54,7 +54,7 @@ func GetNotifications(ic *interop.Context) error {
ev := stackitem.NewArray([]stackitem.Item{ ev := stackitem.NewArray([]stackitem.Item{
stackitem.NewByteArray(notifications[i].ScriptHash.BytesBE()), stackitem.NewByteArray(notifications[i].ScriptHash.BytesBE()),
stackitem.Make(notifications[i].Name), stackitem.Make(notifications[i].Name),
stackitem.DeepCopy(notifications[i].Item).(*stackitem.Array), notifications[i].Item,
}) })
arr.Append(ev) arr.Append(ev)
} }

View file

@ -54,6 +54,7 @@ func TestRuntimeGetNotifications(t *testing.T) {
name, err := stackitem.ToString(elem[1]) name, err := stackitem.ToString(elem[1])
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ic.Notifications[i].Name, name) require.Equal(t, ic.Notifications[i].Name, name)
ic.Notifications[i].Item.MarkAsReadOnly() // tiny hack for test to be able to compare object references.
require.Equal(t, ic.Notifications[i].Item, elem[2]) require.Equal(t, ic.Notifications[i].Item, elem[2])
} }
}) })

View file

@ -111,6 +111,7 @@ func TestRuntimeGetNotifications(t *testing.T) {
name, err := stackitem.ToString(elem[1]) name, err := stackitem.ToString(elem[1])
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ic.Notifications[i].Name, name) require.Equal(t, ic.Notifications[i].Name, name)
ic.Notifications[i].Item.MarkAsReadOnly() // tiny hack for test to be able to compare object references.
require.Equal(t, ic.Notifications[i].Item, elem[2]) require.Equal(t, ic.Notifications[i].Item, elem[2])
} }
}) })

View file

@ -65,7 +65,7 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int,
for i := range items { for i := range items {
item, ok := items[i].(stackitem.Item) item, ok := items[i].(stackitem.Item)
if ok { if ok {
item = stackitem.DeepCopy(item) item = stackitem.DeepCopy(item, true)
} else { } else {
item = stackitem.Make(items[i]) item = stackitem.Make(items[i])
} }

View file

@ -0,0 +1,21 @@
package stackitem
type ro struct {
isReadOnly bool
}
// IsReadOnly implements Immutable interface.
func (r *ro) IsReadOnly() bool {
return r.isReadOnly
}
// MarkAsReadOnly implements immutable interface.
func (r *ro) MarkAsReadOnly() {
r.isReadOnly = true
}
// Immutable is an interface supported by compound types (Array, Map, Struct).
type Immutable interface {
IsReadOnly() bool
MarkAsReadOnly()
}

View file

@ -70,6 +70,8 @@ var (
// can also be returned by serialization functions if the resulting // can also be returned by serialization functions if the resulting
// value exceeds MaxSize. // value exceeds MaxSize.
ErrTooBig = errors.New("too big") ErrTooBig = errors.New("too big")
// ErrReadOnly is returned on attempt to modify immutable stack item.
ErrReadOnly = errors.New("item is read-only")
errTooBigComparable = fmt.Errorf("%w: uncomparable", ErrTooBig) errTooBigComparable = fmt.Errorf("%w: uncomparable", ErrTooBig)
errTooBigInteger = fmt.Errorf("%w: integer", ErrTooBig) errTooBigInteger = fmt.Errorf("%w: integer", ErrTooBig)
@ -185,6 +187,7 @@ func convertPrimitive(item Item, typ Type) (Item, error) {
type Struct struct { type Struct struct {
value []Item value []Item
rc rc
ro
} }
// NewStruct returns a new Struct object. // NewStruct returns a new Struct object.
@ -202,16 +205,25 @@ func (i *Struct) Value() interface{} {
// Remove removes the element at `pos` index from the Struct value. // Remove removes the element at `pos` index from the Struct value.
// It will panic if a bad index given. // It will panic if a bad index given.
func (i *Struct) Remove(pos int) { func (i *Struct) Remove(pos int) {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
i.value = append(i.value[:pos], i.value[pos+1:]...) i.value = append(i.value[:pos], i.value[pos+1:]...)
} }
// Append adds an Item to the end of the Struct value. // Append adds an Item to the end of the Struct value.
func (i *Struct) Append(item Item) { func (i *Struct) Append(item Item) {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
i.value = append(i.value, item) i.value = append(i.value, item)
} }
// Clear removes all elements from the Struct item value. // Clear removes all elements from the Struct item value.
func (i *Struct) Clear() { func (i *Struct) Clear() {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
i.value = i.value[:0] i.value = i.value[:0]
} }
@ -662,6 +674,7 @@ func (i *ByteArray) Convert(typ Type) (Item, error) {
type Array struct { type Array struct {
value []Item value []Item
rc rc
ro
} }
// NewArray returns a new Array object. // NewArray returns a new Array object.
@ -679,16 +692,25 @@ func (i *Array) Value() interface{} {
// Remove removes the element at `pos` index from Array value. // Remove removes the element at `pos` index from Array value.
// It will panics on bad index. // It will panics on bad index.
func (i *Array) Remove(pos int) { func (i *Array) Remove(pos int) {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
i.value = append(i.value[:pos], i.value[pos+1:]...) i.value = append(i.value[:pos], i.value[pos+1:]...)
} }
// Append adds an Item to the end of the Array value. // Append adds an Item to the end of the Array value.
func (i *Array) Append(item Item) { func (i *Array) Append(item Item) {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
i.value = append(i.value, item) i.value = append(i.value, item)
} }
// Clear removes all elements from the Array item value. // Clear removes all elements from the Array item value.
func (i *Array) Clear() { func (i *Array) Clear() {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
i.value = i.value[:0] i.value = i.value[:0]
} }
@ -763,6 +785,7 @@ type MapElement struct {
type Map struct { type Map struct {
value []MapElement value []MapElement
rc rc
ro
} }
// NewMap returns a new Map object. // NewMap returns a new Map object.
@ -789,6 +812,9 @@ func (i *Map) Value() interface{} {
// Clear removes all elements from the Map item value. // Clear removes all elements from the Map item value.
func (i *Map) Clear() { func (i *Map) Clear() {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
i.value = i.value[:0] i.value = i.value[:0]
} }
@ -860,6 +886,9 @@ func (i *Map) Add(key, value Item) {
if err := IsValidMapKey(key); err != nil { if err := IsValidMapKey(key); err != nil {
panic(err) panic(err)
} }
if i.IsReadOnly() {
panic(ErrReadOnly)
}
index := i.Index(key) index := i.Index(key)
if index >= 0 { if index >= 0 {
i.value[index].Value = value i.value[index].Value = value
@ -870,6 +899,9 @@ func (i *Map) Add(key, value Item) {
// Drop removes the given index from the map (no bounds check done here). // Drop removes the given index from the map (no bounds check done here).
func (i *Map) Drop(index int) { func (i *Map) Drop(index int) {
if i.IsReadOnly() {
panic(ErrReadOnly)
}
copy(i.value[index:], i.value[index+1:]) copy(i.value[index:], i.value[index+1:])
i.value = i.value[:len(i.value)-1] i.value = i.value[:len(i.value)-1]
} }
@ -1139,12 +1171,12 @@ func (i *Buffer) Len() int {
// DeepCopy returns a new deep copy of the provided item. // DeepCopy returns a new deep copy of the provided item.
// Values of Interop items are not deeply copied. // Values of Interop items are not deeply copied.
// It does preserve duplicates only for non-primitive types. // It does preserve duplicates only for non-primitive types.
func DeepCopy(item Item) Item { func DeepCopy(item Item, asImmutable bool) Item {
seen := make(map[Item]Item, typicalNumOfItems) seen := make(map[Item]Item, typicalNumOfItems)
return deepCopy(item, seen) return deepCopy(item, seen, asImmutable)
} }
func deepCopy(item Item, seen map[Item]Item) Item { func deepCopy(item Item, seen map[Item]Item, asImmutable bool) Item {
if it := seen[item]; it != nil { if it := seen[item]; it != nil {
return it return it
} }
@ -1155,24 +1187,28 @@ func deepCopy(item Item, seen map[Item]Item) Item {
arr := NewArray(make([]Item, len(it.value))) arr := NewArray(make([]Item, len(it.value)))
seen[item] = arr seen[item] = arr
for i := range it.value { for i := range it.value {
arr.value[i] = deepCopy(it.value[i], seen) arr.value[i] = deepCopy(it.value[i], seen, asImmutable)
} }
arr.MarkAsReadOnly()
return arr return arr
case *Struct: case *Struct:
arr := NewStruct(make([]Item, len(it.value))) arr := NewStruct(make([]Item, len(it.value)))
seen[item] = arr seen[item] = arr
for i := range it.value { for i := range it.value {
arr.value[i] = deepCopy(it.value[i], seen) arr.value[i] = deepCopy(it.value[i], seen, asImmutable)
} }
arr.MarkAsReadOnly()
return arr return arr
case *Map: case *Map:
m := NewMap() m := NewMap()
seen[item] = m seen[item] = m
for i := range it.value { for i := range it.value {
key := deepCopy(it.value[i].Key, seen) key := deepCopy(it.value[i].Key, seen,
value := deepCopy(it.value[i].Value, seen) false) // Key is always primitive and not a Buffer.
value := deepCopy(it.value[i].Value, seen, asImmutable)
m.Add(key, value) m.Add(key, value)
} }
m.MarkAsReadOnly()
return m return m
case *BigInteger: case *BigInteger:
bi := new(big.Int).Set(it.Big()) bi := new(big.Int).Set(it.Big())
@ -1180,6 +1216,9 @@ func deepCopy(item Item, seen map[Item]Item) Item {
case *ByteArray: case *ByteArray:
return NewByteArray(slice.Copy(*it)) return NewByteArray(slice.Copy(*it))
case *Buffer: case *Buffer:
if asImmutable {
return NewByteArray(slice.Copy(*it))
}
return NewBuffer(slice.Copy(*it)) return NewBuffer(slice.Copy(*it))
case Bool: case Bool:
return it return it

View file

@ -530,7 +530,10 @@ func TestDeepCopy(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
actual := DeepCopy(tc.item) actual := DeepCopy(tc.item, false)
if immut, ok := tc.item.(Immutable); ok {
immut.MarkAsReadOnly() // tiny hack for test to be able to compare object references.
}
require.Equal(t, tc.item, actual) require.Equal(t, tc.item, actual)
if tc.item.Type() != BooleanT { if tc.item.Type() != BooleanT {
require.False(t, actual == tc.item) require.False(t, actual == tc.item)
@ -539,7 +542,7 @@ func TestDeepCopy(t *testing.T) {
} }
t.Run("Null", func(t *testing.T) { t.Run("Null", func(t *testing.T) {
require.Equal(t, Null{}, DeepCopy(Null{})) require.Equal(t, Null{}, DeepCopy(Null{}, false))
}) })
t.Run("Array", func(t *testing.T) { t.Run("Array", func(t *testing.T) {
@ -547,7 +550,8 @@ func TestDeepCopy(t *testing.T) {
arr.value[0] = NewBool(true) arr.value[0] = NewBool(true)
arr.value[1] = arr arr.value[1] = arr
actual := DeepCopy(arr) actual := DeepCopy(arr, false)
arr.isReadOnly = true // tiny hack for test to be able to compare object references.
require.Equal(t, arr, actual) require.Equal(t, arr, actual)
require.False(t, arr == actual) require.False(t, arr == actual)
require.True(t, actual == actual.(*Array).value[1]) require.True(t, actual == actual.(*Array).value[1])
@ -558,7 +562,8 @@ func TestDeepCopy(t *testing.T) {
arr.value[0] = NewBool(true) arr.value[0] = NewBool(true)
arr.value[1] = arr arr.value[1] = arr
actual := DeepCopy(arr) actual := DeepCopy(arr, false)
arr.isReadOnly = true // tiny hack for test to be able to compare object references.
require.Equal(t, arr, actual) require.Equal(t, arr, actual)
require.False(t, arr == actual) require.False(t, arr == actual)
require.True(t, actual == actual.(*Struct).value[1]) require.True(t, actual == actual.(*Struct).value[1])
@ -569,7 +574,8 @@ func TestDeepCopy(t *testing.T) {
m.value[0] = MapElement{Key: NewBool(true), Value: m} m.value[0] = MapElement{Key: NewBool(true), Value: m}
m.value[1] = MapElement{Key: NewBigInteger(big.NewInt(1)), Value: NewByteArray([]byte{1, 2, 3})} m.value[1] = MapElement{Key: NewBigInteger(big.NewInt(1)), Value: NewByteArray([]byte{1, 2, 3})}
actual := DeepCopy(m) actual := DeepCopy(m, false)
m.isReadOnly = true // tiny hack for test to be able to compare object references.
require.Equal(t, m, actual) require.Equal(t, m, actual)
require.False(t, m == actual) require.False(t, m == actual)
require.True(t, actual == actual.(*Map).value[0].Value) require.True(t, actual == actual.(*Map).value[0].Value)

View file

@ -1245,10 +1245,16 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
v.throw(stackitem.NewByteArray([]byte(msg))) v.throw(stackitem.NewByteArray([]byte(msg)))
return return
} }
if t.(stackitem.Immutable).IsReadOnly() {
panic(stackitem.ErrReadOnly)
}
v.refs.Remove(arr[index]) v.refs.Remove(arr[index])
arr[index] = item arr[index] = item
v.refs.Add(arr[index]) v.refs.Add(arr[index])
case *stackitem.Map: case *stackitem.Map:
if t.IsReadOnly() {
panic(stackitem.ErrReadOnly)
}
if i := t.Index(key.value); i >= 0 { if i := t.Index(key.value); i >= 0 {
v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value)
} else { } else {
@ -1279,6 +1285,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
item := v.estack.Pop() item := v.estack.Pop()
switch t := item.value.(type) { switch t := item.value.(type) {
case *stackitem.Array, *stackitem.Struct: case *stackitem.Array, *stackitem.Struct:
if t.(stackitem.Immutable).IsReadOnly() {
panic(stackitem.ErrReadOnly)
}
a := t.Value().([]stackitem.Item) a := t.Value().([]stackitem.Item)
for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 {
a[i], a[j] = a[j], a[i] a[i], a[j] = a[j], a[i]
@ -1301,24 +1310,28 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if k < 0 || k >= len(a) { if k < 0 || k >= len(a) {
panic("REMOVE: invalid index") panic("REMOVE: invalid index")
} }
v.refs.Remove(a[k]) toRemove := a[k]
t.Remove(k) t.Remove(k)
v.refs.Remove(toRemove)
case *stackitem.Struct: case *stackitem.Struct:
a := t.Value().([]stackitem.Item) a := t.Value().([]stackitem.Item)
k := toInt(key.BigInt()) k := toInt(key.BigInt())
if k < 0 || k >= len(a) { if k < 0 || k >= len(a) {
panic("REMOVE: invalid index") panic("REMOVE: invalid index")
} }
v.refs.Remove(a[k]) toRemove := a[k]
t.Remove(k) t.Remove(k)
v.refs.Remove(toRemove)
case *stackitem.Map: case *stackitem.Map:
index := t.Index(key.Item()) index := t.Index(key.Item())
// NEO 2.0 doesn't error on missing key. // NEO 2.0 doesn't error on missing key.
if index >= 0 { if index >= 0 {
elems := t.Value().([]stackitem.MapElement) elems := t.Value().([]stackitem.MapElement)
v.refs.Remove(elems[index].Key) key := elems[index].Key
v.refs.Remove(elems[index].Value) val := elems[index].Value
t.Drop(index) t.Drop(index)
v.refs.Remove(key)
v.refs.Remove(val)
} }
default: default:
panic("REMOVE: invalid type") panic("REMOVE: invalid type")
@ -1328,16 +1341,25 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
elem := v.estack.Pop() elem := v.estack.Pop()
switch t := elem.value.(type) { switch t := elem.value.(type) {
case *stackitem.Array: case *stackitem.Array:
if t.IsReadOnly() {
panic(stackitem.ErrReadOnly)
}
for _, item := range t.Value().([]stackitem.Item) { for _, item := range t.Value().([]stackitem.Item) {
v.refs.Remove(item) v.refs.Remove(item)
} }
t.Clear() t.Clear()
case *stackitem.Struct: case *stackitem.Struct:
if t.IsReadOnly() {
panic(stackitem.ErrReadOnly)
}
for _, item := range t.Value().([]stackitem.Item) { for _, item := range t.Value().([]stackitem.Item) {
v.refs.Remove(item) v.refs.Remove(item)
} }
t.Clear() t.Clear()
case *stackitem.Map: case *stackitem.Map:
if t.IsReadOnly() {
panic(stackitem.ErrReadOnly)
}
elems := t.Value().([]stackitem.MapElement) elems := t.Value().([]stackitem.MapElement)
for i := range elems { for i := range elems {
v.refs.Remove(elems[i].Key) v.refs.Remove(elems[i].Key)
@ -1353,6 +1375,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
elems := arr.Value().([]stackitem.Item) elems := arr.Value().([]stackitem.Item)
index := len(elems) - 1 index := len(elems) - 1
elem := elems[index] elem := elems[index]
v.estack.PushItem(elem) // push item on stack firstly, to match the reference behaviour.
switch item := arr.(type) { switch item := arr.(type) {
case *stackitem.Array: case *stackitem.Array:
item.Remove(index) item.Remove(index)
@ -1360,7 +1383,6 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
item.Remove(index) item.Remove(index)
} }
v.refs.Remove(elem) v.refs.Remove(elem)
v.estack.PushItem(elem)
case opcode.SIZE: case opcode.SIZE:
elem := v.estack.Pop() elem := v.estack.Pop()