diff --git a/pkg/compiler/syscall_test.go b/pkg/compiler/syscall_test.go index f14d8629c..9fc86cfdf 100644 --- a/pkg/compiler/syscall_test.go +++ b/pkg/compiler/syscall_test.go @@ -16,6 +16,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/interop/contract" "github.com/nspcc-dev/neo-go/pkg/interop/storage" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/assert" @@ -375,3 +376,118 @@ func TestOpcode(t *testing.T) { }) }) } + +func TestInteropTypesComparison(t *testing.T) { + typeCheck := func(t *testing.T, typeName string, typeLen int) { + t.Run(typeName, func(t *testing.T) { + var ha, hb string + for i := 0; i < typeLen; i++ { + if i == typeLen-1 { + ha += "2" + hb += "3" + } else { + ha += "1, " + hb += "1, " + } + } + check := func(t *testing.T, a, b string, expected bool) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + func Main() bool { + a := interop.` + typeName + `{` + a + `} + b := interop.` + typeName + `{` + b + `} + return a.Equals(b) + }` + eval(t, src, expected) + } + t.Run("same type", func(t *testing.T) { + check(t, ha, ha, true) + check(t, ha, hb, false) + }) + t.Run("a is nil", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + + func Main() bool { + var a interop.` + typeName + ` + b := interop.` + typeName + `{` + hb + `} + return a.Equals(b) + }` + eval(t, src, false) + }) + t.Run("b is nil", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + + func Main() bool { + a := interop.` + typeName + `{` + ha + `} + var b interop.` + typeName + ` + return a.Equals(b) + }` + eval(t, src, false) + }) + t.Run("both nil", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + + func Main() bool { + var a interop.` + typeName + ` + var b interop.` + typeName + ` + return a.Equals(b) + }` + eval(t, src, true) + }) + t.Run("different types", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + + func Main() bool { + a := interop.` + typeName + `{` + ha + `} + b := 123 + return a.Equals(b) + }` + eval(t, src, false) + }) + t.Run("b is Buffer", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + + func Main() bool { + a := interop.` + typeName + `{` + ha + `} + b := []byte{` + ha + `} + return a.Equals(b) + }` + eval(t, src, true) + }) + t.Run("b is ByteString", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + + func Main() bool { + a := interop.` + typeName + `{` + ha + `} + b := string([]byte{` + ha + `}) + return a.Equals(b) + }` + eval(t, src, true) + }) + t.Run("b is compound type", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + + func Main() bool { + a := interop.` + typeName + `{` + ha + `} + b := struct{}{} + return a.Equals(b) + }` + vm, _ := vmAndCompileInterop(t, src) + err := vm.Run() + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "invalid conversion: Struct/ByteString"), err) + }) + }) + } + typeCheck(t, "Hash160", util.Uint160Size) + typeCheck(t, "Hash256", util.Uint256Size) + typeCheck(t, "Signature", 64) + typeCheck(t, "PublicKey", 33) +} diff --git a/pkg/interop/types.go b/pkg/interop/types.go index 8025cc6be..c0ffd4ece 100644 --- a/pkg/interop/types.go +++ b/pkg/interop/types.go @@ -1,5 +1,7 @@ package interop +import "github.com/nspcc-dev/neo-go/pkg/interop/neogointernal" + const ( // Hash160Len is the length of proper Hash160 in bytes, use it to // sanitize input parameters. @@ -35,3 +37,46 @@ type PublicKey []byte // Interface represents interop interface type which is needed for // transparent handling of VM-internal types (e.g. storage.Context). type Interface interface{} + +// Equals compares Hash160 with the provided stackitem using EQUAL opcode. +// The provided stackitem `b` must be either one of the primitive type (int, +// bool, string, []byte) or derived from the primitive type, otherwise Equals +// will fail on .(string) conversion. +func (a Hash160) Equals(b interface{}) bool { + ha := interface{}(a) + return bytesEquals(ha, b) +} + +// Equals compares Hash256 with the provided stackitem using EQUAL opcode. +// The provided stackitem `b` must be either one of the primitive type (int, +// bool, string, []byte) or derived from the primitive type, otherwise Equals +// will fail on .(string) conversion. +func (a Hash256) Equals(b interface{}) bool { + ha := interface{}(a) + return bytesEquals(ha, b) +} + +// Equals compares PublicKey with the provided stackitem using EQUAL opcode. +// The provided stackitem `b` must be either one of the primitive type (int, +// bool, string, []byte) or derived from the primitive type, otherwise Equals +// will fail on .(string) conversion. +func (a PublicKey) Equals(b interface{}) bool { + ha := interface{}(a) + return bytesEquals(ha, b) +} + +// Equals compares Signature with the provided stackitem using EQUAL opcode. +// The provided stackitem `b` must be either one of the primitive types (int, +// bool, string, []byte) or derived from the primitive type, otherwise Equals +// will fail on .(string) conversion. +func (a Signature) Equals(b interface{}) bool { + ha := interface{}(a) + return bytesEquals(ha, b) +} + +// bytesEquals is an internal helper function allowed to compare types that can be +// converted to ByteString. +func bytesEquals(a interface{}, b interface{}) bool { + return (a == nil && b == nil) || + (a != nil && b != nil && neogointernal.Opcode2("EQUAL", a.(string), b.(string)).(bool)) +}