diff --git a/pkg/core/interop/runtime/engine.go b/pkg/core/interop/runtime/engine.go index 73654db19..daadcda5a 100644 --- a/pkg/core/interop/runtime/engine.go +++ b/pkg/core/interop/runtime/engine.go @@ -73,9 +73,20 @@ func Notify(ic *interop.Context) error { if len(name) > MaxEventNameLen { return fmt.Errorf("event name must be less than %d", MaxEventNameLen) } - if !ic.VM.Context().IsDeployed() { + curHash := ic.VM.GetCurrentScriptHash() + ctr, err := ic.GetContract(curHash) + if err != nil { return errors.New("notifications are not allowed in dynamic scripts") } + ev := ctr.Manifest.ABI.GetEvent(name) + if ev == nil { + ic.Log.Info("bad notification", zap.String("contract", curHash.StringLE()), zap.String("event", name), zap.Error(fmt.Errorf("event %s does not exist", name))) + } else { + err = ev.CheckCompliance(args) + if err != nil { + ic.Log.Info("bad notification", zap.String("contract", curHash.StringLE()), zap.String("event", name), zap.Error(err)) + } + } // But it has to be serializable, otherwise we either have some broken // (recursive) structure inside or an interop item that can't be used @@ -87,7 +98,7 @@ func Notify(ic *interop.Context) error { if len(bytes) > MaxNotificationSize { return fmt.Errorf("notification size shouldn't exceed %d", MaxNotificationSize) } - ic.AddNotification(ic.VM.GetCurrentScriptHash(), name, stackitem.DeepCopy(stackitem.NewArray(args), true).(*stackitem.Array)) + ic.AddNotification(curHash, name, stackitem.DeepCopy(stackitem.NewArray(args), true).(*stackitem.Array)) return nil } diff --git a/pkg/core/interop/runtime/engine_test.go b/pkg/core/interop/runtime/engine_test.go index 4ddf628ae..656a78dcb 100644 --- a/pkg/core/interop/runtime/engine_test.go +++ b/pkg/core/interop/runtime/engine_test.go @@ -8,11 +8,9 @@ import ( "github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/pkg/core/block" - "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" - "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" @@ -131,55 +129,3 @@ func TestLog(t *testing.T) { require.Equal(t, h.StringLE(), logMsg["script"]) }) } - -func TestNotify(t *testing.T) { - h := random.Uint160() - caller := random.Uint160() - exe, err := nef.NewFile([]byte{1}) - require.NoError(t, err) - newIC := func(name string, args interface{}) *interop.Context { - ic := &interop.Context{VM: vm.New(), DAO: &dao.Simple{}} - ic.VM.LoadNEFMethod(exe, caller, h, callflag.NoneFlag, true, 0, -1, nil) - ic.VM.Estack().PushVal(args) - ic.VM.Estack().PushVal(name) - return ic - } - t.Run("big name", func(t *testing.T) { - ic := newIC(string(make([]byte, MaxEventNameLen+1)), stackitem.NewArray([]stackitem.Item{stackitem.Null{}})) - require.Error(t, Notify(ic)) - }) - t.Run("dynamic script", func(t *testing.T) { - ic := &interop.Context{VM: vm.New(), DAO: &dao.Simple{}} - ic.VM.LoadScriptWithHash([]byte{1}, h, callflag.NoneFlag) - ic.VM.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(42)})) - ic.VM.Estack().PushVal("event") - require.Error(t, Notify(ic)) - }) - t.Run("recursive struct", func(t *testing.T) { - arr := stackitem.NewArray([]stackitem.Item{stackitem.Null{}}) - arr.Append(arr) - ic := newIC("event", arr) - require.Error(t, Notify(ic)) - }) - t.Run("big notification", func(t *testing.T) { - bs := stackitem.NewByteArray(make([]byte, MaxNotificationSize+1)) - arr := stackitem.NewArray([]stackitem.Item{bs}) - ic := newIC("event", arr) - require.Error(t, Notify(ic)) - }) - t.Run("good", func(t *testing.T) { - arr := stackitem.NewArray([]stackitem.Item{stackitem.Make(42)}) - ic := newIC("good event", arr) - require.NoError(t, Notify(ic)) - require.Equal(t, 1, len(ic.Notifications)) - - arr.MarkAsReadOnly() // tiny hack for test to be able to compare object references. - ev := ic.Notifications[0] - require.Equal(t, "good event", ev.Name) - require.Equal(t, h, ev.ScriptHash) - require.Equal(t, arr, ev.Item) - // Check deep copy. - arr.Value().([]stackitem.Item)[0] = stackitem.Null{} - require.NotEqual(t, arr, ev.Item) - }) -} diff --git a/pkg/core/interop/runtime/ext_test.go b/pkg/core/interop/runtime/ext_test.go index 43e5fff56..d5a878036 100644 --- a/pkg/core/interop/runtime/ext_test.go +++ b/pkg/core/interop/runtime/ext_test.go @@ -79,7 +79,7 @@ func loadScriptWithHashAndFlags(ic *interop.Context, script []byte, hash util.Ui ic.VM.GasLimit = -1 } -func TestBurnGas(t *testing.T) { +func getDeployedInternal(t *testing.T) (*neotest.Executor, neotest.Signer, *core.Blockchain, *state.Contract) { bc, acc := chain.NewSingle(t) e := neotest.NewExecutor(t, bc, acc, acc) managementInvoker := e.ValidatorInvoker(e.NativeHash(t, nativenames.Management)) @@ -92,6 +92,12 @@ func TestBurnGas(t *testing.T) { tx := managementInvoker.PrepareInvoke(t, "deploy", rawNef, rawManifest) e.AddNewBlock(t, tx) e.CheckHalt(t, tx.Hash()) + + return e, acc, bc, cs +} + +func TestBurnGas(t *testing.T) { + e, acc, _, cs := getDeployedInternal(t) cInvoker := e.ValidatorInvoker(cs.Hash) t.Run("good", func(t *testing.T) { @@ -539,3 +545,53 @@ func TestGetRandomCompatibility(t *testing.T) { require.NoError(t, runtime.GetRandom(ic)) require.Equal(t, "247152297361212656635216876565962360375", ic.VM.Estack().Pop().BigInt().String()) } + +func TestNotify(t *testing.T) { + caller := random.Uint160() + newIC := func(name string, args interface{}) *interop.Context { + _, _, bc, cs := getDeployedInternal(t) + ic := bc.GetTestVM(trigger.Application, nil, nil) + ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, callflag.NoneFlag, true, 0, -1, nil) + ic.VM.Estack().PushVal(args) + ic.VM.Estack().PushVal(name) + return ic + } + t.Run("big name", func(t *testing.T) { + ic := newIC(string(make([]byte, runtime.MaxEventNameLen+1)), stackitem.NewArray([]stackitem.Item{stackitem.Null{}})) + require.Error(t, runtime.Notify(ic)) + }) + t.Run("dynamic script", func(t *testing.T) { + ic := newIC("some", stackitem.Null{}) + ic.VM.LoadScriptWithHash([]byte{1}, random.Uint160(), callflag.NoneFlag) + ic.VM.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(42)})) + ic.VM.Estack().PushVal("event") + require.Error(t, runtime.Notify(ic)) + }) + t.Run("recursive struct", func(t *testing.T) { + arr := stackitem.NewArray([]stackitem.Item{stackitem.Null{}}) + arr.Append(arr) + ic := newIC("event", arr) + require.Error(t, runtime.Notify(ic)) + }) + t.Run("big notification", func(t *testing.T) { + bs := stackitem.NewByteArray(make([]byte, runtime.MaxNotificationSize+1)) + arr := stackitem.NewArray([]stackitem.Item{bs}) + ic := newIC("event", arr) + require.Error(t, runtime.Notify(ic)) + }) + t.Run("good", func(t *testing.T) { + arr := stackitem.NewArray([]stackitem.Item{stackitem.Make(42)}) + ic := newIC("good event", arr) + require.NoError(t, runtime.Notify(ic)) + require.Equal(t, 1, len(ic.Notifications)) + + arr.MarkAsReadOnly() // tiny hack for test to be able to compare object references. + ev := ic.Notifications[0] + require.Equal(t, "good event", ev.Name) + require.Equal(t, ic.VM.GetCurrentScriptHash(), ev.ScriptHash) + require.Equal(t, arr, ev.Item) + // Check deep copy. + arr.Value().([]stackitem.Item)[0] = stackitem.Null{} + require.NotEqual(t, arr, ev.Item) + }) +} diff --git a/pkg/smartcontract/manifest/event.go b/pkg/smartcontract/manifest/event.go index 00a87e988..04a1f348d 100644 --- a/pkg/smartcontract/manifest/event.go +++ b/pkg/smartcontract/manifest/event.go @@ -2,6 +2,7 @@ package manifest import ( "errors" + "fmt" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -60,3 +61,17 @@ func (e *Event) FromStackItem(item stackitem.Item) error { } return nil } + +// CheckCompliance checks compliance of the given array of items with the +// current event. +func (e *Event) CheckCompliance(items []stackitem.Item) error { + if len(items) != len(e.Parameters) { + return errors.New("mismatch between the number of parameters and items") + } + for i := range items { + if !e.Parameters[i].Type.Match(items[i]) { + return fmt.Errorf("parameter %d type mismatch: %s vs %s", i, e.Parameters[i].Type.String(), items[i].Type().String()) + } + } + return nil +} diff --git a/pkg/smartcontract/manifest/event_test.go b/pkg/smartcontract/manifest/event_test.go index 2e50bdd46..65b32183a 100644 --- a/pkg/smartcontract/manifest/event_test.go +++ b/pkg/smartcontract/manifest/event_test.go @@ -64,3 +64,13 @@ func TestEvent_FromStackItemErrors(t *testing.T) { }) } } + +func TestEventCheckCompliance(t *testing.T) { + m := &Event{ + Name: "mur", + Parameters: []Parameter{{Name: "p1", Type: smartcontract.BoolType}}, + } + require.Error(t, m.CheckCompliance([]stackitem.Item{})) + require.Error(t, m.CheckCompliance([]stackitem.Item{stackitem.Make("something")})) + require.NoError(t, m.CheckCompliance([]stackitem.Item{stackitem.Make(true)})) +} diff --git a/pkg/smartcontract/param_type.go b/pkg/smartcontract/param_type.go index 5b36de52f..2455e994e 100644 --- a/pkg/smartcontract/param_type.go +++ b/pkg/smartcontract/param_type.go @@ -170,6 +170,52 @@ func (pt ParamType) EncodeDefaultValue(w *io.BinWriter) { } } +func checkBytesWithLen(vt stackitem.Type, v stackitem.Item, l int) bool { + if vt == stackitem.AnyT { + return true + } + if vt != stackitem.ByteArrayT && vt != stackitem.BufferT { + return false + } + b, _ := v.TryBytes() // Can't fail, we know the type exactly. + return len(b) == l +} + +func (pt ParamType) Match(v stackitem.Item) bool { + vt := v.Type() + + // Pointer can't be matched at all. + if vt == stackitem.PointerT { + return false + } + switch pt { + case AnyType: + return true + case BoolType: + return vt == stackitem.BooleanT + case IntegerType: + return vt == stackitem.IntegerT + case ByteArrayType, StringType: + return vt == stackitem.ByteArrayT || vt == stackitem.BufferT || vt == stackitem.AnyT + case Hash160Type: + return checkBytesWithLen(vt, v, 20) + case Hash256Type: + return checkBytesWithLen(vt, v, 32) + case PublicKeyType: + return checkBytesWithLen(vt, v, 33) + case SignatureType: + return checkBytesWithLen(vt, v, 64) + case ArrayType: + return vt == stackitem.AnyT || vt == stackitem.ArrayT || vt == stackitem.StructT + case MapType: + return vt == stackitem.AnyT || vt == stackitem.MapT + case InteropInterfaceType: + return vt == stackitem.AnyT || vt == stackitem.InteropT + default: + return false + } +} + // ParseParamType is a user-friendly string to ParamType converter, it's // case-insensitive and makes the following conversions: // diff --git a/pkg/smartcontract/param_type_test.go b/pkg/smartcontract/param_type_test.go index a9f52c11d..59cfb95e5 100644 --- a/pkg/smartcontract/param_type_test.go +++ b/pkg/smartcontract/param_type_test.go @@ -418,3 +418,58 @@ func TestConvertToStackitemType(t *testing.T) { UnknownType.ConvertToStackitemType() }) } + +func TestParamTypeMatch(t *testing.T) { + for itm, pt := range map[stackitem.Item]ParamType{ + &stackitem.Pointer{}: BoolType, + &stackitem.Pointer{}: MapType, + stackitem.Make(0): BoolType, + stackitem.Make(0): ByteArrayType, + stackitem.Make(0): StringType, + stackitem.Make(false): ByteArrayType, + stackitem.Make(true): StringType, + stackitem.Make([]byte{1}): Hash160Type, + stackitem.Make([]byte{1}): Hash256Type, + stackitem.Make([]byte{1}): PublicKeyType, + stackitem.Make([]byte{1}): SignatureType, + stackitem.Make(0): Hash160Type, + stackitem.Make(0): Hash256Type, + stackitem.Make(0): PublicKeyType, + stackitem.Make(0): SignatureType, + stackitem.Make(0): ArrayType, + stackitem.Make(0): MapType, + stackitem.Make(0): InteropInterfaceType, + stackitem.Make(0): VoidType, + } { + require.Falsef(t, pt.Match(itm), "%s - %s", pt.String(), itm.String()) + } + for itm, pt := range map[stackitem.Item]ParamType{ + stackitem.Make(false): BoolType, + stackitem.Make(true): BoolType, + stackitem.Make(0): IntegerType, + stackitem.Make(100500): IntegerType, + stackitem.Make([]byte{1}): ByteArrayType, + stackitem.Make([]byte{1}): StringType, + stackitem.NewBuffer([]byte{1}): ByteArrayType, + stackitem.NewBuffer([]byte{1}): StringType, + stackitem.Null{}: ByteArrayType, + stackitem.Null{}: StringType, + stackitem.Make(util.Uint160{}.BytesBE()): Hash160Type, + stackitem.Make(util.Uint256{}.BytesBE()): Hash256Type, + stackitem.Null{}: Hash160Type, + stackitem.Null{}: Hash256Type, + stackitem.Make(make([]byte, 33)): PublicKeyType, + stackitem.Null{}: PublicKeyType, + stackitem.Make(make([]byte, 64)): SignatureType, + stackitem.Null{}: SignatureType, + stackitem.Make([]stackitem.Item{}): ArrayType, + stackitem.NewStruct([]stackitem.Item{}): ArrayType, + stackitem.Null{}: ArrayType, + stackitem.NewMap(): MapType, + stackitem.Null{}: MapType, + stackitem.NewInterop(true): InteropInterfaceType, + stackitem.Null{}: InteropInterfaceType, + } { + require.Truef(t, pt.Match(itm), "%s - %s", pt.String(), itm.String()) + } +}