diff --git a/config/protocol.unit_testnet.yml b/config/protocol.unit_testnet.yml index f00fa93ff..c21e1c3f0 100644 --- a/config/protocol.unit_testnet.yml +++ b/config/protocol.unit_testnet.yml @@ -49,9 +49,10 @@ ApplicationConfiguration: AttemptConnPeers: 5 MinPeers: 1 RPC: + Address: 127.0.0.1 Enabled: true EnableCORSWorkaround: false - Port: 20332 + Port: 0 # let the system choose port dynamically Prometheus: Enabled: false #since it's not useful for unit tests. Port: 2112 diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 1474b64ee..e12b9d37e 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -41,9 +41,6 @@ type Service interface { OnTransaction(tx *transaction.Transaction) // GetPayload returns Payload with specified hash if it is present in the local cache. GetPayload(h util.Uint256) *Payload - // OnNewBlock notifies consensus service that there is a new block in - // the chain (without explicitly passing it to the service). - OnNewBlock() } type service struct { @@ -61,7 +58,7 @@ type service struct { transactions chan *transaction.Transaction // blockEvents is used to pass a new block event to the consensus // process. - blockEvents chan struct{} + blockEvents chan *coreb.Block lastProposal []util.Uint256 wallet *wallet.Wallet } @@ -73,9 +70,6 @@ type Config struct { // Broadcast is a callback which is called to notify server // about new consensus payload to sent. Broadcast func(p *Payload) - // RelayBlock is a callback that is called to notify server - // about the new block that needs to be broadcasted. - RelayBlock func(b *coreb.Block) // Chain is a core.Blockchainer instance. Chain core.Blockchainer // RequestTx is a callback to which will be called @@ -106,7 +100,7 @@ func NewService(cfg Config) (Service, error) { messages: make(chan Payload, 100), transactions: make(chan *transaction.Transaction, 100), - blockEvents: make(chan struct{}, 1), + blockEvents: make(chan *coreb.Block, 1), } if cfg.Wallet == nil { @@ -163,7 +157,7 @@ var ( func (s *service) Start() { s.dbft.Start() - + s.Chain.SubscribeForBlocks(s.blockEvents) go s.eventLoop() } @@ -203,11 +197,14 @@ func (s *service) eventLoop() { s.dbft.OnReceive(&msg) case tx := <-s.transactions: s.dbft.OnTransaction(tx) - case <-s.blockEvents: - s.log.Debug("new block in the chain", - zap.Uint32("dbft index", s.dbft.BlockIndex), - zap.Uint32("chain index", s.Chain.BlockHeight())) - s.dbft.InitializeConsensus(0) + case b := <-s.blockEvents: + // We also receive our own blocks here, so check for index. + if b.Index >= s.dbft.BlockIndex { + s.log.Debug("new block in the chain", + zap.Uint32("dbft index", s.dbft.BlockIndex), + zap.Uint32("chain index", s.Chain.BlockHeight())) + s.dbft.InitializeConsensus(0) + } } } } @@ -287,20 +284,6 @@ func (s *service) OnTransaction(tx *transaction.Transaction) { } } -// OnNewBlock notifies consensus process that there is a new block in the chain -// and dbft should probably be reinitialized. -func (s *service) OnNewBlock() { - if s.dbft != nil { - // If there is something in the queue already, the second - // consecutive event doesn't make much sense (reinitializing - // dbft twice doesn't improve it in any way). - select { - case s.blockEvents <- struct{}{}: - default: - } - } -} - // GetPayload returns payload stored in cache. func (s *service) GetPayload(h util.Uint256) *Payload { p := s.cache.Get(h) @@ -366,8 +349,6 @@ func (s *service) processBlock(b block.Block) { if _, errget := s.Chain.GetBlock(bb.Hash()); errget != nil { s.log.Warn("error on add block", zap.Error(err)) } - } else { - s.Config.RelayBlock(bb) } } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index c9e82bf3e..243f6bbbf 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -63,7 +63,9 @@ var ( persistInterval = 1 * time.Second ) -// Blockchain represents the blockchain. +// Blockchain represents the blockchain. It maintans internal state representing +// the state of the ledger that can be accessed in various ways and changed by +// adding new blocks or headers. type Blockchain struct { config config.ProtocolConfiguration @@ -122,12 +124,27 @@ type Blockchain struct { log *zap.Logger lastBatch *storage.MemBatch + + // Notification subsystem. + events chan bcEvent + subCh chan interface{} + unsubCh chan interface{} +} + +// bcEvent is an internal event generated by the Blockchain and then +// broadcasted to other parties. It joins the new block and associated +// invocation logs, all the other events visible from outside can be produced +// from this combination. +type bcEvent struct { + block *block.Block + appExecResults []*state.AppExecResult } type headersOpFunc func(headerList *HeaderHashList) // NewBlockchain returns a new blockchain object the will use the -// given Store as its underlying storage. +// given Store as its underlying storage. For it to work correctly you need +// to spawn a goroutine for its Run method after this initialization. func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration, log *zap.Logger) (*Blockchain, error) { if log == nil { return nil, errors.New("empty logger") @@ -163,6 +180,9 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration, log *zap.L memPool: mempool.NewMemPool(cfg.MemPoolSize), keyCache: make(map[util.Uint160]map[string]*keys.PublicKey), log: log, + events: make(chan bcEvent), + subCh: make(chan interface{}), + unsubCh: make(chan interface{}), generationAmount: genAmount, decrementInterval: decrementInterval, @@ -264,7 +284,8 @@ func (bc *Blockchain) init() error { return nil } -// Run runs chain loop. +// Run runs chain loop, it needs to be run as goroutine and executing it is +// critical for correct Blockchain operation. func (bc *Blockchain) Run() { persistTimer := time.NewTimer(persistInterval) defer func() { @@ -277,6 +298,7 @@ func (bc *Blockchain) Run() { } close(bc.runToExitCh) }() + go bc.notificationDispatcher() for { select { case <-bc.stopCh: @@ -296,6 +318,82 @@ func (bc *Blockchain) Run() { } } +// notificationDispatcher manages subscription to events and broadcasts new events. +func (bc *Blockchain) notificationDispatcher() { + var ( + // These are just sets of subscribers, though modelled as maps + // for ease of management (not a lot of subscriptions is really + // expected, but maps are convenient for adding/deleting elements). + blockFeed = make(map[chan<- *block.Block]bool) + txFeed = make(map[chan<- *transaction.Transaction]bool) + notificationFeed = make(map[chan<- *state.NotificationEvent]bool) + executionFeed = make(map[chan<- *state.AppExecResult]bool) + ) + for { + select { + case <-bc.stopCh: + return + case sub := <-bc.subCh: + switch ch := sub.(type) { + case chan<- *block.Block: + blockFeed[ch] = true + case chan<- *transaction.Transaction: + txFeed[ch] = true + case chan<- *state.NotificationEvent: + notificationFeed[ch] = true + case chan<- *state.AppExecResult: + executionFeed[ch] = true + default: + panic(fmt.Sprintf("bad subscription: %T", sub)) + } + case unsub := <-bc.unsubCh: + switch ch := unsub.(type) { + case chan<- *block.Block: + delete(blockFeed, ch) + case chan<- *transaction.Transaction: + delete(txFeed, ch) + case chan<- *state.NotificationEvent: + delete(notificationFeed, ch) + case chan<- *state.AppExecResult: + delete(executionFeed, ch) + default: + panic(fmt.Sprintf("bad unsubscription: %T", unsub)) + } + case event := <-bc.events: + // We don't want to waste time looping through transactions when there are no + // subscribers. + if len(txFeed) != 0 || len(notificationFeed) != 0 || len(executionFeed) != 0 { + var aerIdx int + for _, tx := range event.block.Transactions { + if tx.Type == transaction.InvocationType { + aer := event.appExecResults[aerIdx] + if !aer.TxHash.Equals(tx.Hash()) { + panic("inconsistent application execution results") + } + aerIdx++ + for ch := range executionFeed { + ch <- aer + } + if aer.VMState == "HALT" { + for i := range aer.Events { + for ch := range notificationFeed { + ch <- &aer.Events[i] + } + } + } + } + for ch := range txFeed { + ch <- tx + } + } + } + for ch := range blockFeed { + ch <- event.block + } + } + } +} + // Close stops Blockchain's internal loop, syncs changes to persistent storage // and closes it. The Blockchain is no longer functional after the call to Close. func (bc *Blockchain) Close() { @@ -459,6 +557,7 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 { // and all tests are in place, we can make a more optimized and cleaner implementation. func (bc *Blockchain) storeBlock(block *block.Block) error { cache := dao.NewCached(bc.dao) + appExecResults := make([]*state.AppExecResult, 0, len(block.Transactions)) fee := bc.getSystemFeeAmount(block.PrevHash) for _, tx := range block.Transactions { fee += uint32(bc.SystemFee(tx).IntegralValue()) @@ -712,27 +811,36 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { Stack: v.Estack().ToContractParameters(), Events: systemInterop.notifications, } + appExecResults = append(appExecResults, aer) err = cache.PutAppExecResult(aer) if err != nil { return errors.Wrap(err, "failed to Store notifications") } } } - bc.lock.Lock() - defer bc.lock.Unlock() if bc.config.SaveStorageBatch { bc.lastBatch = cache.DAO.GetBatch() } + bc.lock.Lock() _, err := cache.Persist() if err != nil { + bc.lock.Unlock() return err } bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) - updateBlockHeightMetric(block.Index) bc.memPool.RemoveStale(bc.isTxStillRelevant) + bc.lock.Unlock() + + updateBlockHeightMetric(block.Index) + // Genesis block is stored when Blockchain is not yet running, so there + // is no one to read this event. And it doesn't make much sense as event + // anyway. + if block.Index != 0 { + bc.events <- bcEvent{block, appExecResults} + } return nil } @@ -1179,6 +1287,68 @@ func (bc *Blockchain) GetConfig() config.ProtocolConfiguration { return bc.config } +// SubscribeForBlocks adds given channel to new block event broadcasting, so when +// there is a new block added to the chain you'll receive it via this channel. +// Make sure it's read from regularly as not reading these events might affect +// other Blockchain functions. +func (bc *Blockchain) SubscribeForBlocks(ch chan<- *block.Block) { + bc.subCh <- ch +} + +// SubscribeForTransactions adds given channel to new transaction event +// broadcasting, so when there is a new transaction added to the chain (in a +// block) you'll receive it via this channel. Make sure it's read from regularly +// as not reading these events might affect other Blockchain functions. +func (bc *Blockchain) SubscribeForTransactions(ch chan<- *transaction.Transaction) { + bc.subCh <- ch +} + +// SubscribeForNotifications adds given channel to new notifications event +// broadcasting, so when an in-block transaction execution generates a +// notification you'll receive it via this channel. Only notifications from +// successful transactions are broadcasted, if you're interested in failed +// transactions use SubscribeForExecutions instead. Make sure this channel is +// read from regularly as not reading these events might affect other Blockchain +// functions. +func (bc *Blockchain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) { + bc.subCh <- ch +} + +// SubscribeForExecutions adds given channel to new transaction execution event +// broadcasting, so when an in-block transaction execution happens you'll receive +// the result of it via this channel. Make sure it's read from regularly as not +// reading these events might affect other Blockchain functions. +func (bc *Blockchain) SubscribeForExecutions(ch chan<- *state.AppExecResult) { + bc.subCh <- ch +} + +// UnsubscribeFromBlocks unsubscribes given channel from new block notifications, +// you can close it afterwards. Passing non-subscribed channel is a no-op. +func (bc *Blockchain) UnsubscribeFromBlocks(ch chan<- *block.Block) { + bc.unsubCh <- ch +} + +// UnsubscribeFromTransactions unsubscribes given channel from new transaction +// notifications, you can close it afterwards. Passing non-subscribed channel is +// a no-op. +func (bc *Blockchain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { + bc.unsubCh <- ch +} + +// UnsubscribeFromNotifications unsubscribes given channel from new +// execution-generated notifications, you can close it afterwards. Passing +// non-subscribed channel is a no-op. +func (bc *Blockchain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) { + bc.unsubCh <- ch +} + +// UnsubscribeFromExecutions unsubscribes given channel from new execution +// notifications, you can close it afterwards. Passing non-subscribed channel is +// a no-op. +func (bc *Blockchain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) { + bc.unsubCh <- ch +} + // CalculateClaimable calculates the amount of GAS which can be claimed for a transaction with value. // First return value is GAS generated between startHeight and endHeight. // Second return value is GAS returned from accumulated SystemFees between startHeight and endHeight. diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 0313aa63e..d344671b5 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -2,13 +2,19 @@ package core import ( "testing" + "time" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/emit" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -229,3 +235,108 @@ func TestClose(t *testing.T) { // This should never be executed. assert.Nil(t, t) } + +func TestSubscriptions(t *testing.T) { + // We use buffering here as a substitute for reader goroutines, events + // get queued up and we read them one by one here. + const chBufSize = 16 + blockCh := make(chan *block.Block, chBufSize) + txCh := make(chan *transaction.Transaction, chBufSize) + notificationCh := make(chan *state.NotificationEvent, chBufSize) + executionCh := make(chan *state.AppExecResult, chBufSize) + + bc := newTestChain(t) + bc.SubscribeForBlocks(blockCh) + bc.SubscribeForTransactions(txCh) + bc.SubscribeForNotifications(notificationCh) + bc.SubscribeForExecutions(executionCh) + + assert.Empty(t, notificationCh) + assert.Empty(t, executionCh) + assert.Empty(t, blockCh) + assert.Empty(t, txCh) + + blocks, err := bc.genBlocks(1) + require.NoError(t, err) + assert.Eventually(t, func() bool { return len(blockCh) != 0 && len(txCh) != 0 }, time.Second, 10*time.Millisecond) + assert.Empty(t, notificationCh) + assert.Empty(t, executionCh) + + b := <-blockCh + tx := <-txCh + assert.Equal(t, blocks[0], b) + assert.Equal(t, blocks[0].Transactions[0], tx) + assert.Empty(t, blockCh) + assert.Empty(t, txCh) + + acc0, err := wallet.NewAccountFromWIF(privNetKeys[0]) + require.NoError(t, err) + addr0, err := address.StringToUint160(acc0.Address) + require.NoError(t, err) + + script := io.NewBufBinWriter() + emit.Bytes(script.BinWriter, []byte("yay!")) + emit.Syscall(script.BinWriter, "Neo.Runtime.Notify") + require.NoError(t, script.Err) + txGood1 := transaction.NewInvocationTX(script.Bytes(), 0) + txGood1.AddVerificationHash(addr0) + require.NoError(t, acc0.SignTx(txGood1)) + + // Reset() reuses the script buffer and we need to keep scripts. + script = io.NewBufBinWriter() + emit.Bytes(script.BinWriter, []byte("nay!")) + emit.Syscall(script.BinWriter, "Neo.Runtime.Notify") + emit.Opcode(script.BinWriter, opcode.THROW) + require.NoError(t, script.Err) + txBad := transaction.NewInvocationTX(script.Bytes(), 0) + txBad.AddVerificationHash(addr0) + require.NoError(t, acc0.SignTx(txBad)) + + script = io.NewBufBinWriter() + emit.Bytes(script.BinWriter, []byte("yay! yay! yay!")) + emit.Syscall(script.BinWriter, "Neo.Runtime.Notify") + require.NoError(t, script.Err) + txGood2 := transaction.NewInvocationTX(script.Bytes(), 0) + txGood2.AddVerificationHash(addr0) + require.NoError(t, acc0.SignTx(txGood2)) + + txMiner := newMinerTX() + invBlock := newBlock(bc.config, bc.BlockHeight()+1, bc.CurrentHeaderHash(), txMiner, txGood1, txBad, txGood2) + require.NoError(t, bc.AddBlock(invBlock)) + + require.Eventually(t, func() bool { + return len(blockCh) != 0 && len(txCh) != 0 && + len(notificationCh) != 0 && len(executionCh) != 0 + }, time.Second, 10*time.Millisecond) + + b = <-blockCh + require.Equal(t, invBlock, b) + assert.Empty(t, blockCh) + + // Follow in-block transaction order. + for _, txExpected := range invBlock.Transactions { + tx = <-txCh + require.Equal(t, txExpected, tx) + if txExpected.Type == transaction.InvocationType { + exec := <-executionCh + require.Equal(t, tx.Hash(), exec.TxHash) + if exec.VMState == "HALT" { + notif := <-notificationCh + inv := tx.Data.(*transaction.InvocationTX) + require.Equal(t, hash.Hash160(inv.Script), notif.ScriptHash) + } + } + } + assert.Empty(t, txCh) + assert.Empty(t, notificationCh) + assert.Empty(t, executionCh) + + bc.UnsubscribeFromBlocks(blockCh) + bc.UnsubscribeFromTransactions(txCh) + bc.UnsubscribeFromNotifications(notificationCh) + bc.UnsubscribeFromExecutions(executionCh) + + // Ensure that new blocks are processed correctly after unsubscription. + _, err = bc.genBlocks(2 * chBufSize) + require.NoError(t, err) +} diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index a633bb523..d3e0309de 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -46,6 +46,14 @@ type Blockchainer interface { References(t *transaction.Transaction) ([]transaction.InOut, error) mempool.Feer // fee interface PoolTx(*transaction.Transaction) error + SubscribeForBlocks(ch chan<- *block.Block) + SubscribeForExecutions(ch chan<- *state.AppExecResult) + SubscribeForNotifications(ch chan<- *state.NotificationEvent) + SubscribeForTransactions(ch chan<- *transaction.Transaction) VerifyTx(*transaction.Transaction, *block.Block) error GetMemPool() *mempool.Pool + UnsubscribeFromBlocks(ch chan<- *block.Block) + UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) + UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) + UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) } diff --git a/pkg/core/doc.go b/pkg/core/doc.go new file mode 100644 index 000000000..c459ea132 --- /dev/null +++ b/pkg/core/doc.go @@ -0,0 +1,29 @@ +/* +Package core implements Neo ledger functionality. +It's built around the Blockchain structure that maintains state of the ledger. + +Events + +You can subscribe to Blockchain events using a set of Subscribe and Unsubscribe +methods. These methods accept channels that will be used to send appropriate +events, so you can control buffering. Channels are never closed by Blockchain, +you can close them after unsubscription. + +Unlike RPC-level subscriptions these don't allow event filtering because it +doesn't improve overall efficiency much (when you're using Blockchain you're +in the same process with it and filtering on your side is not that different +from filtering on Blockchain side). + +The same level of ordering guarantees as with RPC subscriptions is provided, +albeit for a set of event channels, so at first transaction execution is +announced via appropriate channels, then followed by notifications generated +during this execution, then followed by transaction announcement and then +followed by block announcement. Transaction announcements are ordered the same +way they're stored in the block. + +Be careful using these subscriptions, this mechanism is not intended to be used +by lots of subscribers and failing to read from event channels can affect +other Blockchain operations. + +*/ +package core diff --git a/pkg/core/helper_test.go b/pkg/core/helper_test.go index f1e8e1657..f6300c492 100644 --- a/pkg/core/helper_test.go +++ b/pkg/core/helper_test.go @@ -59,6 +59,9 @@ func newBlock(cfg config.ProtocolConfiguration, index uint32, prev util.Uint256, witness := transaction.Witness{ VerificationScript: valScript, } + if len(txs) == 0 { + txs = []*transaction.Transaction{newMinerTX()} + } b := &block.Block{ Base: block.Base{ Version: 0, @@ -71,7 +74,10 @@ func newBlock(cfg config.ProtocolConfiguration, index uint32, prev util.Uint256, }, Transactions: txs, } - _ = b.RebuildMerkleRoot() + err := b.RebuildMerkleRoot() + if err != nil { + panic(err) + } invScript := make([]byte, 0) for _, wif := range privNetKeys { diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 25e1fed7d..a719d012d 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -146,10 +146,36 @@ func (chain testChain) PoolTx(*transaction.Transaction) error { panic("TODO") } +func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) { + panic("TODO") +} +func (chain testChain) SubscribeForExecutions(ch chan<- *state.AppExecResult) { + panic("TODO") +} +func (chain testChain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) { + panic("TODO") +} +func (chain testChain) SubscribeForTransactions(ch chan<- *transaction.Transaction) { + panic("TODO") +} + func (chain testChain) VerifyTx(*transaction.Transaction, *block.Block) error { panic("TODO") } +func (chain testChain) UnsubscribeFromBlocks(ch chan<- *block.Block) { + panic("TODO") +} +func (chain testChain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) { + panic("TODO") +} +func (chain testChain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) { + panic("TODO") +} +func (chain testChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { + panic("TODO") +} + type testDiscovery struct{} func (d testDiscovery) BackFill(addrs ...string) {} diff --git a/pkg/network/server.go b/pkg/network/server.go index efd6d5e3f..1836cdf92 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -102,21 +102,17 @@ func NewServer(config ServerConfig, chain core.Blockchainer, log *zap.Logger) (* transactions: make(chan *transaction.Transaction, 64), } s.bQueue = newBlockQueue(maxBlockBatch, chain, log, func(b *block.Block) { - if s.consensusStarted.Load() { - s.consensus.OnNewBlock() - } else { + if !s.consensusStarted.Load() { s.tryStartConsensus() } - s.relayBlock(b) }) srv, err := consensus.NewService(consensus.Config{ - Logger: log, - Broadcast: s.handleNewPayload, - RelayBlock: s.relayBlock, - Chain: chain, - RequestTx: s.requestTx, - Wallet: config.Wallet, + Logger: log, + Broadcast: s.handleNewPayload, + Chain: chain, + RequestTx: s.requestTx, + Wallet: config.Wallet, TimePerBlock: config.TimePerBlock, }) @@ -178,6 +174,7 @@ func (s *Server) Start(errChan chan error) { s.discovery.BackFill(s.Seeds...) go s.broadcastTxLoop() + go s.relayBlocksLoop() go s.bQueue.run() go s.transport.Accept() setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10)) @@ -790,14 +787,25 @@ func (s *Server) broadcastHPMessage(msg *Message) { s.iteratePeersWithSendMsg(msg, Peer.EnqueueHPPacket, nil) } -// relayBlock tells all the other connected nodes about the given block. -func (s *Server) relayBlock(b *block.Block) { - msg := s.MkMsg(CMDInv, payload.NewInventory(payload.BlockType, []util.Uint256{b.Hash()})) - // Filter out nodes that are more current (avoid spamming the network - // during initial sync). - s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { - return p.Handshaked() && p.LastBlockIndex() < b.Index - }) +// relayBlocksLoop subscribes to new blocks in the ledger and broadcasts them +// to the network. Intended to be run as a separate goroutine. +func (s *Server) relayBlocksLoop() { + ch := make(chan *block.Block, 2) // Some buffering to smooth out possible egressing delays. + s.chain.SubscribeForBlocks(ch) + for { + select { + case <-s.quit: + s.chain.UnsubscribeFromBlocks(ch) + return + case b := <-ch: + msg := s.MkMsg(CMDInv, payload.NewInventory(payload.BlockType, []util.Uint256{b.Hash()})) + // Filter out nodes that are more current (avoid spamming the network + // during initial sync). + s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { + return p.Handshaked() && p.LastBlockIndex() < b.Index + }) + } + } } // verifyAndPoolTX verifies the TX and adds it to the local mempool. diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index bdac24816..6774dd421 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -7,8 +7,11 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" + "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" ) // WSClient is a websocket-enabled RPC client that can be used with appropriate @@ -17,12 +20,28 @@ import ( // that is only provided via websockets (like event subscription mechanism). type WSClient struct { Client + // Notifications is a channel that is used to send events received from + // server. Client's code is supposed to be reading from this channel if + // it wants to use subscription mechanism, failing to do so will cause + // WSClient to block even regular requests. This channel is not buffered. + // In case of protocol error or upon connection closure this channel will + // be closed, so make sure to handle this. + Notifications chan Notification + ws *websocket.Conn done chan struct{} - notifications chan *request.In responses chan *response.Raw requests chan *request.Raw shutdown chan struct{} + subscriptions map[string]bool +} + +// Notification represents server-generated notification for client subscriptions. +// Value can be one of block.Block, result.ApplicationLog, result.NotificationEvent +// or transaction.Transaction based on Type. +type Notification struct { + Type response.EventID + Value interface{} } // requestResponse is a combined type for request and response since we can get @@ -59,12 +78,15 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error return nil, err } wsc := &WSClient{ - Client: *cl, - ws: ws, - shutdown: make(chan struct{}), - done: make(chan struct{}), - responses: make(chan *response.Raw), - requests: make(chan *request.Raw), + Client: *cl, + Notifications: make(chan Notification), + + ws: ws, + shutdown: make(chan struct{}), + done: make(chan struct{}), + responses: make(chan *response.Raw), + requests: make(chan *request.Raw), + subscriptions: make(map[string]bool), } go wsc.wsReader() go wsc.wsWriter() @@ -86,6 +108,7 @@ func (c *WSClient) Close() { func (c *WSClient) wsReader() { c.ws.SetReadLimit(wsReadLimit) c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) +readloop: for { rr := new(requestResponse) c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) @@ -95,9 +118,37 @@ func (c *WSClient) wsReader() { break } if rr.RawID == nil && rr.Method != "" { - if c.notifications != nil { - c.notifications <- &rr.In + event, err := response.GetEventIDFromString(rr.Method) + if err != nil { + // Bad event received. + break } + var slice []json.RawMessage + err = json.Unmarshal(rr.RawParams, &slice) + if err != nil || len(slice) != 1 { + // Bad event received. + break + } + var val interface{} + switch event { + case response.BlockEventID: + val = new(block.Block) + case response.TransactionEventID: + val = new(transaction.Transaction) + case response.NotificationEventID: + val = new(result.NotificationEvent) + case response.ExecutionEventID: + val = new(result.ApplicationLog) + default: + // Bad event received. + break readloop + } + err = json.Unmarshal(slice[0], val) + if err != nil || len(slice) != 1 { + // Bad event received. + break + } + c.Notifications <- Notification{event, val} } else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) { resp := new(response.Raw) resp.ID = rr.RawID @@ -112,9 +163,7 @@ func (c *WSClient) wsReader() { } close(c.done) close(c.responses) - if c.notifications != nil { - close(c.notifications) - } + close(c.Notifications) } func (c *WSClient) wsWriter() { @@ -158,3 +207,73 @@ func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { return resp, nil } } + +func (c *WSClient) performSubscription(params request.RawParams) (string, error) { + var resp string + + if err := c.performRequest("subscribe", params, &resp); err != nil { + return "", err + } + c.subscriptions[resp] = true + return resp, nil +} + +func (c *WSClient) performUnsubscription(id string) error { + var resp bool + + if !c.subscriptions[id] { + return errors.New("no subscription with this ID") + } + if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil { + return err + } + if !resp { + return errors.New("unsubscribe method returned false result") + } + delete(c.subscriptions, id) + return nil +} + +// SubscribeForNewBlocks adds subscription for new block events to this instance +// of client. +func (c *WSClient) SubscribeForNewBlocks() (string, error) { + params := request.NewRawParams("block_added") + return c.performSubscription(params) +} + +// SubscribeForNewTransactions adds subscription for new transaction events to +// this instance of client. +func (c *WSClient) SubscribeForNewTransactions() (string, error) { + params := request.NewRawParams("transaction_added") + return c.performSubscription(params) +} + +// SubscribeForExecutionNotifications adds subscription for notifications +// generated during transaction execution to this instance of client. +func (c *WSClient) SubscribeForExecutionNotifications() (string, error) { + params := request.NewRawParams("notification_from_execution") + return c.performSubscription(params) +} + +// SubscribeForTransactionExecutions adds subscription for application execution +// results generated during transaction execution to this instance of client. +func (c *WSClient) SubscribeForTransactionExecutions() (string, error) { + params := request.NewRawParams("transaction_executed") + return c.performSubscription(params) +} + +// Unsubscribe removes subscription for given event stream. +func (c *WSClient) Unsubscribe(id string) error { + return c.performUnsubscription(id) +} + +// UnsubscribeAll removes all active subscriptions of current client. +func (c *WSClient) UnsubscribeAll() error { + for id := range c.subscriptions { + err := c.performUnsubscription(id) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 2a996999a..f747c1710 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -2,8 +2,12 @@ package client import ( "context" + "net/http" + "net/http/httptest" "testing" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/require" ) @@ -14,3 +18,129 @@ func TestWSClientClose(t *testing.T) { require.NoError(t, err) wsc.Close() } + +func TestWSClientSubscription(t *testing.T) { + var cases = map[string]func(*WSClient) (string, error){ + "blocks": (*WSClient).SubscribeForNewBlocks, + "transactions": (*WSClient).SubscribeForNewTransactions, + "notifications": (*WSClient).SubscribeForExecutionNotifications, + "executions": (*WSClient).SubscribeForTransactionExecutions, + } + t.Run("good", func(t *testing.T) { + for name, f := range cases { + t.Run(name, func(t *testing.T) { + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + id, err := f(wsc) + require.NoError(t, err) + require.Equal(t, "55aaff00", id) + }) + } + }) + t.Run("bad", func(t *testing.T) { + for name, f := range cases { + t.Run(name, func(t *testing.T) { + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + _, err = f(wsc) + require.Error(t, err) + }) + } + }) +} + +func TestWSClientUnsubscription(t *testing.T) { + type responseCheck struct { + response string + code func(*testing.T, *WSClient) + } + var cases = map[string]responseCheck{ + "good": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.Unsubscribe("0") + require.NoError(t, err) + }}, + "all": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.UnsubscribeAll() + require.NoError(t, err) + require.Equal(t, 0, len(wsc.subscriptions)) + }}, + "not subscribed": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { + err := wsc.Unsubscribe("0") + require.Error(t, err) + }}, + "error returned": {`{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.Unsubscribe("0") + require.Error(t, err) + }}, + "false returned": {`{"jsonrpc": "2.0", "id": 1, "result": false}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.Unsubscribe("0") + require.Error(t, err) + }}, + } + for name, rc := range cases { + t.Run(name, func(t *testing.T) { + srv := initTestServer(t, rc.response) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + rc.code(t, wsc) + }) + } +} + +func TestWSClientEvents(t *testing.T) { + var ok bool + // Events from RPC server test chain. + var events = []string{ + `{"jsonrpc":"2.0","method":"transaction_executed","params":[{"txid":"0x93670859cc8a42f6ea994869c944879678d33d7501d388f5a446a8c7de147df7","executions":[{"trigger":"Application","contract":"0x0000000000000000000000000000000000000000","vmstate":"HALT","gas_consumed":"1.048","stack":[{"type":"Integer","value":"1"}],"notifications":[{"contract":"0xc2789e5ab9bab828743833965b1df0d5fbcc206f","state":{"type":"Array","value":[{"type":"ByteArray","value":"636f6e74726163742063616c6c"},{"type":"ByteArray","value":"507574"},{"type":"Array","value":[{"type":"ByteArray","value":"746573746b6579"},{"type":"ByteArray","value":"7465737476616c7565"}]}]}}]}]}]}`, + `{"jsonrpc":"2.0","method":"notification_from_execution","params":[{"contract":"0xc2789e5ab9bab828743833965b1df0d5fbcc206f","state":{"type":"Array","value":[{"type":"ByteArray","value":"636f6e74726163742063616c6c"},{"type":"ByteArray","value":"507574"},{"type":"Array","value":[{"type":"ByteArray","value":"746573746b6579"},{"type":"ByteArray","value":"7465737476616c7565"}]}]}}]}`, + `{"jsonrpc":"2.0","method":"transaction_added","params":[{"txid":"0x93670859cc8a42f6ea994869c944879678d33d7501d388f5a446a8c7de147df7","size":60,"type":"InvocationTransaction","version":1,"attributes":[],"vin":[],"vout":[],"scripts":[],"script":"097465737476616c756507746573746b657952c103507574676f20ccfbd5f01d5b9633387428b8bab95a9e78c2"}]}`, + `{"jsonrpc":"2.0","method":"block_added","params":[{"version":0,"previousblockhash":"0x33f3e0e24542b2ec3b6420e6881c31f6460a39a4e733d88f7557cbcc3b5ed560","merkleroot":"0x9d922c5cfd4c8cd1da7a6b2265061998dc438bd0dea7145192e2858155e6c57a","time":1586154525,"height":205,"nonce":1111,"next_consensus":"0xa21e4f7178607089e4fe9fab1300d1f5a3d348be","script":{"invocation":"4047a444a51218ac856f1cbc629f251c7c88187910534d6ba87847c86a9a73ed4951d203fd0a87f3e65657a7259269473896841f65c0a0c8efc79d270d917f4ff640435ee2f073c94a02f0276dfe4465037475e44e1c34c0decb87ec9c2f43edf688059fc4366a41c673d72ba772b4782c39e79f01cb981247353216d52d2df1651140527eb0dfd80a800fdd7ac8fbe68fc9366db2d71655d8ba235525a97a69a7181b1e069b82091be711c25e504a17c3c55eee6e76e6af13cb488fbe35d5c5d025c34041f39a02ebe9bb08be0e4aaa890f447dc9453209bbfb4705d8f2d869c2b55ee2d41dbec2ee476a059d77fb7c26400284328d05aece5f3168b48f1db1c6f7be0b","verification":"532102103a7f7dd016558597f7960d27c516a4394fd968b9e65155eb4b013e4040406e2102a7bc55fe8684e0119768d104ba30795bdcc86619e864add26156723ed185cd622102b3622bf4017bdfe317c58aed5f4c753f206b7db896046fa7d774bbc4bf7f8dc22103d90c07df63e690ce77912e10ab51acc944b66860237b608c4f8f8309e71ee69954ae"},"tx":[{"txid":"0xf9adfde059810f37b3d0686d67f6b29034e0c669537df7e59b40c14a0508b9ed","size":10,"type":"MinerTransaction","version":0,"attributes":[],"vin":[],"vout":[],"scripts":[]},{"txid":"0x93670859cc8a42f6ea994869c944879678d33d7501d388f5a446a8c7de147df7","size":60,"type":"InvocationTransaction","version":1,"attributes":[],"vin":[],"vout":[],"scripts":[],"script":"097465737476616c756507746573746b657952c103507574676f20ccfbd5f01d5b9633387428b8bab95a9e78c2"}]}]}`, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for _, event := range events { + ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + err = ws.WriteMessage(1, []byte(event)) + if err != nil { + break + } + } + ws.Close() + return + } + })) + + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + for range events { + select { + case _, ok = <-wsc.Notifications: + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + require.Equal(t, true, ok) + } + select { + case _, ok = <-wsc.Notifications: + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + // Connection closed by server. + require.Equal(t, false, ok) +} diff --git a/pkg/rpc/response/events.go b/pkg/rpc/response/events.go new file mode 100644 index 000000000..1efba39a5 --- /dev/null +++ b/pkg/rpc/response/events.go @@ -0,0 +1,79 @@ +package response + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +type ( + // EventID represents an event type happening on the chain. + EventID byte +) + +const ( + // InvalidEventID is an invalid event id that is the default value of + // EventID. It's only used as an initial value similar to nil. + InvalidEventID EventID = iota + // BlockEventID is a `block_added` event. + BlockEventID + // TransactionEventID corresponds to `transaction_added` event. + TransactionEventID + // NotificationEventID represents `notification_from_execution` events. + NotificationEventID + // ExecutionEventID is used for `transaction_executed` events. + ExecutionEventID +) + +// String is a good old Stringer implementation. +func (e EventID) String() string { + switch e { + case BlockEventID: + return "block_added" + case TransactionEventID: + return "transaction_added" + case NotificationEventID: + return "notification_from_execution" + case ExecutionEventID: + return "transaction_executed" + default: + return "unknown" + } +} + +// GetEventIDFromString converts input string into an EventID if it's possible. +func GetEventIDFromString(s string) (EventID, error) { + switch s { + case "block_added": + return BlockEventID, nil + case "transaction_added": + return TransactionEventID, nil + case "notification_from_execution": + return NotificationEventID, nil + case "transaction_executed": + return ExecutionEventID, nil + default: + return 255, errors.New("invalid stream name") + } +} + +// MarshalJSON implements json.Marshaler interface. +func (e EventID) MarshalJSON() ([]byte, error) { + return json.Marshal(e.String()) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (e *EventID) UnmarshalJSON(b []byte) error { + var s string + + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + id, err := GetEventIDFromString(s) + if err != nil { + return err + } + *e = id + return nil +} diff --git a/pkg/rpc/response/result/application_log.go b/pkg/rpc/response/result/application_log.go index da436f2f4..ee59fc43d 100644 --- a/pkg/rpc/response/result/application_log.go +++ b/pkg/rpc/response/result/application_log.go @@ -30,16 +30,22 @@ type NotificationEvent struct { Item smartcontract.Parameter `json:"state"` } +// StateEventToResultNotification converts state.NotificationEvent to +// result.NotificationEvent. +func StateEventToResultNotification(event state.NotificationEvent) NotificationEvent { + seen := make(map[vm.StackItem]bool) + item := event.Item.ToContractParameter(seen) + return NotificationEvent{ + Contract: event.ScriptHash, + Item: item, + } +} + // NewApplicationLog creates a new ApplicationLog wrapper. func NewApplicationLog(appExecRes *state.AppExecResult, scriptHash util.Uint160) ApplicationLog { events := make([]NotificationEvent, 0, len(appExecRes.Events)) for _, e := range appExecRes.Events { - seen := make(map[vm.StackItem]bool) - item := e.Item.ToContractParameter(seen) - events = append(events, NotificationEvent{ - Contract: e.ScriptHash, - Item: item, - }) + events = append(events, StateEventToResultNotification(e)) } triggerString := appExecRes.Trigger.String() diff --git a/pkg/rpc/response/types.go b/pkg/rpc/response/types.go index 0b236826a..ba23c7677 100644 --- a/pkg/rpc/response/types.go +++ b/pkg/rpc/response/types.go @@ -37,3 +37,12 @@ type GetRawTx struct { HeaderAndError Result *result.TransactionOutputRaw `json:"result"` } + +// Notification is a type used to represent wire format of events, they're +// special in that they look like requests but they don't have IDs and their +// "method" is actually an event name. +type Notification struct { + JSONRPC string `json:"jsonrpc"` + Event EventID `json:"method"` + Payload []interface{} `json:"params"` +} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 87f43b519..2c2a2da41 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "strconv" + "sync" "time" "github.com/gorilla/websocket" @@ -41,6 +42,19 @@ type ( coreServer *network.Server log *zap.Logger https *http.Server + shutdown chan struct{} + + subsLock sync.RWMutex + subscribers map[*subscriber]bool + subsGroup sync.WaitGroup + blockSubs int + executionSubs int + notificationSubs int + transactionSubs int + blockCh chan *block.Block + executionCh chan *state.AppExecResult + notificationCh chan *state.NotificationEvent + transactionCh chan *transaction.Transaction } ) @@ -56,6 +70,11 @@ const ( // Write deadline. wsWriteLimit = wsPingPeriod / 2 + + // Maximum number of subscribers per Server. Each websocket client is + // treated like subscriber, so technically it's a limit on websocket + // connections. + maxSubscribers = 64 ) var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ @@ -91,6 +110,11 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "validateaddress": (*Server).validateAddress, } +var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){ + "subscribe": (*Server).subscribe, + "unsubscribe": (*Server).unsubscribe, +} + var invalidBlockHeightError = func(index int, height int) *response.Error { return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil) } @@ -119,11 +143,20 @@ func New(chain core.Blockchainer, conf rpc.Config, coreServer *network.Server, l coreServer: coreServer, log: log, https: tlsServer, + shutdown: make(chan struct{}), + + subscribers: make(map[*subscriber]bool), + // These are NOT buffered to preserve original order of events. + blockCh: make(chan *block.Block), + executionCh: make(chan *state.AppExecResult), + notificationCh: make(chan *state.NotificationEvent), + transactionCh: make(chan *transaction.Transaction), } } -// Start creates a new JSON-RPC server -// listening on the configured port. +// Start creates a new JSON-RPC server listening on the configured port. It's +// supposed to be run as a separate goroutine (like http.Server's Serve) and it +// returns its errors via given errChan. func (s *Server) Start(errChan chan error) { if !s.config.Enabled { s.log.Info("RPC server is not enabled") @@ -132,6 +165,7 @@ func (s *Server) Start(errChan chan error) { s.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr)) + go s.handleSubEvents() if cfg := s.config.TLSConfig; cfg.Enabled { s.https.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr)) @@ -154,6 +188,10 @@ func (s *Server) Start(errChan chan error) { // method. func (s *Server) Shutdown() error { var httpsErr error + + // Signal to websocket writer routines and handleSubEvents. + close(s.shutdown) + if s.config.TLSConfig.Enabled { s.log.Info("shutting down rpc-server (https)", zap.String("endpoint", s.https.Addr)) httpsErr = s.https.Shutdown(context.Background()) @@ -161,6 +199,10 @@ func (s *Server) Shutdown() error { s.log.Info("shutting down rpc-server", zap.String("endpoint", s.Addr)) err := s.Server.Shutdown(context.Background()) + + // Wait for handleSubEvents to finish. + <-s.executionCh + if err == nil { return httpsErr } @@ -168,20 +210,40 @@ func (s *Server) Shutdown() error { } func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { + req := request.NewIn() + if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { + // Technically there is a race between this check and + // s.subscribers modification 20 lines below, but it's tiny + // and not really critical to bother with it. Some additional + // clients may sneak in, no big deal. + s.subsLock.RLock() + numOfSubs := len(s.subscribers) + s.subsLock.RUnlock() + if numOfSubs >= maxSubscribers { + s.writeHTTPErrorResponse( + req, + w, + response.NewInternalServerError("websocket users limit reached", nil), + ) + return + } ws, err := upgrader.Upgrade(w, httpRequest, nil) if err != nil { s.log.Info("websocket connection upgrade failed", zap.Error(err)) return } resChan := make(chan response.Raw) - go s.handleWsWrites(ws, resChan) - s.handleWsReads(ws, resChan) + subChan := make(chan *websocket.PreparedMessage, notificationBufSize) + subscr := &subscriber{writer: subChan, ws: ws} + s.subsLock.Lock() + s.subscribers[subscr] = true + s.subsLock.Unlock() + go s.handleWsWrites(ws, resChan, subChan) + s.handleWsReads(ws, resChan, subscr) return } - req := request.NewIn() - if httpRequest.Method != "POST" { s.writeHTTPErrorResponse( req, @@ -199,11 +261,14 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ return } - resp := s.handleRequest(req) + resp := s.handleRequest(req, nil) s.writeHTTPServerResponse(req, w, resp) } -func (s *Server) handleRequest(req *request.In) response.Raw { +func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw { + var res interface{} + var resErr *response.Error + reqParams, err := req.Params() if err != nil { return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err)) @@ -215,20 +280,37 @@ func (s *Server) handleRequest(req *request.In) response.Raw { incCounter(req.Method) + resErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil) handler, ok := rpcHandlers[req.Method] - if !ok { - return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)) + if ok { + res, resErr = handler(s, *reqParams) + } else if sub != nil { + handler, ok := rpcWsHandlers[req.Method] + if ok { + res, resErr = handler(s, *reqParams, sub) + } } - res, resErr := handler(s, *reqParams) return s.packResponseToRaw(req, res, resErr) } -func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) { +func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw, subChan <-chan *websocket.PreparedMessage) { pingTicker := time.NewTicker(wsPingPeriod) defer ws.Close() defer pingTicker.Stop() for { select { + case <-s.shutdown: + // Signal to the reader routine. + ws.Close() + return + case event, ok := <-subChan: + if !ok { + return + } + ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := ws.WritePreparedMessage(event); err != nil { + return + } case res, ok := <-resChan: if !ok { return @@ -246,22 +328,36 @@ func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) } } -func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) { +func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw, subscr *subscriber) { ws.SetReadLimit(wsReadLimit) ws.SetReadDeadline(time.Now().Add(wsPongLimit)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) +requestloop: for { req := new(request.In) err := ws.ReadJSON(req) if err != nil { break } - res := s.handleRequest(req) + res := s.handleRequest(req, subscr) if res.Error != nil { s.logRequestError(req, res.Error) } - resChan <- res + select { + case <-s.shutdown: + break requestloop + case resChan <- res: + } + } + s.subsLock.Lock() + delete(s.subscribers, subscr) + for _, e := range subscr.feeds { + if e != response.InvalidEventID { + s.unsubscribeFromChannel(e) + } + } + s.subsLock.Unlock() close(resChan) ws.Close() } @@ -1024,6 +1120,201 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res return results, resultsErr } +// subscribe handles subscription requests from websocket clients. +func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { + p, ok := reqParams.Value(0) + if !ok { + return nil, response.ErrInvalidParams + } + streamName, err := p.GetString() + if err != nil { + return nil, response.ErrInvalidParams + } + event, err := response.GetEventIDFromString(streamName) + if err != nil { + return nil, response.ErrInvalidParams + } + s.subsLock.Lock() + defer s.subsLock.Unlock() + select { + case <-s.shutdown: + return nil, response.NewInternalServerError("server is shutting down", nil) + default: + } + var id int + for ; id < len(sub.feeds); id++ { + if sub.feeds[id] == response.InvalidEventID { + break + } + } + if id == len(sub.feeds) { + return nil, response.NewInternalServerError("maximum number of subscriptions is reached", nil) + } + sub.feeds[id] = event + s.subscribeToChannel(event) + return strconv.FormatInt(int64(id), 10), nil +} + +// subscribeToChannel subscribes RPC server to appropriate chain events if +// it's not yet subscribed for them. It's supposed to be called with s.subsLock +// taken by the caller. +func (s *Server) subscribeToChannel(event response.EventID) { + switch event { + case response.BlockEventID: + if s.blockSubs == 0 { + s.chain.SubscribeForBlocks(s.blockCh) + } + s.blockSubs++ + case response.TransactionEventID: + if s.transactionSubs == 0 { + s.chain.SubscribeForTransactions(s.transactionCh) + } + s.transactionSubs++ + case response.NotificationEventID: + if s.notificationSubs == 0 { + s.chain.SubscribeForNotifications(s.notificationCh) + } + s.notificationSubs++ + case response.ExecutionEventID: + if s.executionSubs == 0 { + s.chain.SubscribeForExecutions(s.executionCh) + } + s.executionSubs++ + } +} + +// unsubscribe handles unsubscription requests from websocket clients. +func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { + p, ok := reqParams.Value(0) + if !ok { + return nil, response.ErrInvalidParams + } + id, err := p.GetInt() + if err != nil || id < 0 { + return nil, response.ErrInvalidParams + } + s.subsLock.Lock() + defer s.subsLock.Unlock() + if len(sub.feeds) <= id || sub.feeds[id] == response.InvalidEventID { + return nil, response.ErrInvalidParams + } + event := sub.feeds[id] + sub.feeds[id] = response.InvalidEventID + s.unsubscribeFromChannel(event) + return true, nil +} + +// unsubscribeFromChannel unsubscribes RPC server from appropriate chain events +// if there are no other subscribers for it. It's supposed to be called with +// s.subsLock taken by the caller. +func (s *Server) unsubscribeFromChannel(event response.EventID) { + switch event { + case response.BlockEventID: + s.blockSubs-- + if s.blockSubs == 0 { + s.chain.UnsubscribeFromBlocks(s.blockCh) + } + case response.TransactionEventID: + s.transactionSubs-- + if s.transactionSubs == 0 { + s.chain.UnsubscribeFromTransactions(s.transactionCh) + } + case response.NotificationEventID: + s.notificationSubs-- + if s.notificationSubs == 0 { + s.chain.UnsubscribeFromNotifications(s.notificationCh) + } + case response.ExecutionEventID: + s.executionSubs-- + if s.executionSubs == 0 { + s.chain.UnsubscribeFromExecutions(s.executionCh) + } + } +} + +func (s *Server) handleSubEvents() { +chloop: + for { + var resp = response.Notification{ + JSONRPC: request.JSONRPCVersion, + Payload: make([]interface{}, 1), + } + var msg *websocket.PreparedMessage + select { + case <-s.shutdown: + break chloop + case b := <-s.blockCh: + resp.Event = response.BlockEventID + resp.Payload[0] = b + case execution := <-s.executionCh: + resp.Event = response.ExecutionEventID + resp.Payload[0] = result.NewApplicationLog(execution, util.Uint160{}) + case notification := <-s.notificationCh: + resp.Event = response.NotificationEventID + resp.Payload[0] = result.StateEventToResultNotification(*notification) + case tx := <-s.transactionCh: + resp.Event = response.TransactionEventID + resp.Payload[0] = tx + } + s.subsLock.RLock() + subloop: + for sub := range s.subscribers { + for _, subID := range sub.feeds { + if subID == resp.Event { + if msg == nil { + b, err := json.Marshal(resp) + if err != nil { + s.log.Error("failed to marshal notification", + zap.Error(err), + zap.String("type", resp.Event.String())) + break subloop + } + msg, err = websocket.NewPreparedMessage(websocket.TextMessage, b) + if err != nil { + s.log.Error("failed to prepare notification message", + zap.Error(err), + zap.String("type", resp.Event.String())) + break subloop + } + } + sub.writer <- msg + // The message is sent only once per subscriber. + break + } + } + } + s.subsLock.RUnlock() + } + // It's important to do it with lock held because no subscription routine + // should be running concurrently to this one. And even if one is to run + // after unlock, it'll see closed s.shutdown and won't subscribe. + s.subsLock.Lock() + // There might be no subscription in reality, but it's not a problem as + // core.Blockchain allows unsubscribing non-subscribed channels. + s.chain.UnsubscribeFromBlocks(s.blockCh) + s.chain.UnsubscribeFromTransactions(s.transactionCh) + s.chain.UnsubscribeFromNotifications(s.notificationCh) + s.chain.UnsubscribeFromExecutions(s.executionCh) + s.subsLock.Unlock() +drainloop: + for { + select { + case <-s.blockCh: + case <-s.executionCh: + case <-s.notificationCh: + case <-s.transactionCh: + default: + break drainloop + } + } + // It's not required closing these, but since they're drained already + // this is safe and it also allows to give a signal to Shutdown routine. + close(s.blockCh) + close(s.transactionCh) + close(s.notificationCh) + close(s.executionCh) +} + func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) { num, err := param.GetInt() if err != nil { diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index c6ee3167e..61bcc2e0d 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -15,12 +15,11 @@ import ( "github.com/nspcc-dev/neo-go/pkg/network" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" + "go.uber.org/zap" "go.uber.org/zap/zaptest" ) -func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *httptest.Server) { - var nBlocks uint32 - +func getUnitTestChain(t *testing.T) (*core.Blockchain, config.Config, *zap.Logger) { net := config.ModeUnitTestNet configPath := "../../../config" cfg, err := config.Load(configPath, net) @@ -33,6 +32,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *httptest.Serv go chain.Run() + return chain, cfg, logger +} + +func getTestBlocks(t *testing.T) []*block.Block { + blocks := make([]*block.Block, 0) // File "./testdata/testblocks.acc" was generated by function core._ // ("neo-go/pkg/core/helper_test.go"). // To generate new "./testdata/testblocks.acc", follow the steps: @@ -42,25 +46,41 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *httptest.Serv f, err := os.Open("testdata/testblocks.acc") require.Nil(t, err) br := io.NewBinReaderFromIO(f) - nBlocks = br.ReadU32LE() + nBlocks := br.ReadU32LE() require.Nil(t, br.Err) for i := 0; i < int(nBlocks); i++ { _ = br.ReadU32LE() b := &block.Block{} b.DecodeBinary(br) require.Nil(t, br.Err) - require.NoError(t, chain.AddBlock(b)) + blocks = append(blocks, b) } + return blocks +} + +func initClearServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) { + chain, cfg, logger := getUnitTestChain(t) serverConfig := network.NewServerConfig(cfg) server, err := network.NewServer(serverConfig, chain, logger) require.NoError(t, err) rpcServer := New(chain, cfg.ApplicationConfiguration.RPC, server, logger) + errCh := make(chan error, 2) + go rpcServer.Start(errCh) handler := http.HandlerFunc(rpcServer.handleHTTPRequest) srv := httptest.NewServer(handler) - return chain, srv + return chain, &rpcServer, srv +} + +func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) { + chain, rpcServer, srv := initClearServerWithInMemoryChain(t) + + for _, b := range getTestBlocks(t) { + require.NoError(t, chain.AddBlock(b)) + } + return chain, rpcServer, srv } type FeerStub struct{} diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 8c2fb826a..9def78be5 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -896,9 +896,10 @@ func TestRPC(t *testing.T) { // calls. Some tests change the chain state, thus we reinitialize the chain from // scratch here. func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []byte) { - chain, httpSrv := initServerWithInMemoryChain(t) + chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t) defer chain.Close() + defer rpcSrv.Shutdown() e := &executor{chain: chain, httpSrv: httpSrv} for method, cases := range rpcTestCases { diff --git a/pkg/rpc/server/subscription.go b/pkg/rpc/server/subscription.go new file mode 100644 index 000000000..10c9e25ec --- /dev/null +++ b/pkg/rpc/server/subscription.go @@ -0,0 +1,35 @@ +package server + +import ( + "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/rpc/response" +) + +type ( + // subscriber is an event subscriber. + subscriber struct { + writer chan<- *websocket.PreparedMessage + ws *websocket.Conn + + // These work like slots as there is not a lot of them (it's + // cheaper doing it this way rather than creating a map), + // pointing to EventID is an obvious overkill at the moment, but + // that's not for long. + feeds [maxFeeds]response.EventID + } +) + +const ( + // Maximum number of subscriptions per one client. + maxFeeds = 16 + + // This sets notification messages buffer depth, it may seem to be quite + // big, but there is a big gap in speed between internal event processing + // and networking communication that is combined with spiky nature of our + // event generation process, which leads to lots of events generated in + // short time and they will put some pressure to this buffer (consider + // ~500 invocation txs in one block with some notifications). At the same + // time this channel is about sending pointers, so it's doesn't cost + // a lot in terms of memory used. + notificationBufSize = 1024 +) diff --git a/pkg/rpc/server/subscription_test.go b/pkg/rpc/server/subscription_test.go new file mode 100644 index 000000000..bd4fcb792 --- /dev/null +++ b/pkg/rpc/server/subscription_test.go @@ -0,0 +1,227 @@ +package server + +import ( + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/rpc/response" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished *atomic.Bool) { + for { + ws.SetReadDeadline(time.Now().Add(time.Second)) + _, body, err := ws.ReadMessage() + if isFinished.Load() { + require.Error(t, err) + break + } + require.NoError(t, err) + msgCh <- body + } +} + +func callWSGetRaw(t *testing.T, ws *websocket.Conn, msg string, respCh <-chan []byte) *response.Raw { + var resp = new(response.Raw) + + ws.SetWriteDeadline(time.Now().Add(time.Second)) + require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(msg))) + + body := <-respCh + require.NoError(t, json.Unmarshal(body, resp)) + return resp +} + +func getNotification(t *testing.T, respCh <-chan []byte) *response.Notification { + var resp = new(response.Notification) + body := <-respCh + require.NoError(t, json.Unmarshal(body, resp)) + return resp +} + +func initCleanServerAndWSClient(t *testing.T) (*core.Blockchain, *Server, *websocket.Conn, chan []byte, *atomic.Bool) { + chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t) + + dialer := websocket.Dialer{HandshakeTimeout: time.Second} + url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + ws, _, err := dialer.Dial(url, nil) + require.NoError(t, err) + + // Use buffered channel to read server's messages and then read expected + // responses from it. + respMsgs := make(chan []byte, 16) + finishedFlag := atomic.NewBool(false) + go wsReader(t, ws, respMsgs, finishedFlag) + return chain, rpcSrv, ws, respMsgs, finishedFlag +} + +func TestSubscriptions(t *testing.T) { + var subIDs = make([]string, 0) + var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"} + + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + for _, feed := range subFeeds { + var s string + resp := callWSGetRaw(t, c, fmt.Sprintf(`{ + "jsonrpc": "2.0", + "method": "subscribe", + "params": ["%s"], + "id": 1 +}`, feed), respMsgs) + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &s)) + subIDs = append(subIDs, s) + } + + for _, b := range getTestBlocks(t) { + require.NoError(t, chain.AddBlock(b)) + for _, tx := range b.Transactions { + var mayNotify bool + + if tx.Type == transaction.InvocationType { + resp := getNotification(t, respMsgs) + require.Equal(t, response.ExecutionEventID, resp.Event) + mayNotify = true + } + for { + resp := getNotification(t, respMsgs) + if mayNotify && resp.Event == response.NotificationEventID { + continue + } + require.Equal(t, response.TransactionEventID, resp.Event) + break + } + } + resp := getNotification(t, respMsgs) + require.Equal(t, response.BlockEventID, resp.Event) + } + + for _, id := range subIDs { + var b bool + + resp := callWSGetRaw(t, c, fmt.Sprintf(`{ + "jsonrpc": "2.0", + "method": "unsubscribe", + "params": ["%s"], + "id": 1 +}`, id), respMsgs) + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &b)) + require.Equal(t, true, b) + } + finishedFlag.CAS(false, true) + c.Close() +} + +func TestMaxSubscriptions(t *testing.T) { + var subIDs = make([]string, 0) + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + for i := 0; i < maxFeeds+1; i++ { + var s string + resp := callWSGetRaw(t, c, `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added"], "id": 1}`, respMsgs) + if i < maxFeeds { + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &s)) + // Each ID must be unique. + for _, id := range subIDs { + require.NotEqual(t, id, s) + } + subIDs = append(subIDs, s) + } else { + require.NotNil(t, resp.Error) + require.Nil(t, resp.Result) + } + } + + finishedFlag.CAS(false, true) + c.Close() +} + +func TestBadSubUnsub(t *testing.T) { + var subCases = map[string]string{ + "no params": `{"jsonrpc": "2.0", "method": "subscribe", "params": [], "id": 1}`, + "bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`, + "bad (wrong) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`, + } + var unsubCases = map[string]string{ + "no params": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`, + "bad id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["vasiliy"], "id": 1}`, + "not subscribed id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["7"], "id": 1}`, + } + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + testF := func(t *testing.T, cases map[string]string) func(t *testing.T) { + return func(t *testing.T) { + for n, s := range cases { + t.Run(n, func(t *testing.T) { + resp := callWSGetRaw(t, c, s, respMsgs) + require.NotNil(t, resp.Error) + require.Nil(t, resp.Result) + }) + } + } + } + t.Run("subscribe", testF(t, subCases)) + t.Run("unsubscribe", testF(t, unsubCases)) + + finishedFlag.CAS(false, true) + c.Close() +} + +func doSomeWSRequest(t *testing.T, ws *websocket.Conn) { + ws.SetWriteDeadline(time.Now().Add(time.Second)) + // It could be just about anything including invalid request, + // we only care about server handling being active. + require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc": "2.0", "method": "getversion", "params": [], "id": 1}`))) + ws.SetReadDeadline(time.Now().Add(time.Second)) + _, _, err := ws.ReadMessage() + require.NoError(t, err) +} + +func TestWSClientsLimit(t *testing.T) { + chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + dialer := websocket.Dialer{HandshakeTimeout: time.Second} + url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + wss := make([]*websocket.Conn, maxSubscribers) + + for i := 0; i < len(wss)+1; i++ { + ws, _, err := dialer.Dial(url, nil) + if i < maxSubscribers { + require.NoError(t, err) + wss[i] = ws + // Check that it's completely ready. + doSomeWSRequest(t, ws) + } else { + require.Error(t, err) + } + } + // Check connections are still alive (it actually is necessary to add + // some use of wss to keep connections alive). + for i := 0; i < len(wss); i++ { + doSomeWSRequest(t, wss[i]) + } +} diff --git a/pkg/smartcontract/parameter.go b/pkg/smartcontract/parameter.go index 379a75593..5dad09926 100644 --- a/pkg/smartcontract/parameter.go +++ b/pkg/smartcontract/parameter.go @@ -85,6 +85,8 @@ func (p *Parameter) MarshalJSON() ([]byte, error) { case MapType: ppair := p.Value.([]ParameterPair) resultRawValue, resultErr = json.Marshal(ppair) + case InteropInterfaceType: + resultRawValue = []byte("null") default: resultErr = errors.Errorf("Marshaller for type %s not implemented", p.Type) } @@ -166,6 +168,9 @@ func (p *Parameter) UnmarshalJSON(data []byte) (err error) { return } p.Value = h + case InteropInterfaceType: + // stub, ignore value, it can only be null + p.Value = nil default: return errors.Errorf("Unmarshaller for type %s not implemented", p.Type) } diff --git a/pkg/smartcontract/parameter_test.go b/pkg/smartcontract/parameter_test.go index 61d2fd6a6..72ce299a4 100644 --- a/pkg/smartcontract/parameter_test.go +++ b/pkg/smartcontract/parameter_test.go @@ -122,6 +122,13 @@ var marshalJSONTestCases = []struct { }, result: `{"type":"Hash256","value":"0xf037308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d"}`, }, + { + input: Parameter{ + Type: InteropInterfaceType, + Value: nil, + }, + result: `{"type":"InteropInterface","value":null}`, + }, } var marshalJSONErrorCases = []Parameter{ @@ -129,10 +136,6 @@ var marshalJSONErrorCases = []Parameter{ Type: UnknownType, Value: nil, }, - { - Type: InteropInterfaceType, - Value: nil, - }, { Type: IntegerType, Value: math.Inf(1), @@ -252,6 +255,27 @@ var unmarshalJSONTestCases = []struct { }, input: `{"type":"PublicKey","value":"03b3bf1502fbdc05449b506aaf04579724024b06542e49262bfaa3f70e200040a9"}`, }, + { + input: `{"type":"InteropInterface","value":null}`, + result: Parameter{ + Type: InteropInterfaceType, + Value: nil, + }, + }, + { + input: `{"type":"InteropInterface","value":""}`, + result: Parameter{ + Type: InteropInterfaceType, + Value: nil, + }, + }, + { + input: `{"type":"InteropInterface","value":"Hundertwasser"}`, + result: Parameter{ + Type: InteropInterfaceType, + Value: nil, + }, + }, } var unmarshalJSONErrorCases = []string{ @@ -272,8 +296,6 @@ var unmarshalJSONErrorCases = []string{ `{"type": "Map","value": ["key": {}]}`, // incorrect Map value `{"type": "Map","value": ["key": {"type":"String", "value":"qwer"}, "value": {"type":"Boolean"}]}`, // incorrect Map Value value `{"type": "Map","value": ["key": {"type":"String"}, "value": {"type":"Boolean", "value":true}]}`, // incorrect Map Key value - - `{"type": "InteropInterface","value": ""}`, // ununmarshable type } func TestParam_UnmarshalJSON(t *testing.T) { diff --git a/pkg/wallet/account.go b/pkg/wallet/account.go index 273584259..c2892db7f 100644 --- a/pkg/wallet/account.go +++ b/pkg/wallet/account.go @@ -132,6 +132,9 @@ func (a *Account) SignTx(t *transaction.Transaction) error { return errors.New("account is not unlocked") } data := t.GetSignedPart() + if data == nil { + return errors.New("failed to get transaction's signed part") + } sign := a.privateKey.Sign(data) t.Scripts = append(t.Scripts, transaction.Witness{