vm: check Integer size on creation

This commit is contained in:
Evgenii Stratonikov 2020-04-29 17:13:46 +03:00
parent a64a0f2681
commit 70f0c656b0
3 changed files with 9 additions and 68 deletions

View file

@ -85,9 +85,7 @@ func makeStackItem(v interface{}) StackItem {
value: val, value: val,
} }
case *big.Int: case *big.Int:
return &BigIntegerItem{ return NewBigIntegerItem(val)
value: val,
}
case StackItem: case StackItem:
return val return val
case []int: case []int:
@ -311,6 +309,9 @@ type BigIntegerItem struct {
// NewBigIntegerItem returns an new BigIntegerItem object. // NewBigIntegerItem returns an new BigIntegerItem object.
func NewBigIntegerItem(value *big.Int) *BigIntegerItem { func NewBigIntegerItem(value *big.Int) *BigIntegerItem {
if value.BitLen() > MaxBigIntegerSizeBits {
panic("integer is too big")
}
return &BigIntegerItem{ return &BigIntegerItem{
value: value, value: value,
} }
@ -519,7 +520,11 @@ func (i *ByteArrayItem) TryBytes() ([]byte, error) {
// TryInteger implements StackItem interface. // TryInteger implements StackItem interface.
func (i *ByteArrayItem) TryInteger() (*big.Int, error) { func (i *ByteArrayItem) TryInteger() (*big.Int, error) {
return emit.BytesToInt(i.value), nil bi := emit.BytesToInt(i.value)
if bi.BitLen() > MaxBigIntegerSizeBits {
return nil, errors.New("integer is too big")
}
return bi, nil
} }
// Equals implements StackItem interface. // Equals implements StackItem interface.

View file

@ -775,58 +775,43 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.INC: case opcode.INC:
x := v.estack.Pop().BigInt() x := v.estack.Pop().BigInt()
a := new(big.Int).Add(x, big.NewInt(1)) a := new(big.Int).Add(x, big.NewInt(1))
v.checkBigIntSize(a)
v.estack.PushVal(a) v.estack.PushVal(a)
case opcode.DEC: case opcode.DEC:
x := v.estack.Pop().BigInt() x := v.estack.Pop().BigInt()
a := new(big.Int).Sub(x, big.NewInt(1)) a := new(big.Int).Sub(x, big.NewInt(1))
v.checkBigIntSize(a)
v.estack.PushVal(a) v.estack.PushVal(a)
case opcode.ADD: case opcode.ADD:
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
c := new(big.Int).Add(a, b) c := new(big.Int).Add(a, b)
v.checkBigIntSize(c)
v.estack.PushVal(c) v.estack.PushVal(c)
case opcode.SUB: case opcode.SUB:
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
c := new(big.Int).Sub(a, b) c := new(big.Int).Sub(a, b)
v.checkBigIntSize(c)
v.estack.PushVal(c) v.estack.PushVal(c)
case opcode.MUL: case opcode.MUL:
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
c := new(big.Int).Mul(a, b) c := new(big.Int).Mul(a, b)
v.checkBigIntSize(c)
v.estack.PushVal(c) v.estack.PushVal(c)
case opcode.DIV: case opcode.DIV:
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
v.estack.PushVal(new(big.Int).Quo(a, b)) v.estack.PushVal(new(big.Int).Quo(a, b))
case opcode.MOD: case opcode.MOD:
b := v.estack.Pop().BigInt() b := v.estack.Pop().BigInt()
v.checkBigIntSize(b)
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
v.estack.PushVal(new(big.Int).Rem(a, b)) v.estack.PushVal(new(big.Int).Rem(a, b))
@ -838,7 +823,6 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
panic(fmt.Sprintf("operand must be between %d and %d", 0, maxSHLArg)) panic(fmt.Sprintf("operand must be between %d and %d", 0, maxSHLArg))
} }
a := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt()
v.checkBigIntSize(a)
var item big.Int var item big.Int
if op == opcode.SHL { if op == opcode.SHL {
@ -847,7 +831,6 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
item.Rsh(a, uint(b)) item.Rsh(a, uint(b))
} }
v.checkBigIntSize(&item)
v.estack.PushVal(&item) v.estack.PushVal(&item)
case opcode.NOT: case opcode.NOT:
@ -1527,12 +1510,6 @@ func (v *VM) checkInvocationStackSize() {
} }
} }
func (v *VM) checkBigIntSize(a *big.Int) {
if a.BitLen() > MaxBigIntegerSizeBits {
panic("big integer is too big")
}
}
// bytesToPublicKey is a helper deserializing keys using cache and panicing on // bytesToPublicKey is a helper deserializing keys using cache and panicing on
// error. // error.
func (v *VM) bytesToPublicKey(b []byte) *keys.PublicKey { func (v *VM) bytesToPublicKey(b []byte) *keys.PublicKey {

View file

@ -1067,31 +1067,6 @@ func TestADDBigResult(t *testing.T) {
checkVMFailed(t, vm) checkVMFailed(t, vm)
} }
func testBigArgument(t *testing.T, inst opcode.Opcode) {
prog := makeProgram(inst)
x := getBigInt(MaxBigIntegerSizeBits, 0)
t.Run(inst.String()+" big 1-st argument", func(t *testing.T) {
vm := load(prog)
vm.estack.PushVal(x)
vm.estack.PushVal(0)
checkVMFailed(t, vm)
})
t.Run(inst.String()+" big 2-nd argument", func(t *testing.T) {
vm := load(prog)
vm.estack.PushVal(0)
vm.estack.PushVal(x)
checkVMFailed(t, vm)
})
}
func TestArithBigArgument(t *testing.T) {
testBigArgument(t, opcode.ADD)
testBigArgument(t, opcode.SUB)
testBigArgument(t, opcode.MUL)
testBigArgument(t, opcode.DIV)
testBigArgument(t, opcode.MOD)
}
func TestMul(t *testing.T) { func TestMul(t *testing.T) {
prog := makeProgram(opcode.MUL) prog := makeProgram(opcode.MUL)
vm := load(prog) vm := load(prog)
@ -1190,14 +1165,6 @@ func TestSHRNegative(t *testing.T) {
checkVMFailed(t, vm) checkVMFailed(t, vm)
} }
func TestSHRBigArgument(t *testing.T) {
prog := makeProgram(opcode.SHR)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, 0))
vm.estack.PushVal(1)
checkVMFailed(t, vm)
}
func TestSHLGood(t *testing.T) { func TestSHLGood(t *testing.T) {
prog := makeProgram(opcode.SHL) prog := makeProgram(opcode.SHL)
vm := load(prog) vm := load(prog)
@ -1234,14 +1201,6 @@ func TestSHLBigResult(t *testing.T) {
checkVMFailed(t, vm) checkVMFailed(t, vm)
} }
func TestSHLBigArgument(t *testing.T) {
prog := makeProgram(opcode.SHR)
vm := load(prog)
vm.estack.PushVal(getBigInt(MaxBigIntegerSizeBits, 0))
vm.estack.PushVal(1)
checkVMFailed(t, vm)
}
func TestLT(t *testing.T) { func TestLT(t *testing.T) {
prog := makeProgram(opcode.LT) prog := makeProgram(opcode.LT)
vm := load(prog) vm := load(prog)