vm: restrict BigInteger item size

This commit is contained in:
Evgenii Stratonikov 2019-11-07 12:14:36 +03:00
parent f686069f37
commit 439cd72294
2 changed files with 156 additions and 7 deletions

View file

@ -44,6 +44,9 @@ const (
// MaxInvocationStackSize is the maximum size of an invocation stack. // MaxInvocationStackSize is the maximum size of an invocation stack.
MaxInvocationStackSize = 1024 MaxInvocationStackSize = 1024
// MaxBigIntegerSizeBits is the maximum size of BigInt item in bits.
MaxBigIntegerSizeBits = 32 * 8
// MaxStackSize is the maximum number of items allowed to be // MaxStackSize is the maximum number of items allowed to be
// on all stacks at once. // on all stacks at once.
MaxStackSize = 2 * 1024 MaxStackSize = 2 * 1024
@ -693,27 +696,48 @@ func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) (err error)
// Numeric operations. // Numeric operations.
case ADD: case ADD:
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.estack.PushVal(new(big.Int).Add(a, b)) v.checkBigIntSize(b)
c := new(big.Int).Add(a, b)
v.checkBigIntSize(c)
v.estack.PushVal(c)
case SUB: case SUB:
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.estack.PushVal(new(big.Int).Sub(a, b)) v.checkBigIntSize(a)
c := new(big.Int).Sub(a, b)
v.checkBigIntSize(c)
v.estack.PushVal(c)
case DIV: case DIV:
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
v.estack.PushVal(new(big.Int).Div(a, b)) v.estack.PushVal(new(big.Int).Div(a, b))
case MUL: case MUL:
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.estack.PushVal(new(big.Int).Mul(a, b)) v.checkBigIntSize(b)
c := new(big.Int).Mul(a, b)
v.checkBigIntSize(c)
v.estack.PushVal(c)
case MOD: case MOD:
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
v.estack.PushVal(new(big.Int).Mod(a, b)) v.estack.PushVal(new(big.Int).Mod(a, b))
case SHL, SHR: case SHL, SHR:
@ -724,12 +748,18 @@ func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) (err error)
panic(fmt.Sprintf("operand must be between %d and %d", minSHLArg, maxSHLArg)) panic(fmt.Sprintf("operand must be between %d and %d", minSHLArg, maxSHLArg))
} }
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
var item big.Int
if op == SHL { if op == SHL {
v.estack.PushVal(new(big.Int).Lsh(a, uint(b))) item.Lsh(a, uint(b))
} else { } else {
v.estack.PushVal(new(big.Int).Rsh(a, uint(b))) item.Rsh(a, uint(b))
} }
v.checkBigIntSize(&item)
v.estack.PushVal(&item)
case BOOLAND: case BOOLAND:
b := v.estack.Pop().Bool() b := v.estack.Pop().Bool()
a := v.estack.Pop().Bool() a := v.estack.Pop().Bool()
@ -796,11 +826,15 @@ func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) (err error)
case INC: case INC:
x := v.estack.Pop().BigInt() x := v.estack.Pop().BigInt()
v.estack.PushVal(new(big.Int).Add(x, big.NewInt(1))) a := new(big.Int).Add(x, big.NewInt(1))
v.checkBigIntSize(a)
v.estack.PushVal(a)
case DEC: case DEC:
x := v.estack.Pop().BigInt() x := v.estack.Pop().BigInt()
v.estack.PushVal(new(big.Int).Sub(x, big.NewInt(1))) a := new(big.Int).Sub(x, big.NewInt(1))
v.checkBigIntSize(a)
v.estack.PushVal(a)
case SIGN: case SIGN:
x := v.estack.Pop().BigInt() x := v.estack.Pop().BigInt()
@ -1381,3 +1415,9 @@ func (v *VM) checkInvocationStackSize() {
panic("invocation stack is too big") panic("invocation stack is too big")
} }
} }
func (v *VM) checkBigIntSize(a *big.Int) {
if a.BitLen() > MaxBigIntegerSizeBits {
panic("big integer is too big")
}
}

View file

