diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 7fa745f7b..2c131c772 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -44,12 +44,15 @@ const ( // MaxInvocationStackSize is the maximum size of an invocation stack. MaxInvocationStackSize = 1024 + // MaxBigIntegerSizeBits is the maximum size of BigInt item in bits. + MaxBigIntegerSizeBits = 32 * 8 + // MaxStackSize is the maximum number of items allowed to be // on all stacks at once. MaxStackSize = 2 * 1024 - maxSHLArg = 256 - minSHLArg = -256 + maxSHLArg = MaxBigIntegerSizeBits + minSHLArg = -MaxBigIntegerSizeBits ) // VM represents the virtual machine. @@ -693,27 +696,48 @@ func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) (err error) // Numeric operations. case ADD: a := v.estack.Pop().BigInt() + v.checkBigIntSize(a) b := v.estack.Pop().BigInt() - v.estack.PushVal(new(big.Int).Add(a, b)) + v.checkBigIntSize(b) + + c := new(big.Int).Add(a, b) + v.checkBigIntSize(c) + v.estack.PushVal(c) case SUB: b := v.estack.Pop().BigInt() + v.checkBigIntSize(b) a := v.estack.Pop().BigInt() - v.estack.PushVal(new(big.Int).Sub(a, b)) + v.checkBigIntSize(a) + + c := new(big.Int).Sub(a, b) + v.checkBigIntSize(c) + v.estack.PushVal(c) case DIV: b := v.estack.Pop().BigInt() + v.checkBigIntSize(b) a := v.estack.Pop().BigInt() + v.checkBigIntSize(a) + v.estack.PushVal(new(big.Int).Div(a, b)) case MUL: a := v.estack.Pop().BigInt() + v.checkBigIntSize(a) b := v.estack.Pop().BigInt() - v.estack.PushVal(new(big.Int).Mul(a, b)) + v.checkBigIntSize(b) + + c := new(big.Int).Mul(a, b) + v.checkBigIntSize(c) + v.estack.PushVal(c) case MOD: b := v.estack.Pop().BigInt() + v.checkBigIntSize(b) a := v.estack.Pop().BigInt() + v.checkBigIntSize(a) + v.estack.PushVal(new(big.Int).Mod(a, b)) case SHL, SHR: @@ -724,12 +748,18 @@ func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) (err error) panic(fmt.Sprintf("operand must be between %d and %d", minSHLArg, maxSHLArg)) } a := v.estack.Pop().BigInt() + v.checkBigIntSize(a) + + var item big.Int if op == SHL { - v.estack.PushVal(new(big.Int).Lsh(a, uint(b))) + item.Lsh(a, uint(b)) } else { - v.estack.PushVal(new(big.Int).Rsh(a, uint(b))) + item.Rsh(a, uint(b)) } + v.checkBigIntSize(&item) + v.estack.PushVal(&item) + case BOOLAND: b := v.estack.Pop().Bool() a := v.estack.Pop().Bool() @@ -796,11 +826,15 @@ func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) (err error) case INC: x := v.estack.Pop().BigInt() - v.estack.PushVal(new(big.Int).Add(x, big.NewInt(1))) + a := new(big.Int).Add(x, big.NewInt(1)) + v.checkBigIntSize(a) + v.estack.PushVal(a) case DEC: x := v.estack.Pop().BigInt() - v.estack.PushVal(new(big.Int).Sub(x, big.NewInt(1))) + a := new(big.Int).Sub(x, big.NewInt(1)) + v.checkBigIntSize(a) + v.estack.PushVal(a) case SIGN: x := v.estack.Pop().BigInt() @@ -1381,3 +1415,9 @@ func (v *VM) checkInvocationStackSize() { panic("invocation stack is too big") } } + +func (v *VM) checkBigIntSize(a *big.Int) { + if a.BitLen() > MaxBigIntegerSizeBits { + panic("big integer is too big") + } +} diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 2819d37df..5ace7a9e3 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -508,6 +508,13 @@ func TestNOTByteArray1(t *testing.T) { assert.Equal(t, &BoolItem{false}, vm.estack.Pop().value) } +// getBigInt returns 2^a+b +func getBigInt(a, b int64) *big.Int { + p := new(big.Int).Exp(big.NewInt(2), big.NewInt(a), nil) + p.Add(p, big.NewInt(b)) + return p +} + func TestAdd(t *testing.T) { prog := makeProgram(ADD) vm := load(prog) @@ -517,6 +524,39 @@ func TestAdd(t *testing.T) { assert.Equal(t, int64(6), vm.estack.Pop().BigInt().Int64()) } +func TestADDBigResult(t *testing.T) { + prog := makeProgram(ADD) + vm := load(prog) + vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, -1)) + vm.estack.PushVal(1) + checkVMFailed(t, vm) +} + +func testBigArgument(t *testing.T, inst Instruction) { + prog := makeProgram(inst) + x := getBigInt(MaxBigIntegerSizeBits, 0) + t.Run(inst.String()+" big 1-st argument", func(t *testing.T) { + vm := load(prog) + vm.estack.PushVal(x) + vm.estack.PushVal(0) + checkVMFailed(t, vm) + }) + t.Run(inst.String()+" big 2-nd argument", func(t *testing.T) { + vm := load(prog) + vm.estack.PushVal(0) + vm.estack.PushVal(x) + checkVMFailed(t, vm) + }) +} + +func TestArithBigArgument(t *testing.T) { + testBigArgument(t, ADD) + testBigArgument(t, SUB) + testBigArgument(t, MUL) + testBigArgument(t, DIV) + testBigArgument(t, MOD) +} + func TestMul(t *testing.T) { prog := makeProgram(MUL) vm := load(prog) @@ -526,6 +566,14 @@ func TestMul(t *testing.T) { assert.Equal(t, int64(8), vm.estack.Pop().BigInt().Int64()) } +func TestMULBigResult(t *testing.T) { + prog := makeProgram(MUL) + vm := load(prog) + vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits/2+1, 0)) + vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits/2+1, 0)) + checkVMFailed(t, vm) +} + func TestDiv(t *testing.T) { prog := makeProgram(DIV) vm := load(prog) @@ -544,6 +592,14 @@ func TestSub(t *testing.T) { assert.Equal(t, int64(2), vm.estack.Pop().BigInt().Int64()) } +func TestSUBBigResult(t *testing.T) { + prog := makeProgram(SUB) + vm := load(prog) + vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, -1)) + vm.estack.PushVal(-1) + checkVMFailed(t, vm) +} + func TestSHRGood(t *testing.T) { prog := makeProgram(SHR) vm := load(prog) @@ -568,7 +624,15 @@ func TestSHRSmallValue(t *testing.T) { prog := makeProgram(SHR) vm := load(prog) vm.estack.PushVal(5) - vm.estack.PushVal(-257) + vm.estack.PushVal(minSHLArg - 1) + checkVMFailed(t, vm) +} + +func TestSHRBigArgument(t *testing.T) { + prog := makeProgram(SHR) + vm := load(prog) + vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, 0)) + vm.estack.PushVal(1) checkVMFailed(t, vm) } @@ -596,7 +660,23 @@ func TestSHLBigValue(t *testing.T) { prog := makeProgram(SHL) vm := load(prog) vm.estack.PushVal(5) - vm.estack.PushVal(257) + vm.estack.PushVal(maxSHLArg + 1) + checkVMFailed(t, vm) +} + +func TestSHLBigResult(t *testing.T) { + prog := makeProgram(SHL) + vm := load(prog) + vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits/2, 0)) + vm.estack.PushVal(MaxBigIntegerSizeBits / 2) + checkVMFailed(t, vm) +} + +func TestSHLBigArgument(t *testing.T) { + prog := makeProgram(SHR) + vm := load(prog) + vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, 0)) + vm.estack.PushVal(1) checkVMFailed(t, vm) } @@ -734,6 +814,35 @@ func TestINC(t *testing.T) { assert.Equal(t, big.NewInt(2), vm.estack.Pop().BigInt()) } +func TestINCBigResult(t *testing.T) { + prog := makeProgram(INC, INC) + vm := load(prog) + x := getBigInt(MaxBigIntegerSizeBits, -2) + vm.estack.PushVal(x) + + require.NoError(t, vm.Step()) + require.False(t, vm.HasFailed()) + require.Equal(t, 1, vm.estack.Len()) + require.Equal(t, new(big.Int).Add(x, big.NewInt(1)), vm.estack.Top().BigInt()) + + checkVMFailed(t, vm) +} + +func TestDECBigResult(t *testing.T) { + prog := makeProgram(DEC, DEC) + vm := load(prog) + x := getBigInt(MaxBigIntegerSizeBits, -2) + x.Neg(x) + vm.estack.PushVal(x) + + require.NoError(t, vm.Step()) + require.False(t, vm.HasFailed()) + require.Equal(t, 1, vm.estack.Len()) + require.Equal(t, new(big.Int).Sub(x, big.NewInt(1)), vm.estack.Top().BigInt()) + + checkVMFailed(t, vm) +} + func TestNEWARRAYInteger(t *testing.T) { prog := makeProgram(NEWARRAY) vm := load(prog)