diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index eae4aaab5..77c69c528 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -63,6 +63,9 @@ type VM struct { // callbacks to get interops. getInterop []InteropGetterFunc + // callback to get interop price + getPrice func(*VM, opcode.Opcode, []byte) util.Fixed8 + // callback to get scripts. getScript func(util.Uint160) []byte @@ -76,6 +79,8 @@ type VM struct { itemCount map[StackItem]int size int + gasConsumed util.Fixed8 + // Public keys cache. keys map[string]*keys.PublicKey } @@ -114,6 +119,17 @@ func (v *VM) RegisterInteropGetter(f InteropGetterFunc) { v.getInterop = append(v.getInterop, f) } +// SetPriceGetter registers the given PriceGetterFunc in v. +// f accepts vm's Context, current instruction and instruction parameter. +func (v *VM) SetPriceGetter(f func(*VM, opcode.Opcode, []byte) util.Fixed8) { + v.getPrice = f +} + +// GasConsumed returns the amount of GAS consumed during execution. +func (v *VM) GasConsumed() util.Fixed8 { + return v.gasConsumed +} + // Estack returns the evaluation stack so interop hooks can utilize this. func (v *VM) Estack() *Stack { return v.estack @@ -225,6 +241,7 @@ func (v *VM) Load(prog []byte) { v.estack.Clear() v.astack.Clear() v.state = noneState + v.gasConsumed = 0 v.LoadScript(prog) } @@ -464,6 +481,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } }() + if v.getPrice != nil && ctx.ip < len(ctx.prog) { + v.gasConsumed += v.getPrice(v, op, parameter) + } + if op >= opcode.PUSHBYTES1 && op <= opcode.PUSHBYTES75 { v.estack.PushVal(parameter) return diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 2e9fb9ae2..012515124 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -62,6 +62,40 @@ func TestRegisterInteropGetter(t *testing.T) { assert.Equal(t, currRegistered+1, len(v.getInterop)) } +func TestVM_SetPriceGetter(t *testing.T) { + v := New() + prog := []byte{ + byte(opcode.PUSH4), byte(opcode.PUSH2), + byte(opcode.PUSHDATA1), 0x01, 0x01, + byte(opcode.PUSHDATA1), 0x02, 0xCA, 0xFE, + byte(opcode.PUSH4), byte(opcode.RET), + } + + t.Run("no price getter", func(t *testing.T) { + v.Load(prog) + runVM(t, v) + + require.EqualValues(t, 0, v.GasConsumed()) + }) + + v.SetPriceGetter(func(_ *VM, op opcode.Opcode, p []byte) util.Fixed8 { + if op == opcode.PUSH4 { + return 1 + } else if op == opcode.PUSHDATA1 && bytes.Equal(p, []byte{0xCA, 0xFE}) { + return 7 + } + + return 0 + }) + + t.Run("with price getter", func(t *testing.T) { + v.Load(prog) + runVM(t, v) + + require.EqualValues(t, 9, v.gasConsumed) + }) +} + func TestBytesToPublicKey(t *testing.T) { v := New() cache := v.GetPublicKeys()