neo-go/pkg/core/interop/contract/call.go
Anna Shaleva 4945145b09 interop: use executing contract state for permissions checks
Do not use the updated contract state from native Management to perform
permissions checks. We need to use the currently executing state
instead got from the currently executing VM context until context is
unloaded.

Close #3471.

Signed-off-by: Anna Shaleva <shaleva.ann@nspcc.ru>
2024-06-03 12:32:10 +03:00

181 lines
5.8 KiB
Go

package contract
import (
"errors"
"fmt"
"math/big"
"strings"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
)
type policyChecker interface {
IsBlocked(*dao.Simple, util.Uint160) bool
}
// LoadToken calls method specified by the token id.
func LoadToken(ic *interop.Context, id int32) error {
ctx := ic.VM.Context()
if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) {
return errors.New("invalid call flags")
}
tok := ctx.GetNEF().Tokens[id]
if int(tok.ParamCount) > ctx.Estack().Len() {
return errors.New("stack is too small")
}
args := make([]stackitem.Item, tok.ParamCount)
for i := range args {
args[i] = ic.VM.Estack().Pop().Item()
}
cs, err := ic.GetContract(tok.Hash)
if err != nil {
return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err)
}
return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, false)
}
// Call calls a contract with flags.
func Call(ic *interop.Context) error {
h := ic.VM.Estack().Pop().Bytes()
method := ic.VM.Estack().Pop().String()
fs := callflag.CallFlag(int32(ic.VM.Estack().Pop().BigInt().Int64()))
if fs&^callflag.All != 0 {
return errors.New("call flags out of range")
}
args := ic.VM.Estack().Pop().Array()
u, err := util.Uint160DecodeBytesBE(h)
if err != nil {
return errors.New("invalid contract hash")
}
cs, err := ic.GetContract(u)
if err != nil {
return fmt.Errorf("called contract %s not found: %w", u.StringLE(), err)
}
if strings.HasPrefix(method, "_") {
return errors.New("invalid method name (starts with '_')")
}
md := cs.Manifest.ABI.GetMethod(method, len(args))
if md == nil {
return fmt.Errorf("method not found: %s/%d", method, len(args))
}
hasReturn := md.ReturnType != smartcontract.VoidType
return callInternal(ic, cs, method, fs, hasReturn, args, true)
}
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
hasReturn bool, args []stackitem.Item, isDynamic bool) error {
md := cs.Manifest.ABI.GetMethod(name, len(args))
if md.Safe {
f &^= (callflag.WriteStates | callflag.AllowNotify)
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
curr, err := ic.GetContract(ic.VM.GetCurrentScriptHash())
if err == nil {
if !curr.Manifest.CanCall(cs.Hash, &cs.Manifest, name) {
return errors.New("disallowed method call")
}
}
}
return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, isDynamic, false)
}
// callExFromNative calls a contract with flags using the provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, isDynamic bool, callFromNative bool) error {
for _, nc := range ic.Natives {
if nc.Metadata().Name == nativenames.Policy {
var pch = nc.(policyChecker)
if pch.IsBlocked(ic.DAO, cs.Hash) {
return fmt.Errorf("contract %s is blocked", cs.Hash.StringLE())
}
break
}
}
md := cs.Manifest.ABI.GetMethod(name, len(args))
if md == nil {
return fmt.Errorf("method '%s' not found", name)
}
if len(args) != len(md.Parameters) {
return fmt.Errorf("invalid argument count: %d (expected %d)", len(args), len(md.Parameters))
}
methodOff := md.Offset
initOff := -1
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
if md != nil {
initOff = md.Offset
}
ic.Invocations[cs.Hash]++
f = ic.VM.Context().GetCallFlags() & f
wrapped := ic.VM.ContractHasTryBlock() && // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs.
f&(callflag.All^callflag.ReadOnly) != 0 // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications.
baseNtfCount := len(ic.Notifications)
baseDAO := ic.DAO
if wrapped {
ic.DAO = ic.DAO.GetPrivate()
}
onUnload := func(v *vm.VM, ctx *vm.Context, commit bool) error {
if wrapped {
if commit {
_, err := ic.DAO.Persist()
if err != nil {
return fmt.Errorf("failed to persist changes %w", err)
}
} else {
ic.Notifications = ic.Notifications[:baseNtfCount] // Rollback all notification changes made by current context.
}
ic.DAO = baseDAO
}
if callFromNative && !commit {
return fmt.Errorf("unhandled exception")
}
if isDynamic {
return vm.DynamicOnUnload(v, ctx, commit)
}
return nil
}
ic.VM.LoadNEFMethod(&cs.NEF, &cs.Manifest, caller, cs.Hash, f,
hasReturn, methodOff, initOff, onUnload)
for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- {
e.PushItem(args[i])
}
return nil
}
// ErrNativeCall is returned for failed calls from native.
var ErrNativeCall = errors.New("failed native call")
// CallFromNative performs synchronous call from native contract.
func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error {
startSize := len(ic.VM.Istack())
if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, false, true); err != nil {
return err
}
for !ic.VM.HasStopped() && len(ic.VM.Istack()) > startSize {
if err := ic.VM.Step(); err != nil {
return fmt.Errorf("%w: %w", ErrNativeCall, err)
}
}
if ic.VM.HasFailed() {
return ErrNativeCall
}
return nil
}
// GetCallFlags returns current context calling flags.
func GetCallFlags(ic *interop.Context) error {
ic.VM.Estack().PushItem(stackitem.NewBigInteger(big.NewInt(int64(ic.VM.Context().GetCallFlags()))))
return nil
}