From 0394a79ef8fc4aa6afaaa899aaea5cc9ffbafe22 Mon Sep 17 00:00:00 2001 From: Evgeniy Stratonikov Date: Wed, 9 Feb 2022 15:13:15 +0300 Subject: [PATCH] rpc: support big integers as request parameters Signed-off-by: Evgeniy Stratonikov --- pkg/rpc/request/param.go | 120 ++++++++++++++++++++--------- pkg/rpc/request/param_test.go | 44 +++++++++++ pkg/rpc/request/txBuilder.go | 4 +- pkg/rpc/request/tx_builder_test.go | 9 +++ pkg/vm/emit/emit.go | 28 ++++++- 5 files changed, 161 insertions(+), 44 deletions(-) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 77bef2e85..36f4b41c2 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "math/big" "strconv" "strings" @@ -115,7 +116,7 @@ func (p *Param) GetString() (string, error) { if err == nil { p.cache = s } else { - var i int + var i int64 err = json.Unmarshal(p.RawMessage, &i) if err == nil { p.cache = i @@ -133,8 +134,8 @@ func (p *Param) GetString() (string, error) { switch t := p.cache.(type) { case string: return t, nil - case int: - return strconv.Itoa(t), nil + case int64: + return strconv.FormatInt(t, 10), nil case bool: if t { return "true", nil @@ -180,7 +181,7 @@ func (p *Param) GetBoolean() (bool, error) { if err == nil { p.cache = s } else { - var i int + var i int64 err = json.Unmarshal(p.RawMessage, &i) if err == nil { p.cache = i @@ -195,7 +196,7 @@ func (p *Param) GetBoolean() (bool, error) { return t, nil case string: return t != "", nil - case int: + case int64: return t != 0, nil default: return false, errNotABool @@ -210,20 +211,46 @@ func (p *Param) GetIntStrict() (int, error) { if p.IsNull() { return 0, errNotAnInt } - if p.cache == nil { - var i int - err := json.Unmarshal(p.RawMessage, &i) - if err != nil { - return i, errNotAnInt - } - p.cache = i + value, err := p.fillIntCache() + if err != nil { + return 0, err } - if i, ok := p.cache.(int); ok { - return i, nil + if i, ok := value.(int64); ok && i == int64(int(i)) { + return int(i), nil } return 0, errNotAnInt } +func (p *Param) fillIntCache() (interface{}, error) { + if p.cache != nil { + return p.cache, nil + } + + // We could also try unmarshalling to uint64, but JSON reliably supports numbers + // up to 53 bits in size. + var i int64 + err := json.Unmarshal(p.RawMessage, &i) + if err == nil { + p.cache = i + return i, nil + } + + var s string + err = json.Unmarshal(p.RawMessage, &s) + if err == nil { + p.cache = s + return s, nil + } + + var b bool + err = json.Unmarshal(p.RawMessage, &b) + if err == nil { + p.cache = b + return b, nil + } + return nil, errNotAnInt +} + // GetInt returns int value of the parameter or tries to cast parameter to an int value. func (p *Param) GetInt() (int, error) { if p == nil { @@ -232,30 +259,16 @@ func (p *Param) GetInt() (int, error) { if p.IsNull() { return 0, errNotAnInt } - if p.cache == nil { - var i int - err := json.Unmarshal(p.RawMessage, &i) - if err == nil { - p.cache = i - } else { - var s string - err = json.Unmarshal(p.RawMessage, &s) - if err == nil { - p.cache = s - } else { - var b bool - err = json.Unmarshal(p.RawMessage, &b) - if err == nil { - p.cache = b - } else { - return 0, errNotAnInt - } - } - } + value, err := p.fillIntCache() + if err != nil { + return 0, err } - switch t := p.cache.(type) { - case int: - return t, nil + switch t := value.(type) { + case int64: + if t == int64(int(t)) { + return int(t), nil + } + return 0, errNotAnInt case string: return strconv.Atoi(t) case bool: @@ -264,7 +277,38 @@ func (p *Param) GetInt() (int, error) { } return 0, nil default: - return 0, errNotAnInt + panic("unreachable") + } +} + +// GetBigInt returns big-interer value of the parameter. +func (p *Param) GetBigInt() (*big.Int, error) { + if p == nil { + return nil, errMissingParameter + } + if p.IsNull() { + return nil, errNotAnInt + } + value, err := p.fillIntCache() + if err != nil { + return nil, err + } + switch t := value.(type) { + case int64: + return big.NewInt(t), nil + case string: + bi, ok := new(big.Int).SetString(t, 10) + if !ok { + return nil, errNotAnInt + } + return bi, nil + case bool: + if t { + return big.NewInt(1), nil + } + return new(big.Int), nil + default: + panic("unreachable") } } diff --git a/pkg/rpc/request/param_test.go b/pkg/rpc/request/param_test.go index 1b5900700..91137b05c 100644 --- a/pkg/rpc/request/param_test.go +++ b/pkg/rpc/request/param_test.go @@ -5,6 +5,8 @@ import ( "encoding/hex" "encoding/json" "fmt" + "math" + "math/big" "testing" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -190,6 +192,48 @@ func TestParam_UnmarshalJSON(t *testing.T) { } } +func TestGetBigInt(t *testing.T) { + maxUint64 := new(big.Int).SetUint64(math.MaxUint64) + minInt64 := big.NewInt(math.MinInt64) + testCases := []struct { + raw string + expected *big.Int + }{ + {"true", big.NewInt(1)}, + {"false", new(big.Int)}, + {"42", big.NewInt(42)}, + {`"` + minInt64.String() + `"`, minInt64}, + {`"` + maxUint64.String() + `"`, maxUint64}, + {`"` + minInt64.String() + `000"`, new(big.Int).Mul(minInt64, big.NewInt(1000))}, + {`"` + maxUint64.String() + `000"`, new(big.Int).Mul(maxUint64, big.NewInt(1000))}, + {`"abc"`, nil}, + {`[]`, nil}, + {`null`, nil}, + } + + for _, tc := range testCases { + var p Param + require.NoError(t, json.Unmarshal([]byte(tc.raw), &p)) + + actual, err := p.GetBigInt() + if tc.expected == nil { + require.Error(t, err) + continue + } + require.NoError(t, err) + require.Equal(t, tc.expected, actual) + + expected := tc.expected.Int64() + actualInt, err := p.GetInt() + if !actual.IsInt64() || int64(int(expected)) != expected { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, int(expected), actualInt) + } + } +} + func TestGetWitness(t *testing.T) { accountHash, err := util.Uint160DecodeStringLE("cadb3dc2faa3ef14a13b619c9a43124755aa2569") require.NoError(t, err) diff --git a/pkg/rpc/request/txBuilder.go b/pkg/rpc/request/txBuilder.go index 4a18af412..044db0b2e 100644 --- a/pkg/rpc/request/txBuilder.go +++ b/pkg/rpc/request/txBuilder.go @@ -63,11 +63,11 @@ func ExpandArrayIntoScript(script *io.BinWriter, slice []Param) error { } emit.Bytes(script, key.Bytes()) case smartcontract.IntegerType: - val, err := fp.Value.GetInt() + bi, err := fp.Value.GetBigInt() if err != nil { return err } - emit.Int(script, int64(val)) + emit.BigInt(script, bi) case smartcontract.BoolType: val, err := fp.Value.GetBoolean() // not GetBooleanStrict(), because that's the way C# code works if err != nil { diff --git a/pkg/rpc/request/tx_builder_test.go b/pkg/rpc/request/tx_builder_test.go index 1cc837fe6..b5d056667 100644 --- a/pkg/rpc/request/tx_builder_test.go +++ b/pkg/rpc/request/tx_builder_test.go @@ -3,6 +3,7 @@ package request import ( "encoding/hex" "fmt" + "math/big" "testing" "github.com/nspcc-dev/neo-go/pkg/io" @@ -100,6 +101,10 @@ func TestInvocationScriptCreationBad(t *testing.T) { } func TestExpandArrayIntoScript(t *testing.T) { + bi := new(big.Int).Lsh(big.NewInt(1), 254) + rawInt := make([]byte, 32) + rawInt[31] = 0x40 + testCases := []struct { Input []Param Expected []byte @@ -112,6 +117,10 @@ func TestExpandArrayIntoScript(t *testing.T) { Input: []Param{{RawMessage: []byte(`{"type": "Array", "value": [{"type": "String", "value": "a"}]}`)}}, Expected: []byte{byte(opcode.PUSHDATA1), 1, byte('a'), byte(opcode.PUSH1), byte(opcode.PACK)}, }, + { + Input: []Param{{RawMessage: []byte(`{"type": "Integer", "value": "` + bi.String() + `"}`)}}, + Expected: append([]byte{byte(opcode.PUSHINT256)}, rawInt...), + }, } for _, c := range testCases { script := io.NewBufBinWriter() diff --git a/pkg/vm/emit/emit.go b/pkg/vm/emit/emit.go index e3146d76c..6c0417066 100644 --- a/pkg/vm/emit/emit.go +++ b/pkg/vm/emit/emit.go @@ -52,18 +52,38 @@ func padRight(s int, buf []byte) []byte { // Int emits a int type to the given buffer. func Int(w *io.BinWriter, i int64) { + if smallInt(w, i) { + return + } + bigInt(w, big.NewInt(i), false) +} + +// BigInt emits big-integer to the given buffer. +func BigInt(w *io.BinWriter, n *big.Int) { + bigInt(w, n, true) +} + +func smallInt(w *io.BinWriter, i int64) bool { switch { case i == -1: Opcodes(w, opcode.PUSHM1) case i >= 0 && i < 16: - val := opcode.Opcode(int(opcode.PUSH1) - 1 + int(i)) + val := opcode.Opcode(int(opcode.PUSH0) + int(i)) Opcodes(w, val) default: - bigInt(w, big.NewInt(i)) + return false } + return true } -func bigInt(w *io.BinWriter, n *big.Int) { +func bigInt(w *io.BinWriter, n *big.Int, trySmall bool) { + if w.Err != nil { + return + } + if trySmall && n.IsInt64() && smallInt(w, n.Int64()) { + return + } + buf := bigint.ToPreallocatedBytes(n, make([]byte, 0, 32)) if len(buf) == 0 { Opcodes(w, opcode.PUSH0) @@ -101,7 +121,7 @@ func Array(w *io.BinWriter, es ...interface{}) { case int: Int(w, int64(e)) case *big.Int: - bigInt(w, e) + BigInt(w, e) case string: String(w, e) case util.Uint160: