contract: block calls to contracts via Policy contract

See neo-project/neo#2567.
This commit is contained in:
Roman Khimov 2021-08-17 15:18:11 +03:00
parent 11351b9702
commit f477a48758
2 changed files with 34 additions and 0 deletions

View file

@ -5,7 +5,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
@ -15,6 +17,10 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
type policyChecker interface {
IsBlockedInternal(dao.DAO, util.Uint160) bool
}
// LoadToken calls method specified by token id. // LoadToken calls method specified by token id.
func LoadToken(ic *interop.Context) func(id int32) error { func LoadToken(ic *interop.Context) func(id int32) error {
return func(id int32) error { return func(id int32) error {
@ -88,6 +94,15 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
// callExFromNative calls a contract with flags using provided calling hash. // callExFromNative calls a contract with flags using provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error { name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error {
for _, nc := range ic.Natives {
if nc.Metadata().Name == nativenames.Policy {
var pch = nc.(policyChecker)
if pch.IsBlockedInternal(ic.DAO, cs.Hash) {
return fmt.Errorf("contract %s is blocked", cs.Hash.StringLE())
}
break
}
}
md := cs.Manifest.ABI.GetMethod(name, len(args)) md := cs.Manifest.ABI.GetMethod(name, len(args))
if md == nil { if md == nil {
return fmt.Errorf("method '%s' not found", name) return fmt.Errorf("method '%s' not found", name)

View file

@ -196,4 +196,23 @@ func TestBlockedAccounts(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
checkFAULTState(t, invokeRes) checkFAULTState(t, invokeRes)
}) })
t.Run("block-unblock contract", func(t *testing.T) {
neoHash := chain.contracts.NEO.Metadata().Hash
res, err := invokeContractMethodGeneric(chain, 100000000, policyHash, "blockAccount", true, neoHash.BytesBE())
require.NoError(t, err)
checkResult(t, res, stackitem.NewBool(true))
res, err = invokeContractMethodGeneric(chain, 100000000, neoHash, "balanceOf", true, account.BytesBE())
require.NoError(t, err)
checkFAULTState(t, res)
res, err = invokeContractMethodGeneric(chain, 100000000, policyHash, "unblockAccount", true, neoHash.BytesBE())
require.NoError(t, err)
checkResult(t, res, stackitem.NewBool(true))
res, err = invokeContractMethodGeneric(chain, 100000000, neoHash, "balanceOf", true, account.BytesBE())
require.NoError(t, err)
checkResult(t, res, stackitem.Make(0))
})
} }