@ -508,6 +508,13 @@ func TestNOTByteArray1(t *testing.T) {
assert.Equal(t, &BoolItem{false}, vm.estack.Pop().value) assert.Equal(t, &BoolItem{false}, vm.estack.Pop().value)
} }
// getBigInt returns 2^a+b
func getBigInt(a, b int64) *big.Int {
p := new(big.Int).Exp(big.NewInt(2), big.NewInt(a), nil)
p.Add(p, big.NewInt(b))
return p
}
func TestAdd(t *testing.T) { func TestAdd(t *testing.T) {
prog := makeProgram(ADD) prog := makeProgram(ADD)
vm := load(prog) vm := load(prog)
@ -517,6 +524,39 @@ func TestAdd(t *testing.T) {
assert.Equal(t, int64(6), vm.estack.Pop().BigInt().Int64()) assert.Equal(t, int64(6), vm.estack.Pop().BigInt().Int64())
} }
func TestADDBigResult(t *testing.T) {
prog := makeProgram(ADD)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, -1))
vm.estack.PushVal(1)
checkVMFailed(t, vm)
}
func testBigArgument(t *testing.T, inst Instruction) {
prog := makeProgram(inst)
x := getBigInt(MaxBigIntegerSizeBits, 0)
t.Run(inst.String()+" big 1-st argument", func(t *testing.T) {
vm := load(prog)
vm.estack.PushVal(x)
vm.estack.PushVal(0)
checkVMFailed(t, vm)
})
t.Run(inst.String()+" big 2-nd argument", func(t *testing.T) {
vm := load(prog)
vm.estack.PushVal(0)
vm.estack.PushVal(x)
checkVMFailed(t, vm)
})
}
func TestArithBigArgument(t *testing.T) {
testBigArgument(t, ADD)
testBigArgument(t, SUB)
testBigArgument(t, MUL)
testBigArgument(t, DIV)
testBigArgument(t, MOD)
}
func TestMul(t *testing.T) { func TestMul(t *testing.T) {
prog := makeProgram(MUL) prog := makeProgram(MUL)
vm := load(prog) vm := load(prog)
@ -526,6 +566,14 @@ func TestMul(t *testing.T) {
assert.Equal(t, int64(8), vm.estack.Pop().BigInt().Int64()) assert.Equal(t, int64(8), vm.estack.Pop().BigInt().Int64())
} }
func TestMULBigResult(t *testing.T) {
prog := makeProgram(MUL)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits/2+1, 0))
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits/2+1, 0))
checkVMFailed(t, vm)
}
func TestDiv(t *testing.T) { func TestDiv(t *testing.T) {
prog := makeProgram(DIV) prog := makeProgram(DIV)
vm := load(prog) vm := load(prog)
@ -544,6 +592,14 @@ func TestSub(t *testing.T) {
assert.Equal(t, int64(2), vm.estack.Pop().BigInt().Int64()) assert.Equal(t, int64(2), vm.estack.Pop().BigInt().Int64())
} }
func TestSUBBigResult(t *testing.T) {
prog := makeProgram(SUB)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, -1))
vm.estack.PushVal(-1)
checkVMFailed(t, vm)
}
func TestSHRGood(t *testing.T) { func TestSHRGood(t *testing.T) {
prog := makeProgram(SHR) prog := makeProgram(SHR)
vm := load(prog) vm := load(prog)
@ -572,6 +628,14 @@ func TestSHRSmallValue(t *testing.T) {
checkVMFailed(t, vm) checkVMFailed(t, vm)
} }
func TestSHRBigArgument(t *testing.T) {
prog := makeProgram(SHR)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, 0))
vm.estack.PushVal(1)
checkVMFailed(t, vm)
}
func TestSHLGood(t *testing.T) { func TestSHLGood(t *testing.T) {
prog := makeProgram(SHL) prog := makeProgram(SHL)
vm := load(prog) vm := load(prog)
@ -600,6 +664,22 @@ func TestSHLBigValue(t *testing.T) {
checkVMFailed(t, vm) checkVMFailed(t, vm)
} }
func TestSHLBigResult(t *testing.T) {
prog := makeProgram(SHL)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits/2, 0))
vm.estack.PushVal(MaxBigIntegerSizeBits / 2)
checkVMFailed(t, vm)
}
func TestSHLBigArgument(t *testing.T) {
prog := makeProgram(SHR)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, 0))
vm.estack.PushVal(1)
checkVMFailed(t, vm)
}
func TestLT(t *testing.T) { func TestLT(t *testing.T) {
prog := makeProgram(LT) prog := makeProgram(LT)
vm := load(prog) vm := load(prog)
@ -734,6 +814,35 @@ func TestINC(t *testing.T) {
assert.Equal(t, big.NewInt(2), vm.estack.Pop().BigInt()) assert.Equal(t, big.NewInt(2), vm.estack.Pop().BigInt())
} }
func TestINCBigResult(t *testing.T) {
prog := makeProgram(INC, INC)
vm := load(prog)
x := getBigInt(MaxBigIntegerSizeBits, -2)
vm.estack.PushVal(x)
require.NoError(t, vm.Step())
require.False(t, vm.HasFailed())
require.Equal(t, 1, vm.estack.Len())
require.Equal(t, new(big.Int).Add(x, big.NewInt(1)), vm.estack.Top().BigInt())
checkVMFailed(t, vm)
}
func TestDECBigResult(t *testing.T) {
prog := makeProgram(DEC, DEC)
vm := load(prog)
x := getBigInt(MaxBigIntegerSizeBits, -2)
x.Neg(x)
vm.estack.PushVal(x)
require.NoError(t, vm.Step())
require.False(t, vm.HasFailed())
require.Equal(t, 1, vm.estack.Len())
require.Equal(t, new(big.Int).Sub(x, big.NewInt(1)), vm.estack.Top().BigInt())
checkVMFailed(t, vm)
}
func TestNEWARRAYInteger(t *testing.T) { func TestNEWARRAYInteger(t *testing.T) {
prog := makeProgram(NEWARRAY) prog := makeProgram(NEWARRAY)
vm := load(prog) vm := load(prog)