From 7331127556e774ccfea79c4b9b0f2bb2d5af271f Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 21 Nov 2019 16:42:51 +0300 Subject: [PATCH 1/6] rpc: make parameter type an enum Signed-off-by: Evgenii Stratonikov --- pkg/rpc/param.go | 57 +++++++++++++++++++++++++--- pkg/rpc/params.go | 69 +++------------------------------- pkg/rpc/server.go | 96 +++++++++++++++++++++++------------------------ 3 files changed, 105 insertions(+), 117 deletions(-) diff --git a/pkg/rpc/param.go b/pkg/rpc/param.go index d86b2ef2f..52360830d 100644 --- a/pkg/rpc/param.go +++ b/pkg/rpc/param.go @@ -1,7 +1,11 @@ package rpc import ( + "encoding/json" "fmt" + + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/pkg/errors" ) type ( @@ -9,13 +13,56 @@ type ( // the server or to send to a server using // the client. Param struct { - StringVal string - IntVal int - Type string - RawValue interface{} + Type paramType + Value interface{} } + + paramType int +) + +const ( + defaultT paramType = iota + stringT + numberT ) func (p Param) String() string { - return fmt.Sprintf("%v", p.RawValue) + return fmt.Sprintf("%v", p.Value) +} + +// GetString returns string value of the parameter. +func (p Param) GetString() string { return p.Value.(string) } + +// GetInt returns int value of te parameter. +func (p Param) GetInt() int { return p.Value.(int) } + +// GetUint256 returns Uint256 value of the parameter. +func (p Param) GetUint256() (util.Uint256, error) { + s, ok := p.Value.(string) + if !ok { + return util.Uint256{}, errors.New("must be a string") + } + + return util.Uint256DecodeReverseString(s) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (p *Param) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + p.Type = stringT + p.Value = s + + return nil + } + + var num float64 + if err := json.Unmarshal(data, &num); err == nil { + p.Type = numberT + p.Value = int(num) + + return nil + } + + return errors.New("unknown type") } diff --git a/pkg/rpc/params.go b/pkg/rpc/params.go index bef402f0a..1affcf267 100644 --- a/pkg/rpc/params.go +++ b/pkg/rpc/params.go @@ -1,49 +1,13 @@ package rpc -import ( - "encoding/json" -) - type ( // Params represents the JSON-RPC params. Params []Param ) -// UnmarshalJSON implements the Unmarshaller -// interface. -func (p *Params) UnmarshalJSON(data []byte) error { - var params []interface{} - - err := json.Unmarshal(data, ¶ms) - if err != nil { - return err - } - - for i := 0; i < len(params); i++ { - param := Param{ - RawValue: params[i], - } - - switch val := params[i].(type) { - case string: - param.StringVal = val - param.Type = "string" - - case float64: - newVal, _ := params[i].(float64) - param.IntVal = int(newVal) - param.Type = "number" - } - - *p = append(*p, param) - } - - return nil -} - -// ValueAt returns the param struct for the given +// Value returns the param struct for the given // index if it exists. -func (p Params) ValueAt(index int) (*Param, bool) { +func (p Params) Value(index int) (*Param, bool) { if len(p) > index { return &p[index], true } @@ -51,33 +15,12 @@ func (p Params) ValueAt(index int) (*Param, bool) { return nil, false } -// ValueAtAndType returns the param struct at the given index if it +// ValueWithType returns the param struct at the given index if it // exists and matches the given type. -func (p Params) ValueAtAndType(index int, valueType string) (*Param, bool) { - if len(p) > index && valueType == p[index].Type { - return &p[index], true +func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) { + if val, ok := p.Value(index); ok && val.Type == valType { + return val, true } return nil, false } - -// Value returns the param struct for the given -// index if it exists. -func (p Params) Value(index int) (*Param, error) { - if len(p) <= index { - return nil, errInvalidParams - } - return &p[index], nil -} - -// ValueWithType returns the param struct at the given index if it -// exists and matches the given type. -func (p Params) ValueWithType(index int, valType string) (*Param, error) { - val, err := p.Value(index) - if err != nil { - return nil, err - } else if val.Type != valType { - return nil, errInvalidParams - } - return &p[index], nil -} diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index 8b521ed36..c9e730321 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -30,11 +30,9 @@ type ( } ) -var ( - invalidBlockHeightError = func(index int, height int) error { - return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) - } -) +var invalidBlockHeightError = func(index int, height int) error { + return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) +} // NewServer creates a new Server struct. func NewServer(chain core.Blockchainer, conf config.RPCConfig, coreServer *network.Server) Server { @@ -123,27 +121,28 @@ Methods: getbestblockCalled.Inc() var hash util.Uint256 - param, err := reqParams.Value(0) - if err != nil { - resultsErr = err + param, ok := reqParams.Value(0) + if !ok { + resultsErr = errInvalidParams break Methods } switch param.Type { - case "string": - hash, err = util.Uint256DecodeReverseString(param.StringVal) + case stringT: + var err error + hash, err = param.GetUint256() if err != nil { resultsErr = errInvalidParams break Methods } - case "number": + case numberT: if !s.validBlockHeight(param) { resultsErr = errInvalidParams break Methods } - hash = s.chain.GetHeaderHash(param.IntVal) - case "default": + hash = s.chain.GetHeaderHash(param.GetInt()) + case defaultT: resultsErr = errInvalidParams break Methods } @@ -161,16 +160,16 @@ Methods: case "getblockhash": getblockHashCalled.Inc() - param, err := reqParams.ValueWithType(0, "number") - if err != nil { - resultsErr = err + param, ok := reqParams.ValueWithType(0, numberT) + if !ok { + resultsErr = errInvalidParams break Methods } else if !s.validBlockHeight(param) { - resultsErr = invalidBlockHeightError(0, param.IntVal) + resultsErr = invalidBlockHeightError(0, param.GetInt()) break Methods } - results = s.chain.GetHeaderHash(param.IntVal) + results = s.chain.GetHeaderHash(param.GetInt()) case "getconnectioncount": getconnectioncountCalled.Inc() @@ -203,22 +202,22 @@ Methods: case "validateaddress": validateaddressCalled.Inc() - param, err := reqParams.Value(0) - if err != nil { - resultsErr = err + param, ok := reqParams.Value(0) + if !ok { + resultsErr = errInvalidParams break Methods } - results = wrappers.ValidateAddress(param.RawValue) + results = wrappers.ValidateAddress(param.Value) case "getassetstate": getassetstateCalled.Inc() - param, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err + param, ok := reqParams.ValueWithType(0, stringT) + if !ok { + resultsErr = errInvalidParams break Methods } - paramAssetID, err := util.Uint256DecodeReverseString(param.StringVal) + paramAssetID, err := param.GetUint256() if err != nil { resultsErr = errInvalidParams break @@ -266,14 +265,13 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) { var resultsErr error var results interface{} - param0, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err - } else if txHash, err := util.Uint256DecodeReverseString(param0.StringVal); err != nil { + if param0, ok := reqParams.Value(0); !ok { + return nil, errInvalidParams + } else if txHash, err := param0.GetUint256(); err != nil { resultsErr = errInvalidParams } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) - resultsErr = NewInvalidParamsError(err.Error(), err) + return nil, NewInvalidParamsError(err.Error(), err) } else if len(reqParams) >= 2 { _header := s.chain.GetHeaderHash(int(height)) header, err := s.chain.GetHeader(_header) @@ -281,8 +279,8 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) { resultsErr = NewInvalidParamsError(err.Error(), err) } - param1, _ := reqParams.ValueAt(1) - switch v := param1.RawValue.(type) { + param1, _ := reqParams.Value(1) + switch v := param1.Value.(type) { case int, float64, bool, string: if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" { @@ -305,14 +303,14 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{}, var resultsErr error var results interface{} - param, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err - } else if scriptHash, err := crypto.Uint160DecodeAddress(param.StringVal); err != nil { - resultsErr = errInvalidParams + param, ok := reqParams.ValueWithType(0, stringT) + if !ok { + return nil, errInvalidParams + } else if scriptHash, err := crypto.Uint160DecodeAddress(param.GetString()); err != nil { + return nil, errInvalidParams } else if as := s.chain.GetAccountState(scriptHash); as != nil { if unspents { - results = wrappers.NewUnspents(as, s.chain, param.StringVal) + results = wrappers.NewUnspents(as, s.chain, param.GetString()) } else { results = wrappers.NewAccountState(as) } @@ -324,11 +322,11 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{}, // invokescript implements the `invokescript` RPC call. func (s *Server) invokescript(reqParams Params) (interface{}, error) { - hexScript, err := reqParams.ValueWithType(0, "string") - if err != nil { - return nil, err + hexScript, ok := reqParams.ValueWithType(0, stringT) + if !ok { + return nil, errInvalidParams } - script, err := hex.DecodeString(hexScript.StringVal) + script, err := hex.DecodeString(hexScript.GetString()) if err != nil { return nil, err } @@ -338,7 +336,7 @@ func (s *Server) invokescript(reqParams Params) (interface{}, error) { result := &wrappers.InvokeResult{ State: vm.State(), GasConsumed: "0.1", - Script: hexScript.StringVal, + Script: hexScript.GetString(), Stack: vm.Estack(), } return result, nil @@ -348,10 +346,10 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) { var resultsErr error var results interface{} - param, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err - } else if byteTx, err := hex.DecodeString(param.StringVal); err != nil { + param, ok := reqParams.ValueWithType(0, stringT) + if !ok { + resultsErr = errInvalidParams + } else if byteTx, err := hex.DecodeString(param.GetString()); err != nil { resultsErr = errInvalidParams } else { r := io.NewBinReaderFromBuf(byteTx) @@ -387,5 +385,5 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) { } func (s Server) validBlockHeight(param *Param) bool { - return param.IntVal >= 0 && param.IntVal <= int(s.chain.BlockHeight()) + return param.GetInt() >= 0 && param.GetInt() <= int(s.chain.BlockHeight()) } From c8987eda32e8d80f70902211475ee3cd3a47274d Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 21 Nov 2019 17:42:02 +0300 Subject: [PATCH 2/6] rpc: add array param type and tests Signed-off-by: Evgenii Stratonikov --- pkg/rpc/param.go | 21 +++++++++++++++++++++ pkg/rpc/param_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ pkg/rpc/server.go | 20 ++++++++++---------- 3 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 pkg/rpc/param_test.go diff --git a/pkg/rpc/param.go b/pkg/rpc/param.go index 52360830d..08f50615c 100644 --- a/pkg/rpc/param.go +++ b/pkg/rpc/param.go @@ -1,6 +1,7 @@ package rpc import ( + "encoding/hex" "encoding/json" "fmt" @@ -24,6 +25,7 @@ const ( defaultT paramType = iota stringT numberT + arrayT ) func (p Param) String() string { @@ -46,6 +48,17 @@ func (p Param) GetUint256() (util.Uint256, error) { return util.Uint256DecodeReverseString(s) } +// GetBytesHex returns []byte value of the parameter if +// it is a hex-encoded string. +func (p Param) GetBytesHex() ([]byte, error) { + s, ok := p.Value.(string) + if !ok { + return nil, errors.New("must be a string") + } + + return hex.DecodeString(s) +} + // UnmarshalJSON implements json.Unmarshaler interface. func (p *Param) UnmarshalJSON(data []byte) error { var s string @@ -64,5 +77,13 @@ func (p *Param) UnmarshalJSON(data []byte) error { return nil } + var ps []Param + if err := json.Unmarshal(data, &ps); err == nil { + p.Type = arrayT + p.Value = ps + + return nil + } + return errors.New("unknown type") } diff --git a/pkg/rpc/param_test.go b/pkg/rpc/param_test.go new file mode 100644 index 000000000..fca26196b --- /dev/null +++ b/pkg/rpc/param_test.go @@ -0,0 +1,42 @@ +package rpc + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParam_UnmarshalJSON(t *testing.T) { + msg := `["str1", 123, ["str2", 3]]` + expected := Params{ + { + Type: stringT, + Value: "str1", + }, + { + Type: numberT, + Value: 123, + }, + { + Type: arrayT, + Value: []Param{ + { + Type: stringT, + Value: "str2", + }, + { + Type: numberT, + Value: 3, + }, + }, + }, + } + + var ps Params + require.NoError(t, json.Unmarshal([]byte(msg), &ps)) + require.Equal(t, expected, ps) + + msg = `[{"2": 3}]` + require.Error(t, json.Unmarshal([]byte(msg), &ps)) +} diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index c9e730321..507ae38f4 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -322,21 +322,22 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{}, // invokescript implements the `invokescript` RPC call. func (s *Server) invokescript(reqParams Params) (interface{}, error) { - hexScript, ok := reqParams.ValueWithType(0, stringT) - if !ok { + if len(reqParams) < 1 { return nil, errInvalidParams } - script, err := hex.DecodeString(hexScript.GetString()) + + script, err := reqParams[0].GetBytesHex() if err != nil { - return nil, err + return nil, errInvalidParams } + vm, _ := s.chain.GetTestVM() vm.LoadScript(script) _ = vm.Run() result := &wrappers.InvokeResult{ State: vm.State(), GasConsumed: "0.1", - Script: hexScript.GetString(), + Script: reqParams[0].GetString(), Stack: vm.Estack(), } return result, nil @@ -346,11 +347,10 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) { var resultsErr error var results interface{} - param, ok := reqParams.ValueWithType(0, stringT) - if !ok { - resultsErr = errInvalidParams - } else if byteTx, err := hex.DecodeString(param.GetString()); err != nil { - resultsErr = errInvalidParams + if len(reqParams) < 1 { + return nil, errInvalidParams + } else if byteTx, err := reqParams[0].GetBytesHex(); err != nil { + return nil, errInvalidParams } else { r := io.NewBinReaderFromBuf(byteTx) tx := &transaction.Transaction{} From d5fa31cecd120fa08ca5d50ba78c0d929ea81401 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 21 Nov 2019 18:05:18 +0300 Subject: [PATCH 3/6] rpc: trim spaces in tests once Signed-off-by: Evgenii Stratonikov --- pkg/rpc/server_test.go | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/pkg/rpc/server_test.go b/pkg/rpc/server_test.go index e9cb66cfe..6c8f4e622 100644 --- a/pkg/rpc/server_test.go +++ b/pkg/rpc/server_test.go @@ -21,7 +21,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res StringResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "0x"+chain.CurrentBlockHash().ReverseString(), res.Result) }) @@ -31,7 +31,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res GetBlockResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) block, err := chain.GetBlock(chain.GetHeaderHash(1)) assert.NoErrorf(t, err, "could not get block") @@ -44,7 +44,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res IntResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, chain.BlockHeight()+1, uint32(res.Result)) }) @@ -54,7 +54,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res StringResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) block, err := chain.GetBlock(chain.GetHeaderHash(1)) assert.NoErrorf(t, err, "could not get block") @@ -67,7 +67,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res IntResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, 0, res.Result) }) @@ -77,7 +77,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res GetVersionResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "/NEO-GO:/", res.Result.UserAgent) }) @@ -87,7 +87,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res GetPeersResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, []int{}, res.Result.Bad) assert.Equal(t, []int{}, res.Result.Unconnected) @@ -99,7 +99,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res ValidateAddrResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, true, res.Result.IsValid) }) @@ -109,7 +109,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res ValidateAddrResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, false, res.Result.IsValid) }) @@ -119,7 +119,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res GetAssetResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "00", res.Result.Owner) assert.Equal(t, "AWKECj9RD8rS8RPcpCgYVjk1DeYyHwxZm3", res.Result.Admin) @@ -130,7 +130,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res StringResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "Invalid assetid", res.Result) }) @@ -140,7 +140,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res GetAccountStateResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, 1, len(res.Result.Balances)) assert.Equal(t, false, res.Result.Frozen) @@ -151,7 +151,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res GetUnspents - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, 1, len(res.Result.Balance)) assert.Equal(t, 1, len(res.Result.Balance[0].Unspents)) @@ -162,7 +162,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res StringResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "Invalid public account address", res.Result) }) @@ -172,7 +172,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res StringResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "Invalid public account address", res.Result) }) @@ -184,7 +184,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res StringResultResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "400000455b7b226c616e67223a227a682d434e222c226e616d65223a22e5b08fe89a81e882a1227d2c7b226c616e67223a22656e222c226e616d65223a22416e745368617265227d5d0000c16ff28623000000da1745e9b549bd0bfa1a569971c77eba30cd5a4b00000000", res.Result) }) @@ -194,7 +194,7 @@ func TestRPC(t *testing.T) { body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) var res SendTXResponse - err := json.Unmarshal(bytes.TrimSpace(body), &res) + err := json.Unmarshal(body, &res) assert.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, true, res.Result) }) @@ -208,7 +208,7 @@ func TestRPC(t *testing.T) { func checkErrResponse(t *testing.T, body []byte, expectingFail bool) { var errresp ErrorResponse - err := json.Unmarshal(bytes.TrimSpace(body), &errresp) + err := json.Unmarshal(body, &errresp) assert.Nil(t, err) if expectingFail { assert.NotEqual(t, 0, errresp.Error.Code) @@ -227,5 +227,5 @@ func doRPCCall(rpcCall string, handler http.HandlerFunc, t *testing.T) []byte { resp := w.Result() body, err := ioutil.ReadAll(resp.Body) assert.NoErrorf(t, err, "could not read response from the request: %s", rpcCall) - return body + return bytes.TrimSpace(body) } From 3afcd784f0daeccb2cd36ba27b0027b63b8bb47c Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 21 Nov 2019 19:41:28 +0300 Subject: [PATCH 4/6] rpc: refactor tests Signed-off-by: Evgenii Stratonikov --- pkg/rpc/server_test.go | 519 +++++++++++++++++++++++++++-------------- 1 file changed, 350 insertions(+), 169 deletions(-) diff --git a/pkg/rpc/server_test.go b/pkg/rpc/server_test.go index 6c8f4e622..1e4e3ed6a 100644 --- a/pkg/rpc/server_test.go +++ b/pkg/rpc/server_test.go @@ -7,175 +7,335 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "reflect" "strings" "testing" + "github.com/CityOfZion/neo-go/pkg/core" + "github.com/CityOfZion/neo-go/pkg/rpc/wrappers" + "github.com/CityOfZion/neo-go/pkg/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +type executor struct { + chain *core.Blockchain + handler http.HandlerFunc +} + +const ( + defaultJSONRPC = "2.0" + defaultID = 1 +) + +type rpcTestCase struct { + name string + params string + fail bool + result func(e *executor) interface{} + check func(t *testing.T, e *executor, result interface{}) +} + +var rpcTestCases = map[string][]rpcTestCase{ + "getaccountstate": { + { + name: "positive", + params: `["AZ81H31DMWzbSnFDLFkzh9vHwaDLayV7fU"]`, + result: func(e *executor) interface{} { return &GetAccountStateResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*GetAccountStateResponse) + require.True(t, ok) + assert.Equal(t, 1, len(res.Result.Balances)) + assert.Equal(t, false, res.Result.Frozen) + }, + }, + { + name: "negative", + params: `["AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y"]`, + result: func(e *executor) interface{} { return "Invalid public account address" }, + }, + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid address", + params: `["notabase58"]`, + fail: true, + }, + }, + "getassetstate": { + { + name: "positive", + params: `["602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7"]`, + result: func(e *executor) interface{} { return &GetAssetResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*GetAssetResponse) + require.True(t, ok) + assert.Equal(t, "00", res.Result.Owner) + assert.Equal(t, "AWKECj9RD8rS8RPcpCgYVjk1DeYyHwxZm3", res.Result.Admin) + }, + }, + { + name: "negative", + params: `["602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de2"]`, + result: func(e *executor) interface{} { return "Invalid assetid" }, + }, + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid hash", + params: `["notahex"]`, + fail: true, + }, + }, + "getbestblockhash": { + { + params: "[]", + result: func(e *executor) interface{} { + return "0x" + e.chain.CurrentBlockHash().ReverseString() + }, + }, + { + params: "1", + fail: true, + }, + }, + "getblock": { + { + name: "positive", + params: "[1]", + result: func(e *executor) interface{} { return &GetBlockResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*GetBlockResponse) + require.True(t, ok) + + block, err := e.chain.GetBlock(e.chain.GetHeaderHash(1)) + require.NoErrorf(t, err, "could not get block") + + expectedHash := "0x" + block.Hash().ReverseString() + assert.Equal(t, expectedHash, res.Result.Hash) + }, + }, + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid height", + params: `[-1]`, + fail: true, + }, + { + name: "invalid hash", + params: `["notahex"]`, + fail: true, + }, + { + name: "missing hash", + params: `["` + util.Uint256{}.String() + `"]`, + fail: true, + }, + }, + "getblockcount": { + { + params: "[]", + result: func(e *executor) interface{} { return int(e.chain.BlockHeight() + 1) }, + }, + }, + "getblockhash": { + { + params: "[1]", + result: func(e *executor) interface{} { return "" }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*StringResultResponse) + require.True(t, ok) + + block, err := e.chain.GetBlock(e.chain.GetHeaderHash(1)) + require.NoErrorf(t, err, "could not get block") + + expectedHash := "0x" + block.Hash().ReverseString() + assert.Equal(t, expectedHash, res.Result) + }, + }, + { + name: "string height", + params: `["first"]`, + fail: true, + }, + { + name: "invalid number height", + params: `[-2]`, + fail: true, + }, + }, + "getconnectioncount": { + { + params: "[]", + result: func(*executor) interface{} { return 0 }, + }, + }, + "getpeers": { + { + params: "[]", + result: func(*executor) interface{} { + return &GetPeersResponse{ + Jsonrpc: defaultJSONRPC, + Result: struct { + Unconnected []int `json:"unconnected"` + Connected []int `json:"connected"` + Bad []int `json:"bad"` + }{ + Unconnected: []int{}, + Connected: []int{}, + Bad: []int{}, + }, + ID: defaultID, + } + }, + }, + }, + "getrawtransaction": { + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid hash", + params: `["notahex"]`, + fail: true, + }, + { + name: "missing hash", + params: `["` + util.Uint256{}.String() + `"]`, + fail: true, + }, + }, + "getunspents": { + { + name: "positive", + params: `["AZ81H31DMWzbSnFDLFkzh9vHwaDLayV7fU"]`, + result: func(e *executor) interface{} { return &GetUnspents{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*GetUnspents) + require.True(t, ok) + require.Equal(t, 1, len(res.Result.Balance)) + assert.Equal(t, 1, len(res.Result.Balance[0].Unspents)) + }, + }, + { + name: "negative", + params: `["AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y"]`, + result: func(e *executor) interface{} { return "Invalid public account address" }, + }, + }, + "getversion": { + { + params: "[]", + result: func(*executor) interface{} { return &GetVersionResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + resp, ok := result.(*GetVersionResponse) + require.True(t, ok) + require.Equal(t, "/NEO-GO:/", resp.Result.UserAgent) + }, + }, + }, + "sendrawtransaction": { + { + name: "positive", + params: `["d1001b00046e616d6567d3d8602814a429a91afdbaa3914884a1c90c733101201cc9c05cefffe6cdd7b182816a9152ec218d2ec000000141403387ef7940a5764259621e655b3c621a6aafd869a611ad64adcc364d8dd1edf84e00a7f8b11b630a377eaef02791d1c289d711c08b7ad04ff0d6c9caca22cfe6232103cbb45da6072c14761c9da545749d9cfd863f860c351066d16df480602a2024c6ac"]`, + result: func(e *executor) interface{} { return &SendTXResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*SendTXResponse) + require.True(t, ok) + assert.True(t, res.Result) + }, + }, + { + name: "negative", + params: `["0274d792072617720636f6e7472616374207472616e73616374696f6e206465736372697074696f6e01949354ea0a8b57dfee1e257a1aedd1e0eea2e5837de145e8da9c0f101bfccc8e0100029b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc500a3e11100000000ea610aa6db39bd8c8556c9569d94b5e5a5d0ad199b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc5004f2418010000001cc9c05cefffe6cdd7b182816a9152ec218d2ec0014140dbd3cddac5cb2bd9bf6d93701f1a6f1c9dbe2d1b480c54628bbb2a4d536158c747a6af82698edf9f8af1cac3850bcb772bd9c8e4ac38f80704751cc4e0bd0e67232103cbb45da6072c14761c9da545749d9cfd863f860c351066d16df480602a2024c6ac"]`, + fail: true, + }, + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid string", + params: `["notahex"]`, + fail: true, + }, + { + name: "invalid tx", + params: `["0274d792072617720636f6e747261637"]`, + fail: true, + }, + }, + "validateaddress": { + { + name: "positive", + params: `["AQVh2pG732YvtNaxEGkQUei3YA4cvo7d2i"]`, + result: func(*executor) interface{} { return &ValidateAddrResponse{} }, + check: func(t *testing.T, e *executor, result interface{}) { + res, ok := result.(*ValidateAddrResponse) + require.True(t, ok) + assert.Equal(t, "AQVh2pG732YvtNaxEGkQUei3YA4cvo7d2i", res.Result.Address) + assert.True(t, res.Result.IsValid) + }, + }, + { + name: "negative", + params: "[1]", + result: func(*executor) interface{} { + return &ValidateAddrResponse{ + Jsonrpc: defaultJSONRPC, + Result: wrappers.ValidateAddressResponse{ + Address: float64(1), + IsValid: false, + }, + ID: defaultID, + } + }, + }, + }, +} + func TestRPC(t *testing.T) { chain, handler := initServerWithInMemoryChain(t) - t.Run("getbestblockhash", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getbestblockhash", "params": []}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res StringResultResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, "0x"+chain.CurrentBlockHash().ReverseString(), res.Result) - }) + e := &executor{chain: chain, handler: handler} + for method, cases := range rpcTestCases { + t.Run(method, func(t *testing.T) { + rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` - t.Run("getblock", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getblock", "params": [1]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res GetBlockResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - block, err := chain.GetBlock(chain.GetHeaderHash(1)) - assert.NoErrorf(t, err, "could not get block") - expectedHash := "0x" + block.Hash().ReverseString() - assert.Equal(t, expectedHash, res.Result.Hash) - }) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), handler, t) + checkErrResponse(t, body, tc.fail) + if tc.fail { + return + } - t.Run("getblockcount", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getblockcount", "params": []}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res IntResultResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, chain.BlockHeight()+1, uint32(res.Result)) - }) + expected, res := tc.getResultPair(e) + err := json.Unmarshal(body, res) + require.NoErrorf(t, err, "could not parse response: %s", body) - t.Run("getblockhash", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getblockhash", "params": [1]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res StringResultResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - block, err := chain.GetBlock(chain.GetHeaderHash(1)) - assert.NoErrorf(t, err, "could not get block") - expectedHash := "0x" + block.Hash().ReverseString() - assert.Equal(t, expectedHash, res.Result) - }) - - t.Run("getconnectioncount", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getconnectioncount", "params": []}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res IntResultResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, 0, res.Result) - }) - - t.Run("getversion", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getversion", "params": []}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res GetVersionResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, "/NEO-GO:/", res.Result.UserAgent) - }) - - t.Run("getpeers", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getpeers", "params": []}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res GetPeersResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, []int{}, res.Result.Bad) - assert.Equal(t, []int{}, res.Result.Unconnected) - assert.Equal(t, []int{}, res.Result.Connected) - }) - - t.Run("validateaddress_positive", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "validateaddress", "params": ["AQVh2pG732YvtNaxEGkQUei3YA4cvo7d2i"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res ValidateAddrResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, true, res.Result.IsValid) - }) - - t.Run("validateaddress_negative", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "validateaddress", "params": [1]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res ValidateAddrResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, false, res.Result.IsValid) - }) - - t.Run("getassetstate_positive", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getassetstate", "params": ["602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res GetAssetResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, "00", res.Result.Owner) - assert.Equal(t, "AWKECj9RD8rS8RPcpCgYVjk1DeYyHwxZm3", res.Result.Admin) - }) - - t.Run("getassetstate_negative", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getassetstate", "params": ["602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de2"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res StringResultResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, "Invalid assetid", res.Result) - }) - - t.Run("getaccountstate_positive", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getaccountstate", "params": ["AZ81H31DMWzbSnFDLFkzh9vHwaDLayV7fU"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res GetAccountStateResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, 1, len(res.Result.Balances)) - assert.Equal(t, false, res.Result.Frozen) - }) - - t.Run("getunspents_positive", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getunspents", "params": ["AZ81H31DMWzbSnFDLFkzh9vHwaDLayV7fU"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res GetUnspents - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, 1, len(res.Result.Balance)) - assert.Equal(t, 1, len(res.Result.Balance[0].Unspents)) - }) - - t.Run("getaccountstate_negative", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getaccountstate", "params": ["AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res StringResultResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, "Invalid public account address", res.Result) - }) - - t.Run("getunspents_negative", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getunspents", "params": ["AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, false) - var res StringResultResponse - err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, "Invalid public account address", res.Result) - }) + if tc.check == nil { + assert.Equal(t, expected, res) + } else { + tc.check(t, e, res) + } + }) + } + }) + } t.Run("getrawtransaction", func(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) @@ -185,31 +345,52 @@ func TestRPC(t *testing.T) { checkErrResponse(t, body, false) var res StringResultResponse err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) + require.NoErrorf(t, err, "could not parse response: %s", body) assert.Equal(t, "400000455b7b226c616e67223a227a682d434e222c226e616d65223a22e5b08fe89a81e882a1227d2c7b226c616e67223a22656e222c226e616d65223a22416e745368617265227d5d0000c16ff28623000000da1745e9b549bd0bfa1a569971c77eba30cd5a4b00000000", res.Result) }) - t.Run("sendrawtransaction_positive", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "sendrawtransaction", "params": ["d1001b00046e616d6567d3d8602814a429a91afdbaa3914884a1c90c733101201cc9c05cefffe6cdd7b182816a9152ec218d2ec000000141403387ef7940a5764259621e655b3c621a6aafd869a611ad64adcc364d8dd1edf84e00a7f8b11b630a377eaef02791d1c289d711c08b7ad04ff0d6c9caca22cfe6232103cbb45da6072c14761c9da545749d9cfd863f860c351066d16df480602a2024c6ac"]}` + t.Run("getrawtransaction 2 arguments", func(t *testing.T) { + block, _ := chain.GetBlock(chain.GetHeaderHash(0)) + TXHash := block.Transactions[1].Hash() + rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 0]}"`, TXHash.ReverseString()) body := doRPCCall(rpc, handler, t) checkErrResponse(t, body, false) - var res SendTXResponse + var res StringResultResponse err := json.Unmarshal(body, &res) - assert.NoErrorf(t, err, "could not parse response: %s", body) - assert.Equal(t, true, res.Result) + require.NoErrorf(t, err, "could not parse response: %s", body) + assert.Equal(t, "400000455b7b226c616e67223a227a682d434e222c226e616d65223a22e5b08fe89a81e882a1227d2c7b226c616e67223a22656e222c226e616d65223a22416e745368617265227d5d0000c16ff28623000000da1745e9b549bd0bfa1a569971c77eba30cd5a4b00000000", res.Result) }) +} - t.Run("sendrawtransaction_negative", func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "sendrawtransaction", "params": ["0274d792072617720636f6e7472616374207472616e73616374696f6e206465736372697074696f6e01949354ea0a8b57dfee1e257a1aedd1e0eea2e5837de145e8da9c0f101bfccc8e0100029b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc500a3e11100000000ea610aa6db39bd8c8556c9569d94b5e5a5d0ad199b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc5004f2418010000001cc9c05cefffe6cdd7b182816a9152ec218d2ec0014140dbd3cddac5cb2bd9bf6d93701f1a6f1c9dbe2d1b480c54628bbb2a4d536158c747a6af82698edf9f8af1cac3850bcb772bd9c8e4ac38f80704751cc4e0bd0e67232103cbb45da6072c14761c9da545749d9cfd863f860c351066d16df480602a2024c6ac"]}` - body := doRPCCall(rpc, handler, t) - checkErrResponse(t, body, true) - }) +func (tc rpcTestCase) getResultPair(e *executor) (expected interface{}, res interface{}) { + expected = tc.result(e) + switch exp := expected.(type) { + case string: + res = new(StringResultResponse) + expected = &StringResultResponse{ + Jsonrpc: defaultJSONRPC, + Result: exp, + ID: defaultID, + } + case int: + res = new(IntResultResponse) + expected = &IntResultResponse{ + Jsonrpc: defaultJSONRPC, + Result: exp, + ID: defaultID, + } + default: + resVal := reflect.New(reflect.TypeOf(expected).Elem()) + res = resVal.Interface() + } + + return } func checkErrResponse(t *testing.T, body []byte, expectingFail bool) { var errresp ErrorResponse err := json.Unmarshal(body, &errresp) - assert.Nil(t, err) + require.Nil(t, err) if expectingFail { assert.NotEqual(t, 0, errresp.Error.Code) assert.NotEqual(t, "", errresp.Error.Message) From 0f9024d177bc38cbbd01f0d3d965ade27a3ca0da Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 22 Nov 2019 12:08:33 +0300 Subject: [PATCH 5/6] rpc: make client default values constants Signed-off-by: Evgenii Stratonikov --- pkg/rpc/client.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/rpc/client.go b/pkg/rpc/client.go index eae5d30ec..c471109cb 100644 --- a/pkg/rpc/client.go +++ b/pkg/rpc/client.go @@ -15,7 +15,7 @@ import ( "github.com/pkg/errors" ) -var ( +const ( defaultDialTimeout = 4 * time.Second defaultRequestTimeout = 4 * time.Second defaultClientVersion = "2.0" @@ -75,7 +75,6 @@ func NewClient(ctx context.Context, endpoint string, opts ClientOptions) (*Clien // TODO(@antdm): Enable SSL. if opts.Cert != "" && opts.Key != "" { - } if opts.Client.Timeout == 0 { From c9d5f8b89c1604830c8b4b39dfca34d5e65d37a1 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 22 Nov 2019 12:22:19 +0300 Subject: [PATCH 6/6] rpc: cover stack_param with more tests Use require/assert instead of builtin facilities. Signed-off-by: Evgenii Stratonikov --- pkg/rpc/stack_param_test.go | 296 +++++++++++++++--------------------- 1 file changed, 125 insertions(+), 171 deletions(-) diff --git a/pkg/rpc/stack_param_test.go b/pkg/rpc/stack_param_test.go index 6aa602475..8e7b9965b 100644 --- a/pkg/rpc/stack_param_test.go +++ b/pkg/rpc/stack_param_test.go @@ -1,12 +1,12 @@ package rpc import ( - "encoding/hex" "encoding/json" "reflect" "testing" "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" ) var testCases = []struct { @@ -41,6 +41,28 @@ var testCases = []struct { }, }, }, + { + input: `{"type": "Hash160", "value": "0bcd2978634d961c24f5aea0802297ff128724d6"}`, + result: StackParam{ + Type: Hash160, + Value: util.Uint160{ + 0x0b, 0xcd, 0x29, 0x78, 0x63, 0x4d, 0x96, 0x1c, 0x24, 0xf5, + 0xae, 0xa0, 0x80, 0x22, 0x97, 0xff, 0x12, 0x87, 0x24, 0xd6, + }, + }, + }, + { + input: `{"type": "Hash256", "value": "f037308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d"}`, + result: StackParam{ + Type: Hash256, + Value: util.Uint256{ + 0x2d, 0xf3, 0x45, 0xf2, 0x45, 0xc5, 0x98, 0x9e, + 0x69, 0x95, 0x45, 0x06, 0xa5, 0x9e, 0x40, 0x12, + 0xc1, 0x68, 0x54, 0x48, 0x08, 0xfc, 0xcc, 0x5b, + 0x15, 0x18, 0xab, 0xa0, 0x8f, 0x30, 0x37, 0xf0, + }, + }, + }, } var errorCases = []string{ @@ -55,192 +77,124 @@ var errorCases = []string{ `{"type": "Hash160","value": "0bcd"}`, // incorrect Uint160 value `{"type": "Hash256","value": "0bcd"}`, // incorrect Uint256 value `{"type": "Stringg","value": ""}`, // incorrect type + `{"type": {},"value": ""}`, // incorrect value + + `{"type": "InteropInterface","value": ""}`, // ununmarshable type } func TestStackParam_UnmarshalJSON(t *testing.T) { - var ( - err error - r, s StackParam - ) + var s StackParam for _, tc := range testCases { - if err = json.Unmarshal([]byte(tc.input), &s); err != nil { - t.Errorf("error while unmarhsalling: %v", err) - } else if !reflect.DeepEqual(s, tc.result) { - t.Errorf("got (%v), expected (%v)", s, tc.result) - } - } - - // Hash160 unmarshalling - err = json.Unmarshal([]byte(`{"type": "Hash160","value": "0bcd2978634d961c24f5aea0802297ff128724d6"}`), &s) - if err != nil { - t.Errorf("error while unmarhsalling: %v", err) - } - - h160, err := util.Uint160DecodeString("0bcd2978634d961c24f5aea0802297ff128724d6") - if err != nil { - t.Errorf("unmarshal error: %v", err) - } - - if r = (StackParam{Type: Hash160, Value: h160}); !reflect.DeepEqual(s, r) { - t.Errorf("got (%v), expected (%v)", s, r) - } - - // Hash256 unmarshalling - err = json.Unmarshal([]byte(`{"type": "Hash256","value": "f037308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d"}`), &s) - if err != nil { - t.Errorf("error while unmarhsalling: %v", err) - } - h256, err := util.Uint256DecodeReverseString("f037308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d") - if err != nil { - t.Errorf("unmarshal error: %v", err) - } - if r = (StackParam{Type: Hash256, Value: h256}); !reflect.DeepEqual(s, r) { - t.Errorf("got (%v), expected (%v)", s, r) + assert.NoError(t, json.Unmarshal([]byte(tc.input), &s)) + assert.Equal(t, s, tc.result) } for _, input := range errorCases { - if err = json.Unmarshal([]byte(input), &s); err == nil { - t.Errorf("expected error, got (nil)") - } + assert.Error(t, json.Unmarshal([]byte(input), &s)) } } -const ( - hash160 = "0bcd2978634d961c24f5aea0802297ff128724d6" - hash256 = "7fe610b7c8259ae949accacb091a1bc53219c51a1cb8752fbc6457674c13ec0b" - testString = "myteststring" -) +var tryParseTestCases = []struct { + input interface{} + expected interface{} +}{ + { + input: []byte{ + 0x0b, 0xcd, 0x29, 0x78, 0x63, 0x4d, 0x96, 0x1c, 0x24, 0xf5, + 0xae, 0xa0, 0x80, 0x22, 0x97, 0xff, 0x12, 0x87, 0x24, 0xd6, + }, + expected: util.Uint160{ + 0x0b, 0xcd, 0x29, 0x78, 0x63, 0x4d, 0x96, 0x1c, 0x24, 0xf5, + 0xae, 0xa0, 0x80, 0x22, 0x97, 0xff, 0x12, 0x87, 0x24, 0xd6, + }, + }, + { + input: []byte{ + 0xf0, 0x37, 0x30, 0x8f, 0xa0, 0xab, 0x18, 0x15, + 0x5b, 0xcc, 0xfc, 0x08, 0x48, 0x54, 0x68, 0xc1, + 0x12, 0x40, 0x9e, 0xa5, 0x06, 0x45, 0x95, 0x69, + 0x9e, 0x98, 0xc5, 0x45, 0xf2, 0x45, 0xf3, 0x2d, + }, + expected: util.Uint256{ + 0x2d, 0xf3, 0x45, 0xf2, 0x45, 0xc5, 0x98, 0x9e, + 0x69, 0x95, 0x45, 0x06, 0xa5, 0x9e, 0x40, 0x12, + 0xc1, 0x68, 0x54, 0x48, 0x08, 0xfc, 0xcc, 0x5b, + 0x15, 0x18, 0xab, 0xa0, 0x8f, 0x30, 0x37, 0xf0, + }, + }, + { + input: []byte{0, 1, 2, 3, 4, 9, 8, 6}, + expected: []byte{0, 1, 2, 3, 4, 9, 8, 6}, + }, + { + input: []byte{0x63, 0x78, 0x29, 0xcd, 0x0b}, + expected: int64(50686687331), + }, + { + input: []byte("this is a test string"), + expected: "this is a test string", + }, +} func TestStackParam_TryParse(t *testing.T) { - // ByteArray to util.Uint160 conversion - data, err := hex.DecodeString(hash160) - if err != nil { - t.Fatal(err) - } - - var ( - outputUint160, expectedUint160 util.Uint160 - input = StackParam{ - Type: ByteArray, - Value: data, - } - ) - expectedUint160, err = util.Uint160DecodeString(hash160) - if err != nil { - t.Fatal(err) - } - if err = input.TryParse(&outputUint160); err != nil { - t.Errorf("failed to parse stackparam to Uint160: %v", err) - } - if !reflect.DeepEqual(outputUint160, expectedUint160) { - t.Errorf("got (%v), expected (%v)", outputUint160, expectedUint160) - } - - // ByteArray to util.Uint256 conversion - data, err = hex.DecodeString(hash256) - if err != nil { - t.Fatal(err) - } - - var ( - outputUint256, expectedUint256 util.Uint256 - uint256input = StackParam{ - Type: ByteArray, - Value: data, - } - ) - expectedUint256, err = util.Uint256DecodeReverseString(hash256) - if err != nil { - t.Fatal(err) - } - if err = uint256input.TryParse(&outputUint256); err != nil { - t.Errorf("failed to parse stackparam to []byte: %v", err) - } - if !reflect.DeepEqual(outputUint256, expectedUint256) { - t.Errorf("got (%v), expected (%v)", outputUint256, expectedUint256) - } - - // ByteArray to []byte conversion - var ( - outputBytes []byte - expectedBytes = expectedUint160.Bytes() - ) - if err = input.TryParse(&outputBytes); err != nil { - t.Errorf("failed to parse stackparam to []byte: %v", err) - } - if !reflect.DeepEqual(outputBytes, expectedBytes) { - t.Errorf("got (%v), expected (%v)", outputBytes, expectedBytes) - } - - // ByteArray to int64 conversion - data, err = hex.DecodeString("637829cd0b") - if err != nil { - t.Fatal(err) - } - var ( - outputInt, expectedInt int64 - intinput = StackParam{ - Type: ByteArray, - Value: data, - } - ) - expectedInt = 50686687331 - if err = intinput.TryParse(&outputInt); err != nil { - t.Errorf("failed to parse stackparam to []byte: %v", err) - } - if !reflect.DeepEqual(outputInt, expectedInt) { - t.Errorf("got (%v), expected (%v)", outputInt, expectedInt) - } - - // ByteArray to string conversion - data = []byte(testString) - var ( - outputStr, expectedStr string - strinput = StackParam{ - Type: ByteArray, - Value: data, - } - ) - expectedStr = testString - if err = strinput.TryParse(&outputStr); err != nil { - t.Errorf("failed to parse stackparam to []byte: %v", err) - } - if !reflect.DeepEqual(outputStr, expectedStr) { - t.Errorf("got (%v), expected (%v)", outputStr, expectedStr) - } - - // StackParams to []util.Uint160 - data, err = hex.DecodeString(hash160) - if err != nil { - t.Fatal(err) - } - expUint160, err := util.Uint160DecodeString(hash160) - if err != nil { - t.Fatal(err) - } - var ( - params = StackParams{ - StackParam{ + for _, tc := range tryParseTestCases { + t.Run(reflect.TypeOf(tc.expected).String(), func(t *testing.T) { + input := StackParam{ Type: ByteArray, - Value: data, + Value: tc.input, + } + + val := reflect.New(reflect.TypeOf(tc.expected)) + assert.NoError(t, input.TryParse(val.Interface())) + assert.Equal(t, tc.expected, val.Elem().Interface()) + }) + } + + t.Run("[]Uint160", func(t *testing.T) { + exp1 := util.Uint160{1, 2, 3, 4, 5} + exp2 := util.Uint160{9, 8, 7, 6, 5} + + params := StackParams{ + { + Type: ByteArray, + Value: exp1.Bytes(), }, - StackParam{ + { Type: ByteArray, - Value: data, + Value: exp2.Bytes(), }, } - expectedArray = []util.Uint160{ - expUint160, - expUint160, - } - out1, out2 = &util.Uint160{}, &util.Uint160{} - ) - if err = params.TryParseArray(out1, out2); err != nil { - t.Errorf("failed to parse stackparam to []byte: %v", err) - } - outArray := []util.Uint160{*out1, *out2} - if !reflect.DeepEqual(outArray, expectedArray) { - t.Errorf("got (%v), expected (%v)", outArray, expectedArray) - } + var out1, out2 util.Uint160 + + assert.NoError(t, params.TryParseArray(&out1, &out2)) + assert.Equal(t, exp1, out1) + assert.Equal(t, exp2, out2) + }) +} + +func TestStackParamType_String(t *testing.T) { + types := []StackParamType{ + Signature, + Boolean, + Integer, + Hash160, + Hash256, + ByteArray, + PublicKey, + String, + Array, + InteropInterface, + Void, + } + + for _, exp := range types { + actual, err := StackParamTypeFromString(exp.String()) + assert.NoError(t, err) + assert.Equal(t, exp, actual) + } + + actual, err := StackParamTypeFromString(Unknown.String()) + assert.Error(t, err) + assert.Equal(t, Unknown, actual) }