From 5794bdb1697256568328e1c286ed6811af86fcf6 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Wed, 3 Jun 2020 18:09:28 +0300 Subject: [PATCH 1/9] state: implement JSON marshaling for MPT* items --- pkg/core/state/mpt_root.go | 55 ++++++++++++++++++++++++++++----- pkg/core/state/mpt_root_test.go | 39 +++++++++++++++++++++++ 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go index facf3da45..dea3f62fa 100644 --- a/pkg/core/state/mpt_root.go +++ b/pkg/core/state/mpt_root.go @@ -1,6 +1,9 @@ package state import ( + "encoding/json" + "errors" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" @@ -9,16 +12,16 @@ import ( // MPTRootBase represents storage state root. type MPTRootBase struct { - Version byte - Index uint32 - PrevHash util.Uint256 - Root util.Uint256 + Version byte `json:"version"` + Index uint32 `json:"index"` + PrevHash util.Uint256 `json:"prehash"` + Root util.Uint256 `json:"stateroot"` } // MPTRoot represents storage state root together with sign info. type MPTRoot struct { MPTRootBase - Witness *transaction.Witness + Witness *transaction.Witness `json:"witness,omitempty"` } // MPTRootStateFlag represents verification state of the state root. @@ -33,8 +36,8 @@ const ( // MPTRootState represents state root together with its verification state. type MPTRootState struct { - MPTRoot - Flag MPTRootStateFlag + MPTRoot `json:"stateroot"` + Flag MPTRootStateFlag `json:"flag"` } // EncodeBinary implements io.Serializable. @@ -103,3 +106,41 @@ func (s *MPTRoot) EncodeBinary(w *io.BinWriter) { w.WriteArray([]*transaction.Witness{s.Witness}) } } + +// String implements fmt.Stringer. +func (f MPTRootStateFlag) String() string { + switch f { + case Unverified: + return "Unverified" + case Verified: + return "Verified" + case Invalid: + return "Invalid" + default: + return "" + } +} + +// MarshalJSON implements json.Marshaler. +func (f MPTRootStateFlag) MarshalJSON() ([]byte, error) { + return []byte(`"` + f.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (f *MPTRootStateFlag) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + switch s { + case "Unverified": + *f = Unverified + case "Verified": + *f = Verified + case "Invalid": + *f = Invalid + default: + return errors.New("unknown flag") + } + return nil +} diff --git a/pkg/core/state/mpt_root_test.go b/pkg/core/state/mpt_root_test.go index 15a3ca043..f1c0b5c61 100644 --- a/pkg/core/state/mpt_root_test.go +++ b/pkg/core/state/mpt_root_test.go @@ -1,12 +1,14 @@ package state import ( + "encoding/json" "math/rand" "testing" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" ) @@ -59,3 +61,40 @@ func TestMPTRootStateUnverifiedByDefault(t *testing.T) { var r MPTRootState require.Equal(t, Unverified, r.Flag) } + +func TestMPTRoot_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + r := testStateRoot() + rs := &MPTRootState{ + MPTRoot: *r, + Flag: Verified, + } + testserdes.MarshalUnmarshalJSON(t, rs, new(MPTRootState)) + }) + + t.Run("Compatibility", func(t *testing.T) { + js := []byte(`{ + "flag": "Unverified", + "stateroot": { + "version": 1, + "index": 3000000, + "prehash": "0x4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90", + "stateroot": "0xb2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3" + }}`) + + rs := new(MPTRootState) + require.NoError(t, json.Unmarshal(js, &rs)) + + require.EqualValues(t, 1, rs.Version) + require.EqualValues(t, 3000000, rs.Index) + require.Nil(t, rs.Witness) + + u, err := util.Uint256DecodeStringLE("4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90") + require.NoError(t, err) + require.Equal(t, u, rs.PrevHash) + + u, err = util.Uint256DecodeStringLE("b2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3") + require.NoError(t, err) + require.Equal(t, u, rs.Root) + }) +} From 44f93c7c69e8ff18fdde0538e5923f8bb62905a5 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 4 Jun 2020 14:58:47 +0300 Subject: [PATCH 2/9] rpc/server: simplify errors handling during parameter parsing --- pkg/rpc/client/wsclient_test.go | 12 +- pkg/rpc/request/param.go | 30 +++-- pkg/rpc/request/params.go | 15 ++- pkg/rpc/server/server.go | 199 +++++++++----------------------- 4 files changed, 91 insertions(+), 165 deletions(-) diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 1eebe08bd..b548e74a3 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -181,8 +181,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -196,8 +196,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.NotificationFilterT, param.Type) filt, ok := param.Value.(request.NotificationFilter) require.Equal(t, true, ok) @@ -211,8 +211,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.ExecutionFilterT, param.Type) filt, ok := param.Value.(request.ExecutionFilter) require.Equal(t, true, ok) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 42159c336..5a961ffe9 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -61,12 +61,17 @@ const ( ExecutionFilterT ) +var errMissingParameter = errors.New("parameter is missing") + func (p Param) String() string { return fmt.Sprintf("%v", p.Value) } // GetString returns string value of the parameter. -func (p Param) GetString() (string, error) { +func (p *Param) GetString() (string, error) { + if p == nil { + return "", errMissingParameter + } str, ok := p.Value.(string) if !ok { return "", errors.New("not a string") @@ -75,7 +80,10 @@ func (p Param) GetString() (string, error) { } // GetInt returns int value of te parameter. -func (p Param) GetInt() (int, error) { +func (p *Param) GetInt() (int, error) { + if p == nil { + return 0, errMissingParameter + } i, ok := p.Value.(int) if ok { return i, nil @@ -86,7 +94,10 @@ func (p Param) GetInt() (int, error) { } // GetArray returns a slice of Params stored in the parameter. -func (p Param) GetArray() ([]Param, error) { +func (p *Param) GetArray() ([]Param, error) { + if p == nil { + return nil, errMissingParameter + } a, ok := p.Value.([]Param) if !ok { return nil, errors.New("not an array") @@ -95,7 +106,7 @@ func (p Param) GetArray() ([]Param, error) { } // GetUint256 returns Uint256 value of the parameter. -func (p Param) GetUint256() (util.Uint256, error) { +func (p *Param) GetUint256() (util.Uint256, error) { s, err := p.GetString() if err != nil { return util.Uint256{}, err @@ -105,7 +116,7 @@ func (p Param) GetUint256() (util.Uint256, error) { } // GetUint160FromHex returns Uint160 value of the parameter encoded in hex. -func (p Param) GetUint160FromHex() (util.Uint160, error) { +func (p *Param) GetUint160FromHex() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -119,7 +130,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) { // GetUint160FromAddress returns Uint160 value of the parameter that was // supplied as an address. -func (p Param) GetUint160FromAddress() (util.Uint160, error) { +func (p *Param) GetUint160FromAddress() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -129,7 +140,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) { } // GetFuncParam returns current parameter as a function call parameter. -func (p Param) GetFuncParam() (FuncParam, error) { +func (p *Param) GetFuncParam() (FuncParam, error) { + if p == nil { + return FuncParam{}, errMissingParameter + } fp, ok := p.Value.(FuncParam) if !ok { return FuncParam{}, errors.New("not a function parameter") @@ -139,7 +153,7 @@ func (p Param) GetFuncParam() (FuncParam, error) { // GetBytesHex returns []byte value of the parameter if // it is a hex-encoded string. -func (p Param) GetBytesHex() ([]byte, error) { +func (p *Param) GetBytesHex() ([]byte, error) { s, err := p.GetString() if err != nil { return nil, err diff --git a/pkg/rpc/request/params.go b/pkg/rpc/request/params.go index 8b1945cb1..dd2ac35b9 100644 --- a/pkg/rpc/request/params.go +++ b/pkg/rpc/request/params.go @@ -7,20 +7,19 @@ type ( // Value returns the param struct for the given // index if it exists. -func (p Params) Value(index int) (*Param, bool) { +func (p Params) Value(index int) *Param { if len(p) > index { - return &p[index], true + return &p[index] } - return nil, false + return nil } // ValueWithType returns the param struct at the given index if it // exists and matches the given type. -func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) { - if val, ok := p.Value(index); ok && val.Type == valType { - return val, true +func (p Params) ValueWithType(index int, valType paramType) *Param { + if val := p.Value(index); val != nil && val.Type == valType { + return val } - - return nil, false + return nil } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index a91502e5e..f62a368fe 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -389,8 +389,8 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { var hash util.Uint256 - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } @@ -425,8 +425,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro } func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return nil, response.ErrInvalidParams } num, err := s.blockHeightFromParam(param) @@ -463,20 +463,15 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) } func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } return validateAddress(param.Value), nil } func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - - paramAssetID, err := param.GetUint256() + paramAssetID, err := reqParams.ValueWithType(0, request.StringT).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -490,12 +485,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response // getApplicationLog returns the contract log based on the specified txid. func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - txHash, err := param.GetUint256() + txHash, err := reqParams.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -522,10 +512,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp } func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } + p := ps.ValueWithType(0, request.StringT) u, err := p.GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams @@ -574,11 +561,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) } func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromHex() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -607,11 +590,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro } func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -709,24 +688,14 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 } func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { - param, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - scriptHash, err := param.GetUint160FromHex() + scriptHash, err := ps.Value(0).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } scriptHash = scriptHash.Reverse() - param, ok = ps.Value(1) - if !ok { - return nil, response.ErrInvalidParams - } - - key, err := param.GetBytesHex() + key, err := ps.Value(1).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -743,9 +712,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp var resultsErr *response.Error var results interface{} - if param0, ok := reqParams.Value(0); !ok { - return nil, response.ErrInvalidParams - } else if txHash, err := param0.GetUint256(); err != nil { + if txHash, err := reqParams.Value(0).GetUint256(); err != nil { resultsErr = response.ErrInvalidParams } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) @@ -757,7 +724,10 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp resultsErr = response.NewInvalidParamsError(err.Error(), err) } - param1, _ := reqParams.Value(1) + param1 := reqParams.Value(1) + if param1 == nil { + param1 = &request.Param{} + } switch v := param1.Value.(type) { case int, float64, bool, string: @@ -777,12 +747,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp } func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -796,22 +761,12 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response } func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } - p, ok = ps.ValueWithType(1, request.NumberT) - if !ok { - return nil, response.ErrInvalidParams - } - - num, err := p.GetInt() + num, err := ps.ValueWithType(1, request.NumberT).GetInt() if err != nil || num < 0 { return nil, response.ErrInvalidParams } @@ -833,18 +788,15 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { var results interface{} - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromHex(); err != nil { + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() + if err != nil { return nil, response.ErrInvalidParams + } + cs := s.chain.GetContractState(scriptHash) + if cs != nil { + results = result.NewContractState(cs) } else { - cs := s.chain.GetContractState(scriptHash) - if cs != nil { - results = result.NewContractState(cs) - } else { - return nil, response.NewRPCError("Unknown contract", "", nil) - } + return nil, response.NewRPCError("Unknown contract", "", nil) } return results, nil } @@ -862,33 +814,31 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in var resultsErr *response.Error var results interface{} - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromAddress(); err != nil { + param := reqParams.ValueWithType(0, request.StringT) + scriptHash, err := param.GetUint160FromAddress() + if err != nil { return nil, response.ErrInvalidParams + } + as := s.chain.GetAccountState(scriptHash) + if as == nil { + as = state.NewAccount(scriptHash) + } + if unspents { + str, err := param.GetString() + if err != nil { + return nil, response.ErrInvalidParams + } + results = result.NewUnspents(as, s.chain, str) } else { - as := s.chain.GetAccountState(scriptHash) - if as == nil { - as = state.NewAccount(scriptHash) - } - if unspents { - str, err := param.GetString() - if err != nil { - return nil, response.ErrInvalidParams - } - results = result.NewUnspents(as, s.chain, str) - } else { - results = result.NewAccountState(as) - } + results = result.NewAccountState(as) } return results, resultsErr } // getBlockSysFee returns the system fees of the block, based on the specified index. func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return 0, response.ErrInvalidParams } @@ -915,21 +865,13 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { var verbose bool - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - hash, err := param.GetUint256() + hash, err := reqParams.ValueWithType(0, request.StringT).GetUint256() if err != nil { return nil, response.ErrInvalidParams } - param, ok = reqParams.ValueWithType(1, request.NumberT) - if ok { - v, err := param.GetInt() - if err != nil { - return nil, response.ErrInvalidParams - } + v, err := reqParams.ValueWithType(1, request.NumberT).GetInt() + if err == nil { verbose = v != 0 } @@ -952,11 +894,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons // getUnclaimed returns unclaimed GAS amount of the specified address. func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -997,19 +935,11 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) // invoke implements the `invoke` RPC call. func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } - sliceP, ok := reqParams.ValueWithType(1, request.ArrayT) - if !ok { - return nil, response.ErrInvalidParams - } - slice, err := sliceP.GetArray() + slice, err := reqParams.ValueWithType(1, request.ArrayT).GetArray() if err != nil { return nil, response.ErrInvalidParams } @@ -1022,11 +952,7 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) // invokescript implements the `invokescript` RPC call. func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1069,11 +995,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke { // submitBlock broadcasts a raw block over the NEO network. func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - blockBytes, err := param.GetBytesHex() + blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1134,11 +1056,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res // subscribe handles subscription requests from websocket clients. func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - streamName, err := p.GetString() + streamName, err := reqParams.Value(0).GetString() if err != nil { return nil, response.ErrInvalidParams } @@ -1148,8 +1066,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface } // Optional filter. var filter interface{} - p, ok = reqParams.Value(1) - if ok { + if p := reqParams.Value(1); p != nil { // It doesn't accept filters. if event == response.BlockEventID { return nil, response.ErrInvalidParams @@ -1224,11 +1141,7 @@ func (s *Server) subscribeToChannel(event response.EventID) { // unsubscribe handles unsubscription requests from websocket clients. func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - id, err := p.GetInt() + id, err := reqParams.Value(0).GetInt() if err != nil || id < 0 { return nil, response.ErrInvalidParams } From 7b1a54c9347743be1fd90ca68bb3fe7b74676999 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 4 Jun 2020 15:43:37 +0300 Subject: [PATCH 3/9] rpc/server: unify boolean flag handling Implement (*Param).GetBoolean() for converting parameter to bool value. It is used for verbosity flag and is false iff it is either zero number or empty sting. --- pkg/rpc/request/param.go | 15 +++++++++++++++ pkg/rpc/server/server.go | 28 ++++------------------------ 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 5a961ffe9..de1e3dcb9 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -79,6 +79,21 @@ func (p *Param) GetString() (string, error) { return str, nil } +// GetBoolean returns boolean value of the parameter. +func (p *Param) GetBoolean() bool { + if p == nil { + return false + } + switch p.Type { + case NumberT: + return p.Value != 0 + case StringT: + return p.Value != "" + default: + return true + } +} + // GetInt returns int value of te parameter. func (p *Param) GetInt() (int, error) { if p == nil { diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index f62a368fe..e3703398c 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -416,7 +416,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro return nil, response.NewInternalServerError(fmt.Sprintf("Problem locating block with hash: %s", hash), err) } - if len(reqParams) == 2 && reqParams[1].Value == 1 { + if reqParams.Value(1).GetBoolean() { return result.NewBlock(block, s.chain), nil } writer := io.NewBufBinWriter() @@ -717,26 +717,12 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) return nil, response.NewRPCError("Unknown transaction", err.Error(), err) - } else if len(reqParams) >= 2 { + } else if reqParams.Value(1).GetBoolean() { _header := s.chain.GetHeaderHash(int(height)) header, err := s.chain.GetHeader(_header) if err != nil { resultsErr = response.NewInvalidParamsError(err.Error(), err) - } - - param1 := reqParams.Value(1) - if param1 == nil { - param1 = &request.Param{} - } - switch v := param1.Value.(type) { - - case int, float64, bool, string: - if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" { - results = hex.EncodeToString(tx.Bytes()) - } else { - results = result.NewTransactionOutputRaw(tx, header, s.chain) - } - default: + } else { results = result.NewTransactionOutputRaw(tx, header, s.chain) } } else { @@ -863,18 +849,12 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons // getBlockHeader returns the corresponding block header information according to the specified script hash. func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { - var verbose bool - hash, err := reqParams.ValueWithType(0, request.StringT).GetUint256() if err != nil { return nil, response.ErrInvalidParams } - v, err := reqParams.ValueWithType(1, request.NumberT).GetInt() - if err == nil { - verbose = v != 0 - } - + verbose := reqParams.Value(1).GetBoolean() h, err := s.chain.GetHeader(hash) if err != nil { return nil, response.NewRPCError("unknown block", "", nil) From dcaa82b32b6a54ccf3fb292d111e281a69e2d0e2 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 8 Jun 2020 16:38:44 +0300 Subject: [PATCH 4/9] rpc: convert `null` value to a defaultT Right now we convert it is unmarshaler into a float64(0) so an error is supressed. --- pkg/rpc/request/param.go | 5 +++++ pkg/rpc/request/param_test.go | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index de1e3dcb9..205a2a93e 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -192,6 +192,11 @@ func (p *Param) UnmarshalJSON(data []byte) error { {ArrayT, &[]Param{}}, } + if bytes.Equal(data, []byte("null")) { + p.Type = defaultT + return nil + } + for _, cur := range attempts { r := bytes.NewReader(data) jd := json.NewDecoder(r) diff --git a/pkg/rpc/request/param_test.go b/pkg/rpc/request/param_test.go index da04ea540..7bf2ae22d 100644 --- a/pkg/rpc/request/param_test.go +++ b/pkg/rpc/request/param_test.go @@ -14,7 +14,7 @@ import ( ) func TestParam_UnmarshalJSON(t *testing.T) { - msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}], + msg := `["str1", 123, null, ["str2", 3], [{"type": "String", "value": "jajaja"}], {"type": "MinerTransaction"}, {"contract": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, {"state": "HALT"}]` @@ -29,6 +29,9 @@ func TestParam_UnmarshalJSON(t *testing.T) { Type: NumberT, Value: 123, }, + { + Type: defaultT, + }, { Type: ArrayT, Value: []Param{ From 53dc7f27b60ea6735e0b414de914abcd69432c08 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Wed, 3 Jun 2020 18:09:36 +0300 Subject: [PATCH 5/9] rpc/server: implement getstateroot RPC --- pkg/rpc/server/server.go | 23 +++++++++++++++++++++++ pkg/rpc/server/server_test.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index e3703398c..b38129bb1 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -95,6 +95,7 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "getpeers": (*Server).getPeers, "getrawmempool": (*Server).getRawMempool, "getrawtransaction": (*Server).getrawtransaction, + "getstateroot": (*Server).getStateRoot, "getstorage": (*Server).getStorage, "gettransactionheight": (*Server).getTransactionHeight, "gettxout": (*Server).getTxOut, @@ -687,6 +688,28 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 return d, nil } +func (s *Server) getStateRoot(ps request.Params) (interface{}, *response.Error) { + p := ps.Value(0) + if p == nil { + return nil, response.NewRPCError("Invalid parameter.", "", nil) + } + var rt *state.MPTRootState + var h util.Uint256 + height, err := p.GetInt() + if err == nil { + rt, err = s.chain.GetStateRoot(uint32(height)) + } else if h, err = p.GetUint256(); err == nil { + hdr, err := s.chain.GetHeader(h) + if err == nil { + rt, err = s.chain.GetStateRoot(hdr.Index) + } + } + if err != nil { + return nil, response.NewRPCError("Unknown state root.", "", err) + } + return rt, nil +} + func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { scriptHash, err := ps.Value(0).GetUint160FromHex() if err != nil { diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 6103156f7..c6d384f80 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -16,6 +16,7 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/address" @@ -213,6 +214,18 @@ var rpcTestCases = map[string][]rpcTestCase{ }, }, }, + "getstateroot": { + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid hash", + params: `["0x1234567890"]`, + fail: true, + }, + }, "getstorage": { { name: "positive", @@ -928,6 +941,25 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] }) } + t.Run("getstateroot", func(t *testing.T) { + testRoot := func(t *testing.T, p string) { + rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getstateroot", "params": [%s]}`, p) + fmt.Println(rpc) + body := doRPCCall(rpc, httpSrv.URL, t) + rawRes := checkErrGetResult(t, body, false) + + res := new(state.MPTRootState) + require.NoError(t, json.Unmarshal(rawRes, res)) + require.NotEqual(t, util.Uint256{}, res.Root) // be sure this test uses valid height + + expected, err := e.chain.GetStateRoot(205) + require.NoError(t, err) + require.Equal(t, expected, res) + } + t.Run("ByHeight", func(t *testing.T) { testRoot(t, strconv.FormatInt(205, 10)) }) + t.Run("ByHash", func(t *testing.T) { testRoot(t, `"`+chain.GetHeaderHash(205).StringLE()+`"`) }) + }) + t.Run("getrawtransaction", func(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[1].Hash() From fe8038e8b70899dab5be94db1b31c564ccd40954 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 4 Jun 2020 11:09:07 +0300 Subject: [PATCH 6/9] rpc/server: implement getstateheight RPC --- pkg/rpc/response/result/mpt.go | 7 +++++++ pkg/rpc/server/server.go | 9 +++++++++ pkg/rpc/server/server_test.go | 15 +++++++++++++++ 3 files changed, 31 insertions(+) create mode 100644 pkg/rpc/response/result/mpt.go diff --git a/pkg/rpc/response/result/mpt.go b/pkg/rpc/response/result/mpt.go new file mode 100644 index 000000000..65224e726 --- /dev/null +++ b/pkg/rpc/response/result/mpt.go @@ -0,0 +1,7 @@ +package result + +// StateHeight is a result of getstateheight RPC. +type StateHeight struct { + BlockHeight uint32 `json:"blockHeight"` + StateHeight uint32 `json:"stateHeight"` +} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index b38129bb1..dcb8de045 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -95,6 +95,7 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "getpeers": (*Server).getPeers, "getrawmempool": (*Server).getRawMempool, "getrawtransaction": (*Server).getrawtransaction, + "getstateheight": (*Server).getStateHeight, "getstateroot": (*Server).getStateRoot, "getstorage": (*Server).getStorage, "gettransactionheight": (*Server).getTransactionHeight, @@ -688,6 +689,14 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 return d, nil } +func (s *Server) getStateHeight(_ request.Params) (interface{}, *response.Error) { + height := s.chain.BlockHeight() + return &result.StateHeight{ + BlockHeight: height, + StateHeight: height, + }, nil +} + func (s *Server) getStateRoot(ps request.Params) (interface{}, *response.Error) { p := ps.Value(0) if p == nil { diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index c6d384f80..f4853c8a8 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -214,6 +214,21 @@ var rpcTestCases = map[string][]rpcTestCase{ }, }, }, + "getstateheight": { + { + name: "positive", + params: `[]`, + result: func(_ *executor) interface{} { return new(result.StateHeight) }, + check: func(t *testing.T, e *executor, res interface{}) { + sh, ok := res.(*result.StateHeight) + require.True(t, ok) + + h := e.chain.BlockHeight() + require.Equal(t, h, sh.BlockHeight) + require.Equal(t, h, sh.StateHeight) + }, + }, + }, "getstateroot": { { name: "no params", From 8cbbddddaf252c65751ae0a1c421a326f4c582d4 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 4 Jun 2020 11:59:22 +0300 Subject: [PATCH 7/9] rpc: implement getproof RPC --- pkg/core/blockchain.go | 7 +++ pkg/core/blockchainer.go | 1 + pkg/network/helper_test.go | 3 ++ pkg/rpc/response/result/mpt.go | 74 +++++++++++++++++++++++++++++ pkg/rpc/response/result/mpt_test.go | 57 ++++++++++++++++++++++ pkg/rpc/server/server.go | 27 +++++++++++ pkg/rpc/server/server_test.go | 22 +++++++++ 7 files changed, 191 insertions(+) create mode 100644 pkg/rpc/response/result/mpt_test.go diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 88bb71db9..d0a308a8c 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -555,6 +556,12 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 { return sf } +// GetStateProof returns proof of having key in the MPT with the specified root. +func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) { + tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store)) + return tr.GetProof(key) +} + // GetStateRoot returns state root for a given height. func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { return bc.dao.GetStateRoot(height) diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index eac6e4edc..db2d11abe 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -39,6 +39,7 @@ type Blockchainer interface { GetNEP5Balances(util.Uint160) *state.NEP5Balances GetValidators(txes ...*transaction.Transaction) ([]*keys.PublicKey, error) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) + GetStateProof(root util.Uint256, key []byte) ([][]byte, error) GetStateRoot(height uint32) (*state.MPTRootState, error) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem GetStorageItems(hash util.Uint160) (map[string]*state.StorageItem, error) diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 157ebdba0..2e2b697ad 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -108,6 +108,9 @@ func (chain testChain) GetEnrollments() ([]*state.Validator, error) { func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) { panic("TODO") } +func (chain testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) { + panic("TODO") +} func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { panic("TODO") } diff --git a/pkg/rpc/response/result/mpt.go b/pkg/rpc/response/result/mpt.go index 65224e726..4473d7ea3 100644 --- a/pkg/rpc/response/result/mpt.go +++ b/pkg/rpc/response/result/mpt.go @@ -1,7 +1,81 @@ package result +import ( + "encoding/hex" + "encoding/json" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + // StateHeight is a result of getstateheight RPC. type StateHeight struct { BlockHeight uint32 `json:"blockHeight"` StateHeight uint32 `json:"stateHeight"` } + +// ProofWithKey represens key-proof pair. +type ProofWithKey struct { + Key []byte + Proof [][]byte +} + +// GetProof is a result of getproof RPC. +type GetProof struct { + Result ProofWithKey `json:"proof"` + Success bool `json:"success"` +} + +// MarshalJSON implements json.Marshaler. +func (p *ProofWithKey) MarshalJSON() ([]byte, error) { + w := io.NewBufBinWriter() + p.EncodeBinary(w.BinWriter) + if w.Err != nil { + return nil, w.Err + } + return []byte(`"` + hex.EncodeToString(w.Bytes()) + `"`), nil +} + +// EncodeBinary implements io.Serializable. +func (p *ProofWithKey) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(p.Key) + w.WriteVarUint(uint64(len(p.Proof))) + for i := range p.Proof { + w.WriteVarBytes(p.Proof[i]) + } +} + +// DecodeBinary implements io.Serializable. +func (p *ProofWithKey) DecodeBinary(r *io.BinReader) { + p.Key = r.ReadVarBytes() + sz := r.ReadVarUint() + for i := uint64(0); i < sz; i++ { + p.Proof = append(p.Proof, r.ReadVarBytes()) + } +} + +// UnmarshalJSON implements json.Unmarshaler. +func (p *ProofWithKey) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + return p.FromString(s) +} + +// String implements fmt.Stringer. +func (p *ProofWithKey) String() string { + w := io.NewBufBinWriter() + p.EncodeBinary(w.BinWriter) + return hex.EncodeToString(w.Bytes()) +} + +// FromString decodes p from hex-encoded string. +func (p *ProofWithKey) FromString(s string) error { + rawProof, err := hex.DecodeString(s) + if err != nil { + return err + } + r := io.NewBinReaderFromBuf(rawProof) + p.DecodeBinary(r) + return r.Err +} diff --git a/pkg/rpc/response/result/mpt_test.go b/pkg/rpc/response/result/mpt_test.go new file mode 100644 index 000000000..3a3497aee --- /dev/null +++ b/pkg/rpc/response/result/mpt_test.go @@ -0,0 +1,57 @@ +package result + +import ( + "encoding/json" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/require" +) + +func testProofWithKey() *ProofWithKey { + return &ProofWithKey{ + Key: random.Bytes(10), + Proof: [][]byte{ + random.Bytes(12), + random.Bytes(0), + random.Bytes(34), + }, + } +} + +func TestGetProof_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + p := &GetProof{ + Result: *testProofWithKey(), + Success: true, + } + testserdes.MarshalUnmarshalJSON(t, p, new(GetProof)) + }) + t.Run("Compatibility", func(t *testing.T) { + js := []byte(`{ + "proof" : "25ddeb9aa1bfc353c9c54e21dffb470f65d9c22a0662616c616e63654f70000000000000000708fd12020020666eaa8a6e75d43a97d76e72b605c7e05189f0c57ec19d84acdb75810f18239d202c83028ce3d7abcf4e4f95d05fbfdfa5e18bde3a8fbb65a57559d6b5ea09425c2090c40d440744a848e3b407a00e4efb692a957245a1efc9cb8496cb05fd328ee620dd2652bf25dfc3ad5fee7b200ccf3e3ae50772ff8ed58907e4dab8e7d4b2489720d8a5d5ed75b5b0f256d0a2cf5c220b4ddae2a228ef0fc0212b689f3811dfa94620342cc0d73fabd2440ed2cc735a9608391a510e1981b321a9f4258682706adc9620ced036e52f39387b9c58ade7bf8c3ca8959b64d8031d36d9b1c62f3f1c51c7cb2031072c7c801b5c1614dae441383a65344acd238f13db28ff0a39c0626e597f002062552d64c616d8b2a6a93d22936055110c0065728aa2b4fbf4d76b108390b474203322d3c93c741674a307cf6455e77c02ceeda307d4ec23fd809a2a420b4243f82052ab92a9cedc6716ad4c66a8a3e423b195b05bdebde456f992bff48f2561e99720e6379995e7053823b8ba8fb8af9623cf48e89f60c989598445df5e711db42a6f20192894ed637e86561ff6a4b8dea4539dee8bddb2fb20bf4ae3499852985c88b120e0005edd09f2335aa6b59ff4723e1262b2192adaa5e3e56f79e662f07041f04c2033577f3e2c5bb0e58746980a07cdfad2f872e2b9a10bcc27b7c678c85576df8420f0f04180d15b6eaa0c43e62380084c75ad773d790700a7120c6c4da1fc51693000fd720100209648e8f10a5ff4c209009b9a09697babbe1b2150d0948c1970a560282a1bfa4720988af8f34859dd8309bffea0b1dff9c8cef0b9b0d6a1852d40786627729ae7be00206ebf4f1b7861bca041cbb8feca75158511ca43a1810d17e1e3017468e8cef0de20cac93064090a7da09f8202c17d1e6cbb9a16eb43afcb032e80719cbf05b3446d2019b76a10b91fb99ec08814e8108e5490b879fb09a190cb2c129dfd98335bd5de000020b1da1198bacacf2adc0d863929d77c285ce3a26e736203d0c0a69a1312255fb2207ee8aa092f49348bd89f9c4bf004b0bee2241a2d0acfe7b3ce08e414b04a5717205b0dda71eac8a4e4cdc6a7b939748c0a78abb54f2547a780e6df67b25530330f000020fc358fb9d1e0d36461e015ac8e35f97072a9f9e750a3c25722a2b1a858fcb82d203c52c9fac6d4694b351390158334a9166bc3478ceb9bea2b0b244915f918239e20d526344a24ff19ee6a9f5c5beb833f4eb6d51191590350e26fa50b138493473f005200000000000000000000002077c404fec0a4265568951dbd096572787d109fab105213f4f292a5f53ce72fca00000020b8d1c7a386eaba83ce83ee0700d4ca9b86e75d147d670ea05123e438231d895000004801250b090a0a010b0f0c0305030c090c05040e02010d0f0f0b0407000f06050d090c02020a0006202af2097cf9d3f42e49f6b3c3dd254e7cbdab3485b029721cbbbf1ad0455a810852000000000000002055170506f4b18bc573a909b51cb21bdd5d303ec511f6cdfb1c6a1ab8d8a1dad020ee774c1b9fe1d8ea8d05823837d959da48af74f384d52f06c42c9d146c5258e300000000000000000072000000204457a6fe530ee953ad1f9caf63daf7f86719c9986df2d0b6917021eb379800f00020406bfc79da4ba6f37452a679d13cca252585d34f7e94a480b047bad9427f233e00000000201ce15a2373d28e0dc5f2000cf308f155d06f72070a29e5af1528c8f05f29d248000000000000004301200601060c0601060e06030605040f0700000000000000000000000000000000072091b83866bbd7450115b462e8d48601af3c3e9a35e7018d2b98a23e107c15c200090307000410a328e800", + "success" : true + }`) + + var p GetProof + require.NoError(t, json.Unmarshal(js, &p)) + require.Equal(t, 8, len(p.Result.Proof)) + for i := range p.Result.Proof { // smoke test that every chunk is correctly encoded node + r := io.NewBinReaderFromBuf(p.Result.Proof[i]) + var n mpt.NodeObject + n.DecodeBinary(r) + require.NoError(t, r.Err) + require.NotNil(t, n.Node) + } + }) +} + +func TestProofWithKey_EncodeString(t *testing.T) { + expected := testProofWithKey() + var actual ProofWithKey + require.NoError(t, actual.FromString(expected.String())) + require.Equal(t, expected, &actual) +} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index dcb8de045..cc1d178bf 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -15,6 +15,7 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -95,6 +96,7 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "getpeers": (*Server).getPeers, "getrawmempool": (*Server).getRawMempool, "getrawtransaction": (*Server).getrawtransaction, + "getproof": (*Server).getProof, "getstateheight": (*Server).getStateHeight, "getstateroot": (*Server).getStateRoot, "getstorage": (*Server).getStorage, @@ -689,6 +691,31 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 return d, nil } +func (s *Server) getProof(ps request.Params) (interface{}, *response.Error) { + root, err := ps.Value(0).GetUint256() + if err != nil { + return nil, response.ErrInvalidParams + } + sc, err := ps.Value(1).GetUint160FromHex() + if err != nil { + return nil, response.ErrInvalidParams + } + sc = sc.Reverse() + key, err := ps.Value(2).GetBytesHex() + if err != nil { + return nil, response.ErrInvalidParams + } + skey := mpt.ToNeoStorageKey(append(sc.BytesBE(), key...)) + proof, err := s.chain.GetStateProof(root, skey) + return &result.GetProof{ + Result: result.ProofWithKey{ + Key: skey, + Proof: proof, + }, + Success: err == nil, + }, nil +} + func (s *Server) getStateHeight(_ request.Params) (interface{}, *response.Error) { height := s.chain.BlockHeight() return &result.StateHeight{ diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index f4853c8a8..bb6d8cc11 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -214,6 +214,28 @@ var rpcTestCases = map[string][]rpcTestCase{ }, }, }, + "getproof": { + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid root", + params: `["0xabcdef"]`, + fail: true, + }, + { + name: "invalid contract", + params: `["0000000000000000000000000000000000000000000000000000000000000000", "0xabcdef"]`, + fail: true, + }, + { + name: "invalid key", + params: `["0000000000000000000000000000000000000000000000000000000000000000", "` + testContractHash + `", "notahex"]`, + fail: true, + }, + }, "getstateheight": { { name: "positive", From 519a98039c53259b5f5d82666159626302bfac06 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 5 Jun 2020 11:51:39 +0300 Subject: [PATCH 8/9] rpc: implement verifyproof RPC Test getproof and verifyproof together. --- pkg/rpc/response/result/mpt.go | 41 +++++++++++++++++++++++++++++ pkg/rpc/response/result/mpt_test.go | 11 ++++++++ pkg/rpc/server/server.go | 28 ++++++++++++++++++++ pkg/rpc/server/server_test.go | 28 ++++++++++++++++++++ 4 files changed, 108 insertions(+) diff --git a/pkg/rpc/response/result/mpt.go b/pkg/rpc/response/result/mpt.go index 4473d7ea3..10ef7e8c3 100644 --- a/pkg/rpc/response/result/mpt.go +++ b/pkg/rpc/response/result/mpt.go @@ -1,8 +1,10 @@ package result import ( + "bytes" "encoding/hex" "encoding/json" + "errors" "github.com/nspcc-dev/neo-go/pkg/io" ) @@ -25,6 +27,12 @@ type GetProof struct { Success bool `json:"success"` } +// VerifyProof is a result of verifyproof RPC. +// nil Value is considered invalid. +type VerifyProof struct { + Value []byte +} + // MarshalJSON implements json.Marshaler. func (p *ProofWithKey) MarshalJSON() ([]byte, error) { w := io.NewBufBinWriter() @@ -79,3 +87,36 @@ func (p *ProofWithKey) FromString(s string) error { p.DecodeBinary(r) return r.Err } + +// MarshalJSON implements json.Marshaler. +func (p *VerifyProof) MarshalJSON() ([]byte, error) { + if p.Value == nil { + return []byte(`"invalid"`), nil + } + return []byte(`{"value":"` + hex.EncodeToString(p.Value) + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (p *VerifyProof) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, []byte(`"invalid"`)) { + p.Value = nil + return nil + } + var m map[string]string + if err := json.Unmarshal(data, &m); err != nil { + return err + } + if len(m) != 1 { + return errors.New("must have single key") + } + v, ok := m["value"] + if !ok { + return errors.New("invalid json") + } + b, err := hex.DecodeString(v) + if err != nil { + return err + } + p.Value = b + return nil +} diff --git a/pkg/rpc/response/result/mpt_test.go b/pkg/rpc/response/result/mpt_test.go index 3a3497aee..22e0c021c 100644 --- a/pkg/rpc/response/result/mpt_test.go +++ b/pkg/rpc/response/result/mpt_test.go @@ -55,3 +55,14 @@ func TestProofWithKey_EncodeString(t *testing.T) { require.NoError(t, actual.FromString(expected.String())) require.Equal(t, expected, &actual) } + +func TestVerifyProof_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + vp := &VerifyProof{random.Bytes(100)} + testserdes.MarshalUnmarshalJSON(t, vp, new(VerifyProof)) + }) + t.Run("NoValue", func(t *testing.T) { + vp := new(VerifyProof) + testserdes.MarshalUnmarshalJSON(t, vp, &VerifyProof{[]byte{1, 2, 3}}) + }) +} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index cc1d178bf..6912c4e24 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -112,6 +112,7 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "sendrawtransaction": (*Server).sendrawtransaction, "submitblock": (*Server).submitBlock, "validateaddress": (*Server).validateAddress, + "verifyproof": (*Server).verifyProof, } var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){ @@ -716,6 +717,33 @@ func (s *Server) getProof(ps request.Params) (interface{}, *response.Error) { }, nil } +func (s *Server) verifyProof(ps request.Params) (interface{}, *response.Error) { + root, err := ps.Value(0).GetUint256() + if err != nil { + return nil, response.ErrInvalidParams + } + proofStr, err := ps.Value(1).GetString() + if err != nil { + return nil, response.ErrInvalidParams + } + var p result.ProofWithKey + if err := p.FromString(proofStr); err != nil { + return nil, response.ErrInvalidParams + } + vp := new(result.VerifyProof) + val, ok := mpt.VerifyProof(root, p.Key, p.Proof) + if ok { + var si state.StorageItem + r := io.NewBinReaderFromBuf(val[1:]) + si.DecodeBinary(r) + if r.Err != nil { + return nil, response.NewInternalServerError("invalid item in trie", r.Err) + } + vp.Value = si.Value + } + return vp, nil +} + func (s *Server) getStateHeight(_ request.Params) (interface{}, *response.Error) { height := s.chain.BlockHeight() return &result.StateHeight{ diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index bb6d8cc11..e53687baf 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -16,6 +16,7 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" @@ -978,6 +979,33 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] }) } + t.Run("getproof", func(t *testing.T) { + r, err := chain.GetStateRoot(205) + require.NoError(t, err) + + rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getproof", "params": ["%s", "%s", "%x"]}`, + r.Root.StringLE(), testContractHash, []byte("testkey")) + fmt.Println(rpc) + body := doRPCCall(rpc, httpSrv.URL, t) + fmt.Println(string(body)) + rawRes := checkErrGetResult(t, body, false) + res := new(result.GetProof) + require.NoError(t, json.Unmarshal(rawRes, res)) + require.True(t, res.Success) + h, _ := hex.DecodeString(testContractHash) + skey := append(h, []byte("testkey")...) + require.Equal(t, mpt.ToNeoStorageKey(skey), res.Result.Key) + require.True(t, len(res.Result.Proof) > 0) + + rpc = fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "verifyproof", "params": ["%s", "%s"]}`, + r.Root.StringLE(), res.Result.String()) + body = doRPCCall(rpc, httpSrv.URL, t) + rawRes = checkErrGetResult(t, body, false) + vp := new(result.VerifyProof) + require.NoError(t, json.Unmarshal(rawRes, vp)) + require.Equal(t, []byte("testvalue"), vp.Value) + }) + t.Run("getstateroot", func(t *testing.T) { testRoot := func(t *testing.T, p string) { rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getstateroot", "params": [%s]}`, p) From 1fd7938fd877b7073f29302f4ef7b79e57fb21aa Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 5 Jun 2020 12:11:22 +0300 Subject: [PATCH 9/9] network: process state roots properly --- pkg/consensus/consensus.go | 12 +++--- pkg/consensus/consensus_test.go | 3 +- pkg/{consensus => core/cache}/cache.go | 27 ++++++------- pkg/{consensus => core/cache}/cache_test.go | 32 ++++++++-------- pkg/network/server.go | 42 +++++++++++++++++---- 5 files changed, 72 insertions(+), 44 deletions(-) rename pkg/{consensus => core/cache}/cache.go (60%) rename pkg/{consensus => core/cache}/cache_test.go (68%) diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 742e8d23f..183280313 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/neo-go/pkg/core" coreb "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/cache" "github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -50,9 +51,9 @@ type service struct { log *zap.Logger // cache is a fifo cache which stores recent payloads. - cache *relayCache + cache *cache.HashCache // txx is a fifo cache which stores miner transactions. - txx *relayCache + txx *cache.HashCache dbft *dbft.DBFT // messages and transactions are channels needed to process // everything in single thread. @@ -71,7 +72,7 @@ type Config struct { Logger *zap.Logger // Broadcast is a callback which is called to notify server // about new consensus payload to sent. - Broadcast func(p *Payload) + Broadcast func(cache.Hashable) // Chain is a core.Blockchainer instance. Chain core.Blockchainer // RequestTx is a callback to which will be called @@ -97,8 +98,8 @@ func NewService(cfg Config) (Service, error) { Config: cfg, log: cfg.Logger, - cache: newFIFOCache(cacheMaxCapacity), - txx: newFIFOCache(cacheMaxCapacity), + cache: cache.NewFIFOCache(cacheMaxCapacity), + txx: cache.NewFIFOCache(cacheMaxCapacity), messages: make(chan Payload, 100), transactions: make(chan *transaction.Transaction, 100), @@ -394,6 +395,7 @@ func (s *service) processBlock(b block.Block) { if err := s.Chain.AddStateRoot(r); err != nil { s.log.Warn("errors while adding state root", zap.Error(err)) } + s.Broadcast(r) } func (s *service) getBlockWitness(_ *coreb.Block) *transaction.Witness { diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 285971622..a5713f7f3 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -7,6 +7,7 @@ import ( "github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/cache" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" @@ -182,7 +183,7 @@ func shouldNotReceive(t *testing.T, ch chan Payload) { func newTestService(t *testing.T) *service { srv, err := NewService(Config{ Logger: zaptest.NewLogger(t), - Broadcast: func(*Payload) {}, + Broadcast: func(cache.Hashable) {}, Chain: newTestChain(t), RequestTx: func(...util.Uint256) {}, Wallet: &wallet.Config{ diff --git a/pkg/consensus/cache.go b/pkg/core/cache/cache.go similarity index 60% rename from pkg/consensus/cache.go rename to pkg/core/cache/cache.go index 4a6853803..962b779ed 100644 --- a/pkg/consensus/cache.go +++ b/pkg/core/cache/cache.go @@ -1,4 +1,4 @@ -package consensus +package cache import ( "container/list" @@ -7,9 +7,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" ) -// relayCache is a payload cache which is used to store +// HashCache is a payload cache which is used to store // last consensus payloads. -type relayCache struct { +type HashCache struct { *sync.RWMutex maxCap int @@ -17,13 +17,14 @@ type relayCache struct { queue *list.List } -// hashable is a type of items which can be stored in the relayCache. -type hashable interface { +// Hashable is a type of items which can be stored in the HashCache. +type Hashable interface { Hash() util.Uint256 } -func newFIFOCache(capacity int) *relayCache { - return &relayCache{ +// NewFIFOCache returns new FIFO cache with the specified capacity. +func NewFIFOCache(capacity int) *HashCache { + return &HashCache{ RWMutex: new(sync.RWMutex), maxCap: capacity, @@ -33,7 +34,7 @@ func newFIFOCache(capacity int) *relayCache { } // Add adds payload into a cache if it doesn't already exist. -func (c *relayCache) Add(p hashable) { +func (c *HashCache) Add(p Hashable) { c.Lock() defer c.Unlock() @@ -45,7 +46,7 @@ func (c *relayCache) Add(p hashable) { if c.queue.Len() >= c.maxCap { first := c.queue.Front() c.queue.Remove(first) - delete(c.elems, first.Value.(hashable).Hash()) + delete(c.elems, first.Value.(Hashable).Hash()) } e := c.queue.PushBack(p) @@ -53,7 +54,7 @@ func (c *relayCache) Add(p hashable) { } // Has checks if an item is already in cache. -func (c *relayCache) Has(h util.Uint256) bool { +func (c *HashCache) Has(h util.Uint256) bool { c.RLock() defer c.RUnlock() @@ -61,13 +62,13 @@ func (c *relayCache) Has(h util.Uint256) bool { } // Get returns payload with the specified hash from cache. -func (c *relayCache) Get(h util.Uint256) hashable { +func (c *HashCache) Get(h util.Uint256) Hashable { c.RLock() defer c.RUnlock() e, ok := c.elems[h] if !ok { - return hashable(nil) + return Hashable(nil) } - return e.Value.(hashable) + return e.Value.(Hashable) } diff --git a/pkg/consensus/cache_test.go b/pkg/core/cache/cache_test.go similarity index 68% rename from pkg/consensus/cache_test.go rename to pkg/core/cache/cache_test.go index cd4ebe5a3..e8288e2d7 100644 --- a/pkg/consensus/cache_test.go +++ b/pkg/core/cache/cache_test.go @@ -1,17 +1,19 @@ -package consensus +package cache import ( + "math/rand" "testing" - "github.com/nspcc-dev/dbft/payload" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" ) func TestRelayCache_Add(t *testing.T) { const capacity = 3 - payloads := getDifferentPayloads(t, capacity+1) - c := newFIFOCache(capacity) + payloads := getDifferentItems(t, capacity+1) + c := NewFIFOCache(capacity) require.Equal(t, 0, c.queue.Len()) require.Equal(t, 0, len(c.elems)) @@ -46,19 +48,15 @@ func TestRelayCache_Add(t *testing.T) { require.Equal(t, nil, c.Get(payloads[1].Hash())) } -func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) { - payloads = make([]Payload, n) - for i := range payloads { - var sign [signatureSize]byte - random.Fill(sign[:]) +type testHashable []byte - payloads[i].message = &message{} - payloads[i].SetValidatorIndex(uint16(i)) - payloads[i].SetType(payload.MessageType(commitType)) - payloads[i].payload = &commit{ - signature: sign, - } +// Hash implements Hashable. +func (h testHashable) Hash() util.Uint256 { return hash.Sha256(h) } + +func getDifferentItems(t *testing.T, n int) []testHashable { + items := make([]testHashable, n) + for i := range items { + items[i] = random.Bytes(rand.Int() % 10) } - - return + return items } diff --git a/pkg/network/server.go b/pkg/network/server.go index a9559eba7..1b0c3076c 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/cache" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -29,6 +30,7 @@ const ( maxBlockBatch = 200 maxAddrsToSend = 200 minPoolCount = 30 + stateRootCacheSize = 100 ) var ( @@ -67,6 +69,7 @@ type ( transactions chan *transaction.Transaction + stateCache cache.HashCache consensusStarted *atomic.Bool log *zap.Logger @@ -99,6 +102,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer, log *zap.Logger) (* unregister: make(chan peerDrop), peers: make(map[Peer]bool), consensusStarted: atomic.NewBool(false), + stateCache: *cache.NewFIFOCache(stateRootCacheSize), log: log, transactions: make(chan *transaction.Transaction, 64), } @@ -470,6 +474,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { cp := s.consensus.GetPayload(h) return cp != nil }, + payload.StateRootType: s.stateCache.Has, } if exists := typExists[inv.Type]; exists != nil { for _, hash := range inv.Hashes { @@ -509,7 +514,10 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { msg = s.MkMsg(CMDBlock, b) } case payload.StateRootType: - return nil // do nothing + r := s.stateCache.Get(hash) + if r != nil { + msg = s.MkMsg(CMDStateRoot, r.(*state.MPTRoot)) + } case payload.ConsensusType: if cp := s.consensus.GetPayload(hash); cp != nil { msg = s.MkMsg(CMDConsensus, cp) @@ -613,12 +621,21 @@ func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { // handleStateRootsCmd processees `roots` request. func (s *Server) handleRootsCmd(rs *payload.StateRoots) error { - return nil // TODO + for i := range rs.Roots { + _ = s.chain.AddStateRoot(&rs.Roots[i]) + } + return nil } // handleStateRootCmd processees `stateroot` request. func (s *Server) handleStateRootCmd(r *state.MPTRoot) error { - return nil // TODO + // we ignore error, because there is nothing wrong if we already have this state root + err := s.chain.AddStateRoot(r) + if err == nil && !s.stateCache.Has(r.Hash()) { + s.stateCache.Add(r) + s.broadcastMessage(s.MkMsg(CMDStateRoot, r)) + } + return nil } // handleConsensusCmd processes received consensus payload. @@ -782,11 +799,20 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return nil } -func (s *Server) handleNewPayload(p *consensus.Payload) { - msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) - // It's high priority because it directly affects consensus process, - // even though it's just an inv. - s.broadcastHPMessage(msg) +func (s *Server) handleNewPayload(item cache.Hashable) { + switch p := item.(type) { + case *consensus.Payload: + msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) + // It's high priority because it directly affects consensus process, + // even though it's just an inv. + s.broadcastHPMessage(msg) + case *state.MPTRoot: + s.stateCache.Add(p) + msg := s.MkMsg(CMDStateRoot, p) + s.broadcastMessage(msg) + default: + s.log.Warn("unknown item type", zap.String("type", fmt.Sprintf("%T", p))) + } } func (s *Server) requestTx(hashes ...util.Uint256) {