diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index d2ac7d4a4..3829ea290 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -5,7 +5,9 @@ import ( "fmt" "strings" + "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/native/nativenames" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" @@ -15,6 +17,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) +type policyChecker interface { + IsBlockedInternal(dao.DAO, util.Uint160) bool +} + // LoadToken calls method specified by token id. func LoadToken(ic *interop.Context) func(id int32) error { return func(id int32) error { @@ -88,6 +94,15 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl // callExFromNative calls a contract with flags using provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error { + for _, nc := range ic.Natives { + if nc.Metadata().Name == nativenames.Policy { + var pch = nc.(policyChecker) + if pch.IsBlockedInternal(ic.DAO, cs.Hash) { + return fmt.Errorf("contract %s is blocked", cs.Hash.StringLE()) + } + break + } + } md := cs.Manifest.ABI.GetMethod(name, len(args)) if md == nil { return fmt.Errorf("method '%s' not found", name) diff --git a/pkg/core/native_policy_test.go b/pkg/core/native_policy_test.go index 631a1718f..d29f3f3c0 100644 --- a/pkg/core/native_policy_test.go +++ b/pkg/core/native_policy_test.go @@ -196,4 +196,23 @@ func TestBlockedAccounts(t *testing.T) { require.NoError(t, err) checkFAULTState(t, invokeRes) }) + + t.Run("block-unblock contract", func(t *testing.T) { + neoHash := chain.contracts.NEO.Metadata().Hash + res, err := invokeContractMethodGeneric(chain, 100000000, policyHash, "blockAccount", true, neoHash.BytesBE()) + require.NoError(t, err) + checkResult(t, res, stackitem.NewBool(true)) + + res, err = invokeContractMethodGeneric(chain, 100000000, neoHash, "balanceOf", true, account.BytesBE()) + require.NoError(t, err) + checkFAULTState(t, res) + + res, err = invokeContractMethodGeneric(chain, 100000000, policyHash, "unblockAccount", true, neoHash.BytesBE()) + require.NoError(t, err) + checkResult(t, res, stackitem.NewBool(true)) + + res, err = invokeContractMethodGeneric(chain, 100000000, neoHash, "balanceOf", true, account.BytesBE()) + require.NoError(t, err) + checkResult(t, res, stackitem.Make(0)) + }) }