diff --git a/cli/smartcontract/smart_contract.go b/cli/smartcontract/smart_contract.go index d3f6796f5..accac7614 100644 --- a/cli/smartcontract/smart_contract.go +++ b/cli/smartcontract/smart_contract.go @@ -14,6 +14,7 @@ import ( "github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/rpc" + "github.com/CityOfZion/neo-go/pkg/smartcontract" "github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/vm" "github.com/CityOfZion/neo-go/pkg/vm/compiler" @@ -27,6 +28,7 @@ var ( errNoInput = errors.New("no input file was found, specify an input file with the '--in or -i' flag") errNoConfFile = errors.New("no config file was found, specify a config file with the '--config' or '-c' flag") errNoWIF = errors.New("no WIF parameter found, specify it with the '--wif or -w' flag") + errNoScriptHash = errors.New("no smart contract hash was provided, specify one as the first argument") errNoSmartContractName = errors.New("no name was provided, specify the '--name or -n' flag") errFileExist = errors.New("A file with given smart-contract name already exists") ) @@ -96,9 +98,83 @@ func NewCommands() []cli.Command { }, }, { - Name: "testinvoke", - Usage: "Test an invocation of a smart contract on the blockchain", - Action: testInvoke, + Name: "testinvokefunction", + Usage: "invoke deployed contract on the blockchain (test mode)", + UsageText: "neo-go contract testinvokefunction -e endpoint scripthash [method] [arguments...]", + Description: `Executes given (as a script hash) deployed script with the given method and + arguments. If no method is given "" is passed to the script, if no arguments + are given, an empty array is passed. All of the given arguments are + encapsulated into array before invoking the script. The script thus should + follow the regular convention of smart contract arguments (method string and + an array of other arguments). + + Arguments always do have regular Neo smart contract parameter types, either + specified explicitly or being inferred from the value. To specify the type + manually use "type:value" syntax where the type is one of the following: + 'signature', 'bool', 'int', 'hash160', 'hash256', 'bytes', 'key' or 'string'. + Array types are not currently supported. + + Given values are type-checked against given types with the following + restrictions applied: + * 'signature' type values should be hex-encoded and have a (decoded) + length of 64 bytes. + * 'bool' type values are 'true' and 'false'. + * 'int' values are decimal integers that can be successfully converted + from the string. + * 'hash160' values are Neo addresses and hex-encoded 20-bytes long (after + decoding) strings. + * 'hash256' type values should be hex-encoded and have a (decoded) + length of 32 bytes. + * 'bytes' type values are any hex-encoded things. + * 'key' type values are hex-encoded marshalled public keys. + * 'string' type values are any valid UTF-8 strings. In the value's part of + the string the colon looses it's special meaning as a separator between + type and value and is taken literally. + + If no type is explicitly specified, it is inferred from the value using the + following logic: + - anything that can be interpreted as a decimal integer gets + an 'int' type + - 'true' and 'false' strings get 'bool' type + - valid Neo addresses and 20 bytes long hex-encoded strings get 'hash160' + type + - valid hex-encoded public keys get 'key' type + - 32 bytes long hex-encoded values get 'hash256' type + - 64 bytes long hex-encoded values get 'signature' type + - any other valid hex-encoded values get 'bytes' type + - anything else is a 'string' + + Backslash character is used as an escape character and allows to use colon in + an implicitly typed string. For any other characters it has no special + meaning, to get a literal backslash in the string use the '\\' sequence. + + Examples: + * 'int:42' is an integer with a value of 42 + * '42' is an integer with a value of 42 + * 'bad' is a string with a value of 'bad' + * 'dead' is a byte array with a value of 'dead' + * 'string:dead' is a string with a value of 'dead' + * 'AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y' is a hash160 with a value + of '23ba2703c53263e8d6e522dc32203339dcd8eee9' + * '\4\2' is an integer with a value of 42 + * '\\4\2' is a string with a value of '\42' + * 'string:string' is a string with a value of 'string' + * 'string\:string' is a string with a value of 'string:string' + * '03b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c' is a + key with a value of '03b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c' +`, + Action: testInvokeFunction, + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "endpoint, e", + Usage: "RPC endpoint address (like 'http://seed4.ngd.network:20332')", + }, + }, + }, + { + Name: "testinvokescript", + Usage: "Invoke compiled AVM code on the blockchain (test mode, not creating a transaction for it)", + Action: testInvokeScript, Flags: []cli.Flag{ cli.StringFlag{ Name: "endpoint, e", @@ -211,7 +287,53 @@ func contractCompile(ctx *cli.Context) error { return nil } -func testInvoke(ctx *cli.Context) error { +func testInvokeFunction(ctx *cli.Context) error { + endpoint := ctx.String("endpoint") + if len(endpoint) == 0 { + return cli.NewExitError(errNoEndpoint, 1) + } + + args := ctx.Args() + if !args.Present() { + return cli.NewExitError(errNoScriptHash, 1) + } + script := args[0] + operation := "" + if len(args) > 1 { + operation = args[1] + } + params := make([]smartcontract.Parameter, 0) + if len(args) > 2 { + for k, s := range args[2:] { + param, err := smartcontract.NewParameterFromString(s) + if err != nil { + return cli.NewExitError(fmt.Errorf("failed to parse argument #%d: %v", k+2+1, err), 1) + } + params = append(params, *param) + } + } + + client, err := rpc.NewClient(context.TODO(), endpoint, rpc.ClientOptions{}) + if err != nil { + return cli.NewExitError(err, 1) + } + + resp, err := client.InvokeFunction(script, operation, params) + if err != nil { + return cli.NewExitError(err, 1) + } + + b, err := json.MarshalIndent(resp.Result, "", " ") + if err != nil { + return cli.NewExitError(err, 1) + } + + fmt.Println(string(b)) + + return nil +} + +func testInvokeScript(ctx *cli.Context) error { src := ctx.String("in") if len(src) == 0 { return cli.NewExitError(errNoInput, 1) diff --git a/docs/rpc.md b/docs/rpc.md index baf060fff..4c65245e8 100644 --- a/docs/rpc.md +++ b/docs/rpc.md @@ -55,12 +55,22 @@ which would yield the response: | `getunspents` | Yes | | `getversion` | Yes | | `invoke` | No (#346) | -| `invokefunction` | No (#347) | +| `invokefunction` | Yes | | `invokescript` | Yes | | `sendrawtransaction` | Yes | | `submitblock` | No (#344) | | `validateaddress` | Yes | +#### Implementation notices + +##### `invokefunction` + +neo-go's implementation of `invokefunction` does not return `tx` field in the +answer because that requires signing the transaction with some key in the +server which doesn't fit the model of our node-client interactions. Lacking +this signature the transaction is almost useless, so there is no point in +returning it. + ## Reference * [JSON-RPC 2.0 Specification](http://www.jsonrpc.org/specification) diff --git a/pkg/rpc/param.go b/pkg/rpc/param.go index 08f50615c..1e6ee7972 100644 --- a/pkg/rpc/param.go +++ b/pkg/rpc/param.go @@ -1,10 +1,12 @@ package rpc import ( + "bytes" "encoding/hex" "encoding/json" "fmt" + "github.com/CityOfZion/neo-go/pkg/crypto" "github.com/CityOfZion/neo-go/pkg/util" "github.com/pkg/errors" ) @@ -19,6 +21,12 @@ type ( } paramType int + // FuncParam represents a function argument parameter used in the + // invokefunction RPC method. + FuncParam struct { + Type StackParamType `json:"type"` + Value Param `json:"value"` + } ) const ( @@ -26,6 +34,7 @@ const ( stringT numberT arrayT + funcParamT ) func (p Param) String() string { @@ -33,27 +42,82 @@ func (p Param) String() string { } // GetString returns string value of the parameter. -func (p Param) GetString() string { return p.Value.(string) } +func (p Param) GetString() (string, error) { + str, ok := p.Value.(string) + if !ok { + return "", errors.New("not a string") + } + return str, nil +} // GetInt returns int value of te parameter. -func (p Param) GetInt() int { return p.Value.(int) } +func (p Param) GetInt() (int, error) { + i, ok := p.Value.(int) + if !ok { + return 0, errors.New("not an integer") + } + return i, nil +} + +// GetArray returns a slice of Params stored in the parameter. +func (p Param) GetArray() ([]Param, error) { + a, ok := p.Value.([]Param) + if !ok { + return nil, errors.New("not an array") + } + return a, nil +} // GetUint256 returns Uint256 value of the parameter. func (p Param) GetUint256() (util.Uint256, error) { - s, ok := p.Value.(string) - if !ok { - return util.Uint256{}, errors.New("must be a string") + s, err := p.GetString() + if err != nil { + return util.Uint256{}, err } return util.Uint256DecodeReverseString(s) } +// GetUint160FromHex returns Uint160 value of the parameter encoded in hex. +func (p Param) GetUint160FromHex() (util.Uint160, error) { + s, err := p.GetString() + if err != nil { + return util.Uint160{}, err + } + + scriptHashLE, err := util.Uint160DecodeString(s) + if err != nil { + return util.Uint160{}, err + } + return util.Uint160DecodeBytes(scriptHashLE.BytesReverse()) +} + +// GetUint160FromAddress returns Uint160 value of the parameter that was +// supplied as an address. +func (p Param) GetUint160FromAddress() (util.Uint160, error) { + s, err := p.GetString() + if err != nil { + return util.Uint160{}, err + } + + return crypto.Uint160DecodeAddress(s) +} + +// GetFuncParam returns current parameter as a function call parameter. +func (p Param) GetFuncParam() (FuncParam, error) { + fp, ok := p.Value.(FuncParam) + if !ok { + return FuncParam{}, errors.New("not a function parameter") + } + return fp, nil +} + // GetBytesHex returns []byte value of the parameter if // it is a hex-encoded string. func (p Param) GetBytesHex() ([]byte, error) { - s, ok := p.Value.(string) - if !ok { - return nil, errors.New("must be a string") + s, err := p.GetString() + if err != nil { + return nil, err } return hex.DecodeString(s) @@ -77,6 +141,17 @@ func (p *Param) UnmarshalJSON(data []byte) error { return nil } + r := bytes.NewReader(data) + jd := json.NewDecoder(r) + jd.DisallowUnknownFields() + var fp FuncParam + if err := jd.Decode(&fp); err == nil { + p.Type = funcParamT + p.Value = fp + + return nil + } + var ps []Param if err := json.Unmarshal(data, &ps); err == nil { p.Type = arrayT diff --git a/pkg/rpc/param_test.go b/pkg/rpc/param_test.go index fca26196b..cd4465bfb 100644 --- a/pkg/rpc/param_test.go +++ b/pkg/rpc/param_test.go @@ -1,14 +1,18 @@ package rpc import ( + "encoding/hex" "encoding/json" "testing" + "github.com/CityOfZion/neo-go/pkg/crypto" + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestParam_UnmarshalJSON(t *testing.T) { - msg := `["str1", 123, ["str2", 3]]` + msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}]]` expected := Params{ { Type: stringT, @@ -31,6 +35,21 @@ func TestParam_UnmarshalJSON(t *testing.T) { }, }, }, + { + Type: arrayT, + Value: []Param{ + { + Type: funcParamT, + Value: FuncParam{ + Type: String, + Value: Param{ + Type: stringT, + Value: "jajaja", + }, + }, + }, + }, + }, } var ps Params @@ -40,3 +59,126 @@ func TestParam_UnmarshalJSON(t *testing.T) { msg = `[{"2": 3}]` require.Error(t, json.Unmarshal([]byte(msg), &ps)) } + +func TestParamGetString(t *testing.T) { + p := Param{stringT, "jajaja"} + str, err := p.GetString() + assert.Equal(t, "jajaja", str) + require.Nil(t, err) + + p = Param{stringT, int(100500)} + _, err = p.GetString() + require.NotNil(t, err) +} + +func TestParamGetInt(t *testing.T) { + p := Param{numberT, int(100500)} + i, err := p.GetInt() + assert.Equal(t, 100500, i) + require.Nil(t, err) + + p = Param{numberT, "jajaja"} + _, err = p.GetInt() + require.NotNil(t, err) +} + +func TestParamGetArray(t *testing.T) { + p := Param{arrayT, []Param{{numberT, 42}}} + a, err := p.GetArray() + assert.Equal(t, []Param{{numberT, 42}}, a) + require.Nil(t, err) + + p = Param{arrayT, 42} + _, err = p.GetArray() + require.NotNil(t, err) +} + +func TestParamGetUint256(t *testing.T) { + gas := "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7" + u256, _ := util.Uint256DecodeReverseString(gas) + p := Param{stringT, gas} + u, err := p.GetUint256() + assert.Equal(t, u256, u) + require.Nil(t, err) + + p = Param{stringT, 42} + _, err = p.GetUint256() + require.NotNil(t, err) + + p = Param{stringT, "qq2c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7"} + _, err = p.GetUint256() + require.NotNil(t, err) +} + +func TestParamGetUint160FromHex(t *testing.T) { + in := "50befd26fdf6e4d957c11e078b24ebce6291456f" + u160, _ := util.Uint160DecodeString(in) + u160, _ = util.Uint160DecodeBytes(util.ArrayReverse(u160[:])) + p := Param{stringT, in} + u, err := p.GetUint160FromHex() + assert.Equal(t, u160, u) + require.Nil(t, err) + + p = Param{stringT, 42} + _, err = p.GetUint160FromHex() + require.NotNil(t, err) + + p = Param{stringT, "wwbefd26fdf6e4d957c11e078b24ebce6291456f"} + _, err = p.GetUint160FromHex() + require.NotNil(t, err) +} + +func TestParamGetUint160FromAddress(t *testing.T) { + in := "AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y" + u160, _ := crypto.Uint160DecodeAddress(in) + p := Param{stringT, in} + u, err := p.GetUint160FromAddress() + assert.Equal(t, u160, u) + require.Nil(t, err) + + p = Param{stringT, 42} + _, err = p.GetUint160FromAddress() + require.NotNil(t, err) + + p = Param{stringT, "QK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y"} + _, err = p.GetUint160FromAddress() + require.NotNil(t, err) +} + +func TestParamGetFuncParam(t *testing.T) { + fp := FuncParam{ + Type: String, + Value: Param{ + Type: stringT, + Value: "jajaja", + }, + } + p := Param{ + Type: funcParamT, + Value: fp, + } + newfp, err := p.GetFuncParam() + assert.Equal(t, fp, newfp) + require.Nil(t, err) + + p = Param{funcParamT, 42} + _, err = p.GetFuncParam() + require.NotNil(t, err) +} + +func TestParamGetBytesHex(t *testing.T) { + in := "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7" + inb, _ := hex.DecodeString(in) + p := Param{stringT, in} + bh, err := p.GetBytesHex() + assert.Equal(t, inb, bh) + require.Nil(t, err) + + p = Param{stringT, 42} + _, err = p.GetBytesHex() + require.NotNil(t, err) + + p = Param{stringT, "qq2c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7"} + _, err = p.GetBytesHex() + require.NotNil(t, err) +} diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index 507ae38f4..df89546b3 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -10,7 +10,6 @@ import ( "github.com/CityOfZion/neo-go/config" "github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core/transaction" - "github.com/CityOfZion/neo-go/pkg/crypto" "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/network" "github.com/CityOfZion/neo-go/pkg/rpc/result" @@ -136,13 +135,13 @@ Methods: break Methods } case numberT: - if !s.validBlockHeight(param) { + num, err := s.blockHeightFromParam(param) + if err != nil { resultsErr = errInvalidParams break Methods } - - hash = s.chain.GetHeaderHash(param.GetInt()) - case defaultT: + hash = s.chain.GetHeaderHash(num) + default: resultsErr = errInvalidParams break Methods } @@ -164,12 +163,14 @@ Methods: if !ok { resultsErr = errInvalidParams break Methods - } else if !s.validBlockHeight(param) { - resultsErr = invalidBlockHeightError(0, param.GetInt()) + } + num, err := s.blockHeightFromParam(param) + if err != nil { + resultsErr = errInvalidParams break Methods } - results = s.chain.GetHeaderHash(param.GetInt()) + results = s.chain.GetHeaderHash(num) case "getconnectioncount": getconnectioncountCalled.Inc() @@ -242,6 +243,9 @@ Methods: getunspentsCalled.Inc() results, resultsErr = s.getAccountState(reqParams, true) + case "invokefunction": + results, resultsErr = s.invokeFunction(reqParams) + case "invokescript": results, resultsErr = s.invokescript(reqParams) @@ -306,11 +310,15 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{}, param, ok := reqParams.ValueWithType(0, stringT) if !ok { return nil, errInvalidParams - } else if scriptHash, err := crypto.Uint160DecodeAddress(param.GetString()); err != nil { + } else if scriptHash, err := param.GetUint160FromAddress(); err != nil { return nil, errInvalidParams } else if as := s.chain.GetAccountState(scriptHash); as != nil { if unspents { - results = wrappers.NewUnspents(as, s.chain, param.GetString()) + str, err := param.GetString() + if err != nil { + return nil, errInvalidParams + } + results = wrappers.NewUnspents(as, s.chain, str) } else { results = wrappers.NewAccountState(as) } @@ -320,6 +328,32 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{}, return results, resultsErr } +// invokescript implements the `invokescript` RPC call. +func (s *Server) invokeFunction(reqParams Params) (interface{}, error) { + scriptHashHex, ok := reqParams.ValueWithType(0, stringT) + if !ok { + return nil, errInvalidParams + } + scriptHash, err := scriptHashHex.GetUint160FromHex() + if err != nil { + return nil, err + } + script, err := CreateFunctionInvocationScript(scriptHash, reqParams[1:]) + if err != nil { + return nil, err + } + vm, _ := s.chain.GetTestVM() + vm.LoadScript(script) + _ = vm.Run() + result := &wrappers.InvokeResult{ + State: vm.State(), + GasConsumed: "0.1", + Script: hex.EncodeToString(script), + Stack: vm.Estack(), + } + return result, nil +} + // invokescript implements the `invokescript` RPC call. func (s *Server) invokescript(reqParams Params) (interface{}, error) { if len(reqParams) < 1 { @@ -334,10 +368,12 @@ func (s *Server) invokescript(reqParams Params) (interface{}, error) { vm, _ := s.chain.GetTestVM() vm.LoadScript(script) _ = vm.Run() + // It's already being GetBytesHex'ed, so it's a correct string. + echo, _ := reqParams[0].GetString() result := &wrappers.InvokeResult{ State: vm.State(), GasConsumed: "0.1", - Script: reqParams[0].GetString(), + Script: echo, Stack: vm.Estack(), } return result, nil @@ -384,6 +420,14 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) { return results, resultsErr } -func (s Server) validBlockHeight(param *Param) bool { - return param.GetInt() >= 0 && param.GetInt() <= int(s.chain.BlockHeight()) +func (s Server) blockHeightFromParam(param *Param) (int, error) { + num, err := param.GetInt() + if err != nil { + return 0, nil + } + + if num < 0 || num > int(s.chain.BlockHeight()) { + return 0, invalidBlockHeightError(0, num) + } + return num, nil } diff --git a/pkg/rpc/server_helper_test.go b/pkg/rpc/server_helper_test.go index c57584bf5..493abbce9 100644 --- a/pkg/rpc/server_helper_test.go +++ b/pkg/rpc/server_helper_test.go @@ -32,6 +32,19 @@ type SendTXResponse struct { ID int `json:"id"` } +// InvokeFunctionResponse struct for testing. +type InvokeFunctionResponse struct { + Jsonrpc string `json:"jsonrpc"` + Result struct { + Script string `json:"script"` + State string `json:"state"` + GasConsumed string `json:"gas_consumed"` + Stack []FuncParam `json:"stack"` + TX string `json:"tx,omitempty"` + } `json:"result"` + ID int `json:"id"` +} + // ValidateAddrResponse struct for testing. type ValidateAddrResponse struct { Jsonrpc string `json:"jsonrpc"` diff --git a/pkg/rpc/server_test.go b/pkg/rpc/server_test.go index 1e4e3ed6a..c1860dd72 100644 --- a/pkg/rpc/server_test.go +++ b/pkg/rpc/server_test.go @@ -126,6 +126,11 @@ var rpcTestCases = map[string][]rpcTestCase{ params: `[]`, fail: true, }, + { + name: "bad params", + params: `[[]]`, + fail: true, + }, { name: "invalid height", params: `[-1]`, @@ -246,6 +251,69 @@ var rpcTestCases = map[string][]rpcTestCase{ }, }, }, + "invokefunction": { + { + name: "positive", + params: `["50befd26fdf6e4d957c11e078b24ebce6291456f", "test", []]`, + result: func(e *executor) interface{} { return &InvokeFunctionResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*InvokeFunctionResponse) + require.True(t, ok) + assert.NotEqual(t, "", res.Result.Script) + assert.NotEqual(t, "", res.Result.State) + assert.NotEqual(t, 0, res.Result.GasConsumed) + }, + }, + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "not a string", + params: `[42, "test", []]`, + fail: true, + }, + { + name: "not a scripthash", + params: `["qwerty", "test", []]`, + fail: true, + }, + { + name: "bad params", + params: `["50befd26fdf6e4d957c11e078b24ebce6291456f", "test", [{"type": "Integer", "value": "qwerty"}]]`, + fail: true, + }, + }, + "invokescript": { + { + name: "positive", + params: `["51c56b0d48656c6c6f2c20776f726c6421680f4e656f2e52756e74696d652e4c6f67616c7566"]`, + result: func(e *executor) interface{} { return &InvokeFunctionResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*InvokeFunctionResponse) + require.True(t, ok) + assert.NotEqual(t, "", res.Result.Script) + assert.NotEqual(t, "", res.Result.State) + assert.NotEqual(t, 0, res.Result.GasConsumed) + }, + }, + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "not a string", + params: `[42]`, + fail: true, + }, + { + name: "bas string", + params: `["qwerty"]`, + fail: true, + }, + }, "sendrawtransaction": { { name: "positive", diff --git a/pkg/rpc/stack_param.go b/pkg/rpc/stack_param.go index 907e3e885..cf7e4c72c 100644 --- a/pkg/rpc/stack_param.go +++ b/pkg/rpc/stack_param.go @@ -89,6 +89,11 @@ func StackParamTypeFromString(s string) (StackParamType, error) { } } +// MarshalJSON implements the json.Marshaler interface. +func (t *StackParamType) MarshalJSON() ([]byte, error) { + return []byte(`"` + t.String() + `"`), nil +} + // UnmarshalJSON sets StackParamType from JSON-encoded data. func (t *StackParamType) UnmarshalJSON(data []byte) (err error) { var ( diff --git a/pkg/rpc/txBuilder.go b/pkg/rpc/txBuilder.go index b13dfd972..5ad22d481 100644 --- a/pkg/rpc/txBuilder.go +++ b/pkg/rpc/txBuilder.go @@ -2,6 +2,9 @@ package rpc import ( "bytes" + "errors" + "fmt" + "strconv" "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/crypto" @@ -162,3 +165,121 @@ func CreateDeploymentScript(avm []byte, contract *ContractDetails) ([]byte, erro } return script.Bytes(), nil } + +// CreateFunctionInvocationScript creates a script to invoke given contract with +// given parameters. +func CreateFunctionInvocationScript(contract util.Uint160, params Params) ([]byte, error) { + script := new(bytes.Buffer) + for i := len(params) - 1; i >= 0; i-- { + switch params[i].Type { + case stringT: + if err := vm.EmitString(script, params[i].String()); err != nil { + return nil, err + } + case numberT: + num, err := params[i].GetInt() + if err != nil { + return nil, err + } + if err := vm.EmitString(script, strconv.Itoa(num)); err != nil { + return nil, err + } + case arrayT: + slice, err := params[i].GetArray() + if err != nil { + return nil, err + } + for j := len(slice) - 1; j >= 0; j-- { + fp, err := slice[j].GetFuncParam() + if err != nil { + return nil, err + } + switch fp.Type { + case ByteArray, Signature: + str, err := fp.Value.GetBytesHex() + if err != nil { + return nil, err + } + if err := vm.EmitBytes(script, str); err != nil { + return nil, err + } + case String: + str, err := fp.Value.GetString() + if err != nil { + return nil, err + } + if err := vm.EmitString(script, str); err != nil { + return nil, err + } + case Hash160: + hash, err := fp.Value.GetUint160FromHex() + if err != nil { + return nil, err + } + if err := vm.EmitBytes(script, hash.Bytes()); err != nil { + return nil, err + } + case Hash256: + hash, err := fp.Value.GetUint256() + if err != nil { + return nil, err + } + if err := vm.EmitBytes(script, hash.Bytes()); err != nil { + return nil, err + } + case PublicKey: + str, err := fp.Value.GetString() + if err != nil { + return nil, err + } + key, err := keys.NewPublicKeyFromString(string(str)) + if err != nil { + return nil, err + } + if err := vm.EmitBytes(script, key.Bytes()); err != nil { + return nil, err + } + case Integer: + val, err := fp.Value.GetInt() + if err != nil { + return nil, err + } + if err := vm.EmitInt(script, int64(val)); err != nil { + return nil, err + } + case Boolean: + str, err := fp.Value.GetString() + if err != nil { + return nil, err + } + switch str { + case "true": + err = vm.EmitInt(script, 1) + case "false": + err = vm.EmitInt(script, 0) + default: + err = errors.New("wrong boolean value") + } + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("parameter type %v is not supported", fp.Type) + } + } + err = vm.EmitInt(script, int64(len(slice))) + if err != nil { + return nil, err + } + err = vm.EmitOpcode(script, vm.PACK) + if err != nil { + return nil, err + } + } + } + + if err := vm.EmitAppCall(script, contract, false); err != nil { + return nil, err + } + return script.Bytes(), nil +} diff --git a/pkg/rpc/tx_builder_test.go b/pkg/rpc/tx_builder_test.go new file mode 100644 index 000000000..0c3e1fa16 --- /dev/null +++ b/pkg/rpc/tx_builder_test.go @@ -0,0 +1,89 @@ +package rpc + +import ( + "encoding/hex" + "testing" + + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInvocationScriptCreationGood(t *testing.T) { + p := Param{stringT, "50befd26fdf6e4d957c11e078b24ebce6291456f"} + contract, err := p.GetUint160FromHex() + require.Nil(t, err) + + var paramScripts = []struct { + ps Params + script string + }{{ + script: "676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "transfer"}}, + script: "087472616e73666572676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{numberT, 42}}, + script: "023432676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{}}}, + script: "00c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{ByteArray, Param{stringT, "50befd26fdf6e4d957c11e078b24ebce6291456f"}}}}}}, + script: "1450befd26fdf6e4d957c11e078b24ebce6291456f51c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{Signature, Param{stringT, "4edf5005771de04619235d5a4c7a9a11bb78e008541f1da7725f654c33380a3c87e2959a025da706d7255cb3a3fa07ebe9c6559d0d9e6213c68049168eb1056f"}}}}}}, + script: "404edf5005771de04619235d5a4c7a9a11bb78e008541f1da7725f654c33380a3c87e2959a025da706d7255cb3a3fa07ebe9c6559d0d9e6213c68049168eb1056f51c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{String, Param{stringT, "50befd26fdf6e4d957c11e078b24ebce6291456f"}}}}}}, + script: "283530626566643236666466366534643935376331316530373862323465626365363239313435366651c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{Hash160, Param{stringT, "50befd26fdf6e4d957c11e078b24ebce6291456f"}}}}}}, + script: "146f459162ceeb248b071ec157d9e4f6fd26fdbe5051c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{Hash256, Param{stringT, "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7"}}}}}}, + script: "20e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c6051c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{PublicKey, Param{stringT, "03c089d7122b840a4935234e82e26ae5efd0c2acb627239dc9f207311337b6f2c1"}}}}}}, + script: "2103c089d7122b840a4935234e82e26ae5efd0c2acb627239dc9f207311337b6f2c151c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{Integer, Param{numberT, 42}}}}}}, + script: "012a51c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{Boolean, Param{stringT, "true"}}}}}}, + script: "5151c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }, { + ps: Params{{stringT, "a"}, {arrayT, []Param{{funcParamT, FuncParam{Boolean, Param{stringT, "false"}}}}}}, + script: "0051c10161676f459162ceeb248b071ec157d9e4f6fd26fdbe50", + }} + for _, ps := range paramScripts { + script, err := CreateFunctionInvocationScript(contract, ps.ps) + assert.Nil(t, err) + assert.Equal(t, ps.script, hex.EncodeToString(script)) + } +} + +func TestInvocationScriptCreationBad(t *testing.T) { + contract := util.Uint160{} + + var testParams = []Params{ + Params{{numberT, "qwerty"}}, + Params{{arrayT, 42}}, + Params{{arrayT, []Param{{numberT, 42}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{ByteArray, Param{stringT, "qwerty"}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{Signature, Param{stringT, "qwerty"}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{String, Param{numberT, 42}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{Hash160, Param{stringT, "qwerty"}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{Hash256, Param{stringT, "qwerty"}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{PublicKey, Param{numberT, 42}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{PublicKey, Param{stringT, "qwerty"}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{Integer, Param{stringT, "qwerty"}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{Boolean, Param{numberT, 42}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{Boolean, Param{stringT, "qwerty"}}}}}}, + Params{{arrayT, []Param{{funcParamT, FuncParam{Unknown, Param{}}}}}}, + } + for _, ps := range testParams { + _, err := CreateFunctionInvocationScript(contract, ps) + assert.NotNil(t, err) + } +} diff --git a/pkg/smartcontract/param_context.go b/pkg/smartcontract/param_context.go index 98d32a1f9..ad509f17d 100644 --- a/pkg/smartcontract/param_context.go +++ b/pkg/smartcontract/param_context.go @@ -1,6 +1,14 @@ package smartcontract import ( + "encoding/hex" + "errors" + "strconv" + "strings" + "unicode/utf8" + + "github.com/CityOfZion/neo-go/pkg/crypto" + "github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/util" ) @@ -89,6 +97,203 @@ func NewParameter(t ParamType) Parameter { } } +// parseParamType is a user-friendly string to ParamType converter, it's +// case-insensitive and makes the following conversions: +// signature -> SignatureType +// bool -> BoolType +// int -> IntegerType +// hash160 -> Hash160Type +// hash256 -> Hash256Type +// bytes -> ByteArrayType +// key -> PublicKeyType +// string -> StringType +// anything else generates an error. +func parseParamType(typ string) (ParamType, error) { + switch strings.ToLower(typ) { + case "signature": + return SignatureType, nil + case "bool": + return BoolType, nil + case "int": + return IntegerType, nil + case "hash160": + return Hash160Type, nil + case "hash256": + return Hash256Type, nil + case "bytes": + return ByteArrayType, nil + case "key": + return PublicKeyType, nil + case "string": + return StringType, nil + default: + // We deliberately don't support array here. + return 0, errors.New("wrong or unsupported parameter type") + } +} + +// adjustValToType is a value type-checker and converter. +func adjustValToType(typ ParamType, val string) (interface{}, error) { + switch typ { + case SignatureType: + b, err := hex.DecodeString(val) + if err != nil { + return nil, err + } + if len(b) != 64 { + return nil, errors.New("not a signature") + } + return val, nil + case BoolType: + switch val { + case "true": + return true, nil + case "false": + return false, nil + default: + return nil, errors.New("invalid boolean value") + } + case IntegerType: + return strconv.Atoi(val) + case Hash160Type: + u, err := crypto.Uint160DecodeAddress(val) + if err == nil { + return hex.EncodeToString(u.Bytes()), nil + } + b, err := hex.DecodeString(val) + if err != nil { + return nil, err + } + if len(b) != 20 { + return nil, errors.New("not a hash160") + } + return val, nil + case Hash256Type: + b, err := hex.DecodeString(val) + if err != nil { + return nil, err + } + if len(b) != 32 { + return nil, errors.New("not a hash256") + } + return val, nil + case ByteArrayType: + _, err := hex.DecodeString(val) + if err != nil { + return nil, err + } + return val, nil + case PublicKeyType: + _, err := keys.NewPublicKeyFromString(val) + if err != nil { + return nil, err + } + return val, nil + case StringType: + return val, nil + default: + return nil, errors.New("unsupported parameter type") + } +} + +// inferParamType tries to infer the value type from its contents. It returns +// IntegerType for anything that looks like decimal integer (can be converted +// with strconv.Atoi), BoolType for true and false values, Hash160Type for +// addresses and hex strings encoding 20 bytes long values, PublicKeyType for +// valid hex-encoded public keys, Hash256Type for hex-encoded 32 bytes values, +// SignatureType for hex-encoded 64 bytes values, ByteArrayType for any other +// valid hex-encoded values and StringType for anything else. +func inferParamType(val string) ParamType { + var err error + + _, err = strconv.Atoi(val) + if err == nil { + return IntegerType + } + + if val == "true" || val == "false" { + return BoolType + } + + _, err = crypto.Uint160DecodeAddress(val) + if err == nil { + return Hash160Type + } + + _, err = keys.NewPublicKeyFromString(val) + if err == nil { + return PublicKeyType + } + + unhexed, err := hex.DecodeString(val) + if err == nil { + switch len(unhexed) { + case 20: + return Hash160Type + case 32: + return Hash256Type + case 64: + return SignatureType + default: + return ByteArrayType + } + } + // Anything can be a string. + return StringType +} + +// NewParameterFromString returns a new Parameter initialized from the given +// string in neo-go-specific format. It is intended to be used in user-facing +// interfaces and has some heuristics in it to simplify parameter passing. Exact +// syntax is documented in the cli documentation. +func NewParameterFromString(in string) (*Parameter, error) { + var ( + char rune + val string + err error + r *strings.Reader + buf strings.Builder + escaped bool + hadType bool + res = &Parameter{} + ) + r = strings.NewReader(in) + for char, _, err = r.ReadRune(); err == nil && char != utf8.RuneError; char, _, err = r.ReadRune() { + if char == '\\' && !escaped { + escaped = true + continue + } + if char == ':' && !escaped && !hadType { + typStr := buf.String() + res.Type, err = parseParamType(typStr) + if err != nil { + return nil, err + } + buf.Reset() + hadType = true + continue + } + escaped = false + // We don't care about length and it never fails. + _, _ = buf.WriteRune(char) + } + if char == utf8.RuneError { + return nil, errors.New("bad UTF-8 string") + } + // The only other error `ReadRune` returns is io.EOF, which is fine and + // expected, so we don't check err here. + + val = buf.String() + if !hadType { + res.Type = inferParamType(val) + } + res.Value, err = adjustValToType(res.Type, val) + if err != nil { + return nil, err + } + return res, nil +} + // ContextItem represents a transaction context item. type ContextItem struct { Script util.Uint160 diff --git a/pkg/smartcontract/param_context_test.go b/pkg/smartcontract/param_context_test.go new file mode 100644 index 000000000..1f95d3128 --- /dev/null +++ b/pkg/smartcontract/param_context_test.go @@ -0,0 +1,343 @@ +package smartcontract + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseParamType(t *testing.T) { + var inouts = []struct { + in string + out ParamType + err bool + }{{ + in: "signature", + out: SignatureType, + }, { + in: "Signature", + out: SignatureType, + }, { + in: "SiGnAtUrE", + out: SignatureType, + }, { + in: "bool", + out: BoolType, + }, { + in: "int", + out: IntegerType, + }, { + in: "hash160", + out: Hash160Type, + }, { + in: "hash256", + out: Hash256Type, + }, { + in: "bytes", + out: ByteArrayType, + }, { + in: "key", + out: PublicKeyType, + }, { + in: "string", + out: StringType, + }, { + in: "array", + err: true, + }, { + in: "qwerty", + err: true, + }} + for _, inout := range inouts { + out, err := parseParamType(inout.in) + if inout.err { + assert.NotNil(t, err, "should error on '%s' input", inout.in) + } else { + assert.Nil(t, err, "shouldn't error on '%s' input", inout.in) + assert.Equal(t, inout.out, out, "bad output for '%s' input", inout.in) + } + } +} + +func TestInferParamType(t *testing.T) { + var inouts = []struct { + in string + out ParamType + }{{ + in: "42", + out: IntegerType, + }, { + in: "-42", + out: IntegerType, + }, { + in: "0", + out: IntegerType, + }, { + in: "2e10", + out: ByteArrayType, + }, { + in: "true", + out: BoolType, + }, { + in: "false", + out: BoolType, + }, { + in: "truee", + out: StringType, + }, { + in: "AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y", + out: Hash160Type, + }, { + in: "ZK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y", + out: StringType, + }, { + in: "50befd26fdf6e4d957c11e078b24ebce6291456f", + out: Hash160Type, + }, { + in: "03b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c", + out: PublicKeyType, + }, { + in: "30b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c", + out: ByteArrayType, + }, { + in: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7", + out: Hash256Type, + }, { + in: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7da", + out: ByteArrayType, + }, { + in: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b", + out: SignatureType, + }, { + in: "qwerty", + out: StringType, + }, { + in: "ab", + out: ByteArrayType, + }, { + in: "az", + out: StringType, + }, { + in: "bad", + out: StringType, + }, { + in: "фыва", + out: StringType, + }, { + in: "dead", + out: ByteArrayType, + }} + for _, inout := range inouts { + out := inferParamType(inout.in) + assert.Equal(t, inout.out, out, "bad output for '%s' input", inout.in) + } +} + +func TestAdjustValToType(t *testing.T) { + var inouts = []struct { + typ ParamType + val string + out interface{} + err bool + }{{ + typ: SignatureType, + val: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b", + out: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b", + }, { + typ: SignatureType, + val: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c", + err: true, + }, { + typ: SignatureType, + val: "qwerty", + err: true, + }, { + typ: BoolType, + val: "false", + out: false, + }, { + typ: BoolType, + val: "true", + out: true, + }, { + typ: BoolType, + val: "qwerty", + err: true, + }, { + typ: BoolType, + val: "42", + err: true, + }, { + typ: BoolType, + val: "0", + err: true, + }, { + typ: IntegerType, + val: "0", + out: 0, + }, { + typ: IntegerType, + val: "42", + out: 42, + }, { + typ: IntegerType, + val: "-42", + out: -42, + }, { + typ: IntegerType, + val: "q", + err: true, + }, { + typ: Hash160Type, + val: "AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y", + out: "23ba2703c53263e8d6e522dc32203339dcd8eee9", + }, { + typ: Hash160Type, + val: "50befd26fdf6e4d957c11e078b24ebce6291456f", + out: "50befd26fdf6e4d957c11e078b24ebce6291456f", + }, { + typ: Hash160Type, + val: "befd26fdf6e4d957c11e078b24ebce6291456f", + err: true, + }, { + typ: Hash160Type, + val: "q", + err: true, + }, { + typ: Hash256Type, + val: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7", + out: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7", + }, { + typ: Hash256Type, + val: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282d", + err: true, + }, { + typ: Hash256Type, + val: "q", + err: true, + }, { + typ: ByteArrayType, + val: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282d", + out: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282d", + }, { + typ: ByteArrayType, + val: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7", + out: "602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7", + }, { + typ: ByteArrayType, + val: "50befd26fdf6e4d957c11e078b24ebce6291456f", + out: "50befd26fdf6e4d957c11e078b24ebce6291456f", + }, { + typ: ByteArrayType, + val: "AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y", + err: true, + }, { + typ: ByteArrayType, + val: "q", + err: true, + }, { + typ: ByteArrayType, + val: "ab", + out: "ab", + }, { + typ: PublicKeyType, + val: "03b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c", + out: "03b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c", + }, { + typ: PublicKeyType, + val: "01b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c", + err: true, + }, { + typ: PublicKeyType, + val: "q", + err: true, + }, { + typ: StringType, + val: "q", + out: "q", + }, { + typ: StringType, + val: "dead", + out: "dead", + }, { + typ: StringType, + val: "йцукен", + out: "йцукен", + }, { + typ: ArrayType, + val: "", + err: true, + }} + + for _, inout := range inouts { + out, err := adjustValToType(inout.typ, inout.val) + if inout.err { + assert.NotNil(t, err, "should error on '%s/%s' input", inout.typ, inout.val) + } else { + assert.Nil(t, err, "shouldn't error on '%s/%s' input", inout.typ, inout.val) + assert.Equal(t, inout.out, out, "bad output for '%s/%s' input", inout.typ, inout.val) + } + } +} + +func TestNewParameterFromString(t *testing.T) { + var inouts = []struct { + in string + out Parameter + err bool + }{{ + in: "qwerty", + out: Parameter{StringType, "qwerty"}, + }, { + in: "42", + out: Parameter{IntegerType, 42}, + }, { + in: "Hello, 世界", + out: Parameter{StringType, "Hello, 世界"}, + }, { + in: `\4\2`, + out: Parameter{IntegerType, 42}, + }, { + in: `\\4\2`, + out: Parameter{StringType, `\42`}, + }, { + in: `\\\4\2`, + out: Parameter{StringType, `\42`}, + }, { + in: "int:42", + out: Parameter{IntegerType, 42}, + }, { + in: "true", + out: Parameter{BoolType, true}, + }, { + in: "string:true", + out: Parameter{StringType, "true"}, + }, { + in: "\xfe\xff", + err: true, + }, { + in: `string\:true`, + out: Parameter{StringType, "string:true"}, + }, { + in: "string:true:true", + out: Parameter{StringType, "true:true"}, + }, { + in: `string\\:true`, + err: true, + }, { + in: `qwerty:asdf`, + err: true, + }, { + in: `bool:asdf`, + err: true, + }} + for _, inout := range inouts { + out, err := NewParameterFromString(inout.in) + if inout.err { + assert.NotNil(t, err, "should error on '%s' input", inout.in) + } else { + assert.Nil(t, err, "shouldn't error on '%s' input", inout.in) + assert.Equal(t, inout.out, *out, "bad output for '%s' input", inout.in) + } + } +} diff --git a/pkg/vm/output.go b/pkg/vm/output.go index 14e83103b..46c27d3f9 100644 --- a/pkg/vm/output.go +++ b/pkg/vm/output.go @@ -9,13 +9,34 @@ type stackItem struct { Type string `json:"type"` } +func appendToItems(items *[]stackItem, val StackItem, seen map[StackItem]bool) { + if arr, ok := val.Value().([]StackItem); ok { + if seen[val] { + return + } + seen[val] = true + intItems := make([]stackItem, 0, len(arr)) + for _, v := range arr { + appendToItems(&intItems, v, seen) + } + *items = append(*items, stackItem{ + Value: intItems, + Type: val.String(), + }) + + } else { + *items = append(*items, stackItem{ + Value: val, + Type: val.String(), + }) + } +} + func stackToArray(s *Stack) []stackItem { items := make([]stackItem, 0, s.Len()) - s.Iter(func(e *Element) { - items = append(items, stackItem{ - Value: e.value, - Type: e.value.String(), - }) + seen := make(map[StackItem]bool) + s.IterBack(func(e *Element) { + appendToItems(&items, e.value, seen) }) return items } diff --git a/pkg/vm/stack.go b/pkg/vm/stack.go index 23a4b2247..cc8dabfb2 100644 --- a/pkg/vm/stack.go +++ b/pkg/vm/stack.go @@ -360,6 +360,17 @@ func (s *Stack) Iter(f func(*Element)) { } } +// IterBack iterates over all the elements of the stack, starting from the bottom +// of the stack. +// s.IterBack(func(elem *Element) { +// // do something with the element. +// }) +func (s *Stack) IterBack(f func(*Element)) { + for e := s.Back(); e != nil; e = e.Prev() { + f(e) + } +} + // popSigElements pops keys or signatures from the stack as needed for // CHECKMULTISIG. func (s *Stack) popSigElements() ([][]byte, error) { diff --git a/pkg/vm/stack_test.go b/pkg/vm/stack_test.go index ffc0f4c5d..c14a84414 100644 --- a/pkg/vm/stack_test.go +++ b/pkg/vm/stack_test.go @@ -154,19 +154,46 @@ func TestIterAfterRemove(t *testing.T) { func TestIteration(t *testing.T) { var ( + n = 10 s = NewStack("test") - elems = makeElements(10) + elems = makeElements(n) ) for _, elem := range elems { s.Push(elem) } assert.Equal(t, len(elems), s.Len()) - i := 0 + iteratedElems := make([]*Element, 0) + s.Iter(func(elem *Element) { - i++ + iteratedElems = append(iteratedElems, elem) }) - assert.Equal(t, len(elems), i) + // Top to bottom order of iteration. + poppedElems := make([]*Element, 0) + for elem := s.Pop(); elem != nil; elem = s.Pop() { + poppedElems = append(poppedElems, elem) + } + assert.Equal(t, poppedElems, iteratedElems) +} + +func TestBackIteration(t *testing.T) { + var ( + n = 10 + s = NewStack("test") + elems = makeElements(n) + ) + for _, elem := range elems { + s.Push(elem) + } + assert.Equal(t, len(elems), s.Len()) + + iteratedElems := make([]*Element, 0) + + s.IterBack(func(elem *Element) { + iteratedElems = append(iteratedElems, elem) + }) + // Bottom to the top order of iteration. + assert.Equal(t, elems, iteratedElems) } func TestPushVal(t *testing.T) {