diff --git a/.docker/privnet-entrypoint.sh b/.docker/privnet-entrypoint.sh index 65f38b80f..d59632169 100755 --- a/.docker/privnet-entrypoint.sh +++ b/.docker/privnet-entrypoint.sh @@ -2,15 +2,11 @@ BIN=/usr/bin/neo-go -if [ -z "$ACC"]; then - ACC=/6000-privnet-blocks.acc.gz -fi - case $@ in "node"*) echo "=> Try to restore blocks before running node" - if test -f $ACC; then - gunzip --stdout /$ACC > /privnet.acc + if [ -n "$ACC" -a -f "$ACC" ]; then + gunzip --stdout "$ACC" > /privnet.acc ${BIN} db restore -p --config-path /config -i /privnet.acc fi ;; diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index bb3a61bf9..07a9d822b 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -20,6 +20,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/wallet" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -63,6 +64,9 @@ type service struct { lastProposal []util.Uint256 wallet *wallet.Wallet network netmode.Magic + // started is a flag set with Start method that runs an event handling + // goroutine. + started *atomic.Bool } // Config is a configuration for consensus services. @@ -104,6 +108,7 @@ func NewService(cfg Config) (Service, error) { transactions: make(chan *transaction.Transaction, 100), blockEvents: make(chan *coreb.Block, 1), network: cfg.Chain.GetConfig().Magic, + started: atomic.NewBool(false), } if cfg.Wallet == nil { @@ -143,6 +148,7 @@ func NewService(cfg Config) (Service, error) { dbft.WithNewCommit(func() payload.Commit { return new(commit) }), dbft.WithNewRecoveryRequest(func() payload.RecoveryRequest { return new(recoveryRequest) }), dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { return new(recoveryMessage) }), + dbft.WithVerifyPrepareRequest(srv.verifyRequest), ) if srv.dbft == nil { @@ -169,9 +175,11 @@ func (s *service) newPayload() payload.ConsensusPayload { } func (s *service) Start() { - s.dbft.Start() - s.Chain.SubscribeForBlocks(s.blockEvents) - go s.eventLoop() + if s.started.CAS(false, true) { + s.dbft.Start() + s.Chain.SubscribeForBlocks(s.blockEvents) + go s.eventLoop() + } } func (s *service) eventLoop() { @@ -267,8 +275,8 @@ func (s *service) OnPayload(cp *Payload) { s.Config.Broadcast(cp) s.cache.Add(cp) - if s.dbft == nil { - log.Debug("dbft is nil") + if s.dbft == nil || !s.started.Load() { + log.Debug("dbft is inactive or not started yet") return } @@ -280,13 +288,6 @@ func (s *service) OnPayload(cp *Payload) { } } - // we use switch here because other payloads could be possibly added in future - switch cp.Type() { - case payload.PrepareRequestType: - req := cp.GetPrepareRequest().(*prepareRequest) - s.lastProposal = req.transactionHashes - } - s.messages <- *cp } @@ -347,6 +348,14 @@ func (s *service) verifyBlock(b block.Block) bool { return true } +func (s *service) verifyRequest(p payload.ConsensusPayload) error { + req := p.GetPrepareRequest().(*prepareRequest) + // Save lastProposal for getVerified(). + s.lastProposal = req.transactionHashes + + return nil +} + func (s *service) processBlock(b block.Block) { bb := &b.(*neoBlock).Block bb.Script = *(s.getBlockWitness(bb)) diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 196bdd5cf..fb4c82278 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -2,6 +2,7 @@ package consensus import ( "testing" + "time" "github.com/nspcc-dev/dbft/block" "github.com/nspcc-dev/dbft/payload" @@ -39,6 +40,7 @@ func TestNewService(t *testing.T) { func TestService_GetVerified(t *testing.T) { srv := newTestService(t) + srv.dbft.Start() var txs []*transaction.Transaction for i := 0; i < 4; i++ { tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) @@ -52,22 +54,30 @@ func TestService_GetVerified(t *testing.T) { hashes := []util.Uint256{txs[0].Hash(), txs[1].Hash(), txs[2].Hash()} - p := new(Payload) - p.message = &message{} - p.SetType(payload.PrepareRequestType) - tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) - tx.Nonce = 999 - p.SetPayload(&prepareRequest{transactionHashes: hashes}) - p.SetValidatorIndex(1) + // Everyone sends a message. + for i := 0; i < 4; i++ { + p := new(Payload) + p.message = &message{} + // One PrepareRequest and three ChangeViews. + if i == 1 { + p.SetType(payload.PrepareRequestType) + p.SetPayload(&prepareRequest{transactionHashes: hashes}) + } else { + p.SetType(payload.ChangeViewType) + p.SetPayload(&changeView{newViewNumber: 1, timestamp: uint32(time.Now().Unix())}) + } + p.SetHeight(1) + p.SetValidatorIndex(uint16(i)) - priv, _ := getTestValidator(1) - require.NoError(t, p.Sign(priv)) + priv, _ := getTestValidator(i) + require.NoError(t, p.Sign(priv)) - srv.OnPayload(p) + // Skip srv.OnPayload, because the service is not really started. + srv.dbft.OnReceive(p) + } + require.Equal(t, uint8(1), srv.dbft.ViewNumber) require.Equal(t, hashes, srv.lastProposal) - srv.dbft.ViewNumber = 1 - t.Run("new transactions will be proposed in case of failure", func(t *testing.T) { txx := srv.getVerifiedTx() require.Equal(t, 1, len(txx), "there is only 1 tx in mempool") @@ -157,6 +167,10 @@ func TestService_getTx(t *testing.T) { func TestService_OnPayload(t *testing.T) { srv := newTestService(t) + // This test directly reads things from srv.messages that normally + // is read by internal goroutine started with Start(). So let's + // pretend we really did start already. + srv.started.Store(true) priv, _ := getTestValidator(1) p := new(Payload) diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index 030db04ab..8e894b884 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -96,7 +96,7 @@ func (p *changeViewCompact) DecodeBinary(r *io.BinReader) { p.ValidatorIndex = r.ReadU16LE() p.OriginalViewNumber = r.ReadB() p.Timestamp = r.ReadU32LE() - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -112,7 +112,7 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) { p.ViewNumber = r.ReadB() p.ValidatorIndex = r.ReadU16LE() r.ReadBytes(p.Signature[:]) - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -126,7 +126,7 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) { // DecodeBinary implements io.Serializable interface. func (p *preparationCompact) DecodeBinary(r *io.BinReader) { p.ValidatorIndex = r.ReadU16LE() - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -234,6 +234,7 @@ func (m *recoveryMessage) GetChangeViews(p payload.ConsensusPayload, validators newViewNumber: cv.OriginalViewNumber + 1, timestamp: cv.Timestamp, }) + c.message.ViewNumber = cv.OriginalViewNumber c.SetValidatorIndex(cv.ValidatorIndex) c.Witness.InvocationScript = cv.InvocationScript c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators) diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 3d359afdd..f2269dff5 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -188,6 +188,7 @@ func TestGetTransaction(t *testing.T) { */ func TestGetClaimable(t *testing.T) { bc := newTestChain(t) + defer bc.Close() bc.generationAmount = []int{4, 3, 2, 1} bc.decrementInterval = 2 diff --git a/pkg/crypto/keys/publickey.go b/pkg/crypto/keys/publickey.go index 1d1c0103a..e4df477e1 100644 --- a/pkg/crypto/keys/publickey.go +++ b/pkg/crypto/keys/publickey.go @@ -1,7 +1,6 @@ package keys import ( - "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/x509" @@ -19,6 +18,9 @@ import ( "github.com/pkg/errors" ) +// coordLen is the number of bytes in serialized X or Y coordinate. +const coordLen = 32 + // PublicKeys is a list of public keys. type PublicKeys []*PublicKey @@ -113,23 +115,49 @@ func NewPublicKeyFromBytes(b []byte) (*PublicKey, error) { return pubKey, nil } -// Bytes returns the byte array representation of the public key. -func (p *PublicKey) Bytes() []byte { +// getBytes serializes X and Y using compressed or uncompressed format. +func (p *PublicKey) getBytes(compressed bool) []byte { if p.IsInfinity() { return []byte{0x00} } - var ( - x = p.X.Bytes() - paddedX = append(bytes.Repeat([]byte{0x00}, 32-len(x)), x...) - prefix = byte(0x03) - ) - - if p.Y.Bit(0) == 0 { - prefix = byte(0x02) + var resLen = 1 + coordLen + if !compressed { + resLen += coordLen } + var res = make([]byte, resLen) + var prefix byte - return append([]byte{prefix}, paddedX...) + xBytes := p.X.Bytes() + copy(res[1+coordLen-len(xBytes):], xBytes) + if compressed { + if p.Y.Bit(0) == 0 { + prefix = 0x02 + } else { + prefix = 0x03 + } + } else { + prefix = 0x04 + yBytes := p.Y.Bytes() + copy(res[1+coordLen+coordLen-len(yBytes):], yBytes) + + } + res[0] = prefix + + return res +} + +// Bytes returns byte array representation of the public key in compressed +// form (33 bytes with 0x02 or 0x03 prefix, except infinity which is always 0). +func (p *PublicKey) Bytes() []byte { + return p.getBytes(true) +} + +// UncompressedBytes returns byte array representation of the public key in +// uncompressed form (65 bytes with 0x04 prefix, except infinity which is +// always 0). +func (p *PublicKey) UncompressedBytes() []byte { + return p.getBytes(false) } // NewPublicKeyFromASN1 returns a NEO PublicKey from the ASN.1 serialized key. diff --git a/pkg/crypto/keys/publickey_test.go b/pkg/crypto/keys/publickey_test.go index a2aa2060b..848808919 100644 --- a/pkg/crypto/keys/publickey_test.go +++ b/pkg/crypto/keys/publickey_test.go @@ -112,10 +112,14 @@ func TestPubkeyToAddress(t *testing.T) { func TestDecodeBytes(t *testing.T) { pubKey := getPubKey(t) - decodedPubKey := &PublicKey{} - err := decodedPubKey.DecodeBytes(pubKey.Bytes()) - require.NoError(t, err) - require.Equal(t, pubKey, decodedPubKey) + var testBytesFunction = func(t *testing.T, bytesFunction func() []byte) { + decodedPubKey := &PublicKey{} + err := decodedPubKey.DecodeBytes(bytesFunction()) + require.NoError(t, err) + require.Equal(t, pubKey, decodedPubKey) + } + t.Run("compressed", func(t *testing.T) { testBytesFunction(t, pubKey.Bytes) }) + t.Run("uncompressed", func(t *testing.T) { testBytesFunction(t, pubKey.UncompressedBytes) }) } func TestSort(t *testing.T) { diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index fd23355a2..b8c935c80 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -8,9 +8,9 @@ import ( "reflect" ) -// maxArraySize is a maximums size of an array which can be decoded. +// MaxArraySize is the maximum size of an array which can be decoded. // It is taken from https://github.com/neo-project/neo/blob/master/neo/IO/Helper.cs#L130 -const maxArraySize = 0x1000000 +const MaxArraySize = 0x1000000 // BinReader is a convenient wrapper around a io.Reader and err object. // Used to simplify error handling when reading into a struct with many fields. @@ -110,7 +110,7 @@ func (r *BinReader) ReadArray(t interface{}, maxSize ...int) { elemType := sliceType.Elem() isPtr := elemType.Kind() == reflect.Ptr - ms := maxArraySize + ms := MaxArraySize if len(maxSize) != 0 { ms = maxSize[0] } @@ -168,8 +168,16 @@ func (r *BinReader) ReadVarUint() uint64 { // ReadVarBytes reads the next set of bytes from the underlying reader. // ReadVarUInt() is used to determine how large that slice is -func (r *BinReader) ReadVarBytes() []byte { +func (r *BinReader) ReadVarBytes(maxSize ...int) []byte { n := r.ReadVarUint() + ms := MaxArraySize + if len(maxSize) != 0 { + ms = maxSize[0] + } + if n > uint64(ms) { + r.Err = fmt.Errorf("byte-slice is too big (%d)", n) + return nil + } b := make([]byte, n) r.ReadBytes(b) return b diff --git a/pkg/io/binaryrw_test.go b/pkg/io/binaryrw_test.go index d5e1cf8c6..fd998d503 100644 --- a/pkg/io/binaryrw_test.go +++ b/pkg/io/binaryrw_test.go @@ -143,6 +143,35 @@ func TestBufBinWriter_Len(t *testing.T) { require.Equal(t, 1, bw.Len()) } +func TestBinReader_ReadVarBytes(t *testing.T) { + buf := make([]byte, 11) + for i := range buf { + buf[i] = byte(i) + } + w := NewBufBinWriter() + w.WriteVarBytes(buf) + require.NoError(t, w.Err) + data := w.Bytes() + + t.Run("NoArguments", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + actual := r.ReadVarBytes() + require.NoError(t, r.Err) + require.Equal(t, buf, actual) + }) + t.Run("Good", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + actual := r.ReadVarBytes(11) + require.NoError(t, r.Err) + require.Equal(t, buf, actual) + }) + t.Run("Bad", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + r.ReadVarBytes(10) + require.Error(t, r.Err) + }) +} + func TestWriterErrHandling(t *testing.T) { var badio = &badRW{} bw := NewBinWriterFromIO(badio) diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 50fed7ccf..09e331d8e 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -183,8 +183,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.BlockFilterT, param.Type) filt, ok := param.Value.(request.BlockFilter) require.Equal(t, true, ok) @@ -198,8 +198,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -214,8 +214,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -231,8 +231,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -247,8 +247,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.NotificationFilterT, param.Type) filt, ok := param.Value.(request.NotificationFilter) require.Equal(t, true, ok) @@ -262,8 +262,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.ExecutionFilterT, param.Type) filt, ok := param.Value.(request.ExecutionFilter) require.Equal(t, true, ok) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 87af906fc..4918645bc 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -69,12 +69,17 @@ const ( Cosigner ) +var errMissingParameter = errors.New("parameter is missing") + func (p Param) String() string { return fmt.Sprintf("%v", p.Value) } // GetString returns string value of the parameter. -func (p Param) GetString() (string, error) { +func (p *Param) GetString() (string, error) { + if p == nil { + return "", errMissingParameter + } str, ok := p.Value.(string) if !ok { return "", errors.New("not a string") @@ -82,8 +87,26 @@ func (p Param) GetString() (string, error) { return str, nil } +// GetBoolean returns boolean value of the parameter. +func (p *Param) GetBoolean() bool { + if p == nil { + return false + } + switch p.Type { + case NumberT: + return p.Value != 0 + case StringT: + return p.Value != "" + default: + return true + } +} + // GetInt returns int value of te parameter. -func (p Param) GetInt() (int, error) { +func (p *Param) GetInt() (int, error) { + if p == nil { + return 0, errMissingParameter + } i, ok := p.Value.(int) if ok { return i, nil @@ -94,7 +117,10 @@ func (p Param) GetInt() (int, error) { } // GetArray returns a slice of Params stored in the parameter. -func (p Param) GetArray() ([]Param, error) { +func (p *Param) GetArray() ([]Param, error) { + if p == nil { + return nil, errMissingParameter + } a, ok := p.Value.([]Param) if !ok { return nil, errors.New("not an array") @@ -103,7 +129,7 @@ func (p Param) GetArray() ([]Param, error) { } // GetUint256 returns Uint256 value of the parameter. -func (p Param) GetUint256() (util.Uint256, error) { +func (p *Param) GetUint256() (util.Uint256, error) { s, err := p.GetString() if err != nil { return util.Uint256{}, err @@ -113,7 +139,7 @@ func (p Param) GetUint256() (util.Uint256, error) { } // GetUint160FromHex returns Uint160 value of the parameter encoded in hex. -func (p Param) GetUint160FromHex() (util.Uint160, error) { +func (p *Param) GetUint160FromHex() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -127,7 +153,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) { // GetUint160FromAddress returns Uint160 value of the parameter that was // supplied as an address. -func (p Param) GetUint160FromAddress() (util.Uint160, error) { +func (p *Param) GetUint160FromAddress() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -137,7 +163,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) { } // GetFuncParam returns current parameter as a function call parameter. -func (p Param) GetFuncParam() (FuncParam, error) { +func (p *Param) GetFuncParam() (FuncParam, error) { + if p == nil { + return FuncParam{}, errMissingParameter + } fp, ok := p.Value.(FuncParam) if !ok { return FuncParam{}, errors.New("not a function parameter") @@ -147,7 +176,7 @@ func (p Param) GetFuncParam() (FuncParam, error) { // GetBytesHex returns []byte value of the parameter if // it is a hex-encoded string. -func (p Param) GetBytesHex() ([]byte, error) { +func (p *Param) GetBytesHex() ([]byte, error) { s, err := p.GetString() if err != nil { return nil, err @@ -214,6 +243,11 @@ func (p *Param) UnmarshalJSON(data []byte) error { {ArrayT, &[]Param{}}, } + if bytes.Equal(data, []byte("null")) { + p.Type = defaultT + return nil + } + for _, cur := range attempts { r := bytes.NewReader(data) jd := json.NewDecoder(r) diff --git a/pkg/rpc/request/param_test.go b/pkg/rpc/request/param_test.go index 92980516a..70a40b41f 100644 --- a/pkg/rpc/request/param_test.go +++ b/pkg/rpc/request/param_test.go @@ -14,7 +14,7 @@ import ( ) func TestParam_UnmarshalJSON(t *testing.T) { - msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}], + msg := `["str1", 123, null, ["str2", 3], [{"type": "String", "value": "jajaja"}], {"primary": 1}, {"sender": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, {"cosigner": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, @@ -36,6 +36,9 @@ func TestParam_UnmarshalJSON(t *testing.T) { Type: NumberT, Value: 123, }, + { + Type: defaultT, + }, { Type: ArrayT, Value: []Param{ diff --git a/pkg/rpc/request/params.go b/pkg/rpc/request/params.go index 8b1945cb1..dd2ac35b9 100644 --- a/pkg/rpc/request/params.go +++ b/pkg/rpc/request/params.go @@ -7,20 +7,19 @@ type ( // Value returns the param struct for the given // index if it exists. -func (p Params) Value(index int) (*Param, bool) { +func (p Params) Value(index int) *Param { if len(p) > index { - return &p[index], true + return &p[index] } - return nil, false + return nil } // ValueWithType returns the param struct at the given index if it // exists and matches the given type. -func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) { - if val, ok := p.Value(index); ok && val.Type == valType { - return val, true +func (p Params) ValueWithType(index int, valType paramType) *Param { + if val := p.Value(index); val != nil && val.Type == valType { + return val } - - return nil, false + return nil } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index d9299574d..36c5e367b 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -386,6 +386,10 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *response.Error) { var hash util.Uint256 + if param == nil { + return hash, response.ErrInvalidParams + } + switch param.Type { case request.StringT: var err error @@ -406,11 +410,7 @@ func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *respon } func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - + param := reqParams.Value(0) hash, respErr := s.blockHashFromParam(param) if respErr != nil { return nil, respErr @@ -421,7 +421,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro return nil, response.NewInternalServerError(fmt.Sprintf("Problem locating block with hash: %s", hash), err) } - if len(reqParams) == 2 && reqParams[1].Value == 1 { + if reqParams.Value(1).GetBoolean() { return result.NewBlock(block, s.chain), nil } writer := io.NewBufBinWriter() @@ -430,8 +430,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro } func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return nil, response.ErrInvalidParams } num, err := s.blockHeightFromParam(param) @@ -472,8 +472,8 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) } func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } return validateAddress(param.Value), nil @@ -481,12 +481,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, *respon // getApplicationLog returns the contract log based on the specified txid. func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - txHash, err := param.GetUint256() + txHash, err := reqParams.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -500,11 +495,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp } func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromHex() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -533,11 +524,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro } func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -636,6 +623,9 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Error) { var result int32 + if param == nil { + return 0, response.ErrInvalidParams + } switch param.Type { case request.StringT: var err error @@ -661,11 +651,7 @@ func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Err } func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { - param, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - id, rErr := s.contractIDFromParam(param) + id, rErr := s.contractIDFromParam(ps.Value(0)) if rErr == response.ErrUnknown { return nil, nil } @@ -673,12 +659,7 @@ func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { return nil, rErr } - param, ok = ps.Value(1) - if !ok { - return nil, response.ErrInvalidParams - } - - key, err := param.GetBytesHex() + key, err := ps.Value(1).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -695,30 +676,17 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp var resultsErr *response.Error var results interface{} - if param0, ok := reqParams.Value(0); !ok { - return nil, response.ErrInvalidParams - } else if txHash, err := param0.GetUint256(); err != nil { + if txHash, err := reqParams.Value(0).GetUint256(); err != nil { resultsErr = response.ErrInvalidParams } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) return nil, response.NewRPCError("Unknown transaction", err.Error(), err) - } else if len(reqParams) >= 2 { + } else if reqParams.Value(1).GetBoolean() { _header := s.chain.GetHeaderHash(int(height)) header, err := s.chain.GetHeader(_header) if err != nil { resultsErr = response.NewInvalidParamsError(err.Error(), err) - } - - param1, _ := reqParams.Value(1) - switch v := param1.Value.(type) { - - case int, float64, bool, string: - if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" { - results = hex.EncodeToString(tx.Bytes()) - } else { - results = result.NewTransactionOutputRaw(tx, header, s.chain) - } - default: + } else { results = result.NewTransactionOutputRaw(tx, header, s.chain) } } else { @@ -729,12 +697,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp } func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -749,27 +712,21 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response // getContractState returns contract state (contract information, according to the contract script hash). func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { - var results interface{} - - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() + if err != nil { return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromHex(); err != nil { - return nil, response.ErrInvalidParams - } else { - cs := s.chain.GetContractState(scriptHash) - if cs == nil { - return nil, response.NewRPCError("Unknown contract", "", nil) - } - results = cs } - return results, nil + cs := s.chain.GetContractState(scriptHash) + if cs == nil { + return nil, response.NewRPCError("Unknown contract", "", nil) + } + return cs, nil } // getBlockSysFee returns the system fees of the block, based on the specified index. func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return 0, response.ErrInvalidParams } @@ -794,27 +751,13 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons // getBlockHeader returns the corresponding block header information according to the specified script hash. func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { - var verbose bool - - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - + param := reqParams.Value(0) hash, respErr := s.blockHashFromParam(param) if respErr != nil { return nil, respErr } - param, ok = reqParams.ValueWithType(1, request.NumberT) - if ok { - v, err := param.GetInt() - if err != nil { - return nil, response.ErrInvalidParams - } - verbose = v != 0 - } - + verbose := reqParams.Value(1).GetBoolean() h, err := s.chain.GetHeader(hash) if err != nil { return nil, response.NewRPCError("unknown block", "", nil) @@ -834,11 +777,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons // getUnclaimedGas returns unclaimed GAS amount of the specified address. func (s *Server) getUnclaimedGas(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -876,11 +815,7 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) // invokeFunction implements the `invokeFunction` RPC call. func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -941,11 +876,7 @@ func (s *Server) runScriptInVM(script []byte, tx *transaction.Transaction) *resu // submitBlock broadcasts a raw block over the NEO network. func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - blockBytes, err := param.GetBytesHex() + blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1004,11 +935,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res // subscribe handles subscription requests from websocket clients. func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - streamName, err := p.GetString() + streamName, err := reqParams.Value(0).GetString() if err != nil { return nil, response.ErrInvalidParams } @@ -1018,8 +945,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface } // Optional filter. var filter interface{} - p, ok = reqParams.Value(1) - if ok { + if p := reqParams.Value(1); p != nil { switch event { case response.BlockEventID: if p.Type != request.BlockFilterT { @@ -1093,11 +1019,7 @@ func (s *Server) subscribeToChannel(event response.EventID) { // unsubscribe handles unsubscription requests from websocket clients. func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - id, err := p.GetInt() + id, err := reqParams.Value(0).GetInt() if err != nil || id < 0 { return nil, response.ErrInvalidParams }