diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 0cff84edc..373b637a9 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -20,6 +20,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/wallet" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -63,6 +64,9 @@ type service struct { lastProposal []util.Uint256 wallet *wallet.Wallet network netmode.Magic + // started is a flag set with Start method that runs an event handling + // goroutine. + started *atomic.Bool } // Config is a configuration for consensus services. @@ -104,6 +108,7 @@ func NewService(cfg Config) (Service, error) { transactions: make(chan *transaction.Transaction, 100), blockEvents: make(chan *coreb.Block, 1), network: cfg.Chain.GetConfig().Magic, + started: atomic.NewBool(false), } if cfg.Wallet == nil { @@ -143,6 +148,7 @@ func NewService(cfg Config) (Service, error) { dbft.WithNewCommit(func() payload.Commit { return new(commit) }), dbft.WithNewRecoveryRequest(func() payload.RecoveryRequest { return new(recoveryRequest) }), dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { return new(recoveryMessage) }), + dbft.WithVerifyPrepareRequest(srv.verifyRequest), ) if srv.dbft == nil { @@ -169,9 +175,11 @@ func (s *service) newPayload() payload.ConsensusPayload { } func (s *service) Start() { - s.dbft.Start() - s.Chain.SubscribeForBlocks(s.blockEvents) - go s.eventLoop() + if s.started.CAS(false, true) { + s.dbft.Start() + s.Chain.SubscribeForBlocks(s.blockEvents) + go s.eventLoop() + } } func (s *service) eventLoop() { @@ -267,8 +275,8 @@ func (s *service) OnPayload(cp *Payload) { s.Config.Broadcast(cp) s.cache.Add(cp) - if s.dbft == nil { - log.Debug("dbft is nil") + if s.dbft == nil || !s.started.Load() { + log.Debug("dbft is inactive or not started yet") return } @@ -280,13 +288,6 @@ func (s *service) OnPayload(cp *Payload) { } } - // we use switch here because other payloads could be possibly added in future - switch cp.Type() { - case payload.PrepareRequestType: - req := cp.GetPrepareRequest().(*prepareRequest) - s.lastProposal = req.transactionHashes - } - s.messages <- *cp } @@ -347,6 +348,14 @@ func (s *service) verifyBlock(b block.Block) bool { return true } +func (s *service) verifyRequest(p payload.ConsensusPayload) error { + req := p.GetPrepareRequest().(*prepareRequest) + // Save lastProposal for getVerified(). + s.lastProposal = req.transactionHashes + + return nil +} + func (s *service) processBlock(b block.Block) { bb := &b.(*neoBlock).Block bb.Script = *(s.getBlockWitness(bb)) diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 196bdd5cf..fb4c82278 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -2,6 +2,7 @@ package consensus import ( "testing" + "time" "github.com/nspcc-dev/dbft/block" "github.com/nspcc-dev/dbft/payload" @@ -39,6 +40,7 @@ func TestNewService(t *testing.T) { func TestService_GetVerified(t *testing.T) { srv := newTestService(t) + srv.dbft.Start() var txs []*transaction.Transaction for i := 0; i < 4; i++ { tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) @@ -52,22 +54,30 @@ func TestService_GetVerified(t *testing.T) { hashes := []util.Uint256{txs[0].Hash(), txs[1].Hash(), txs[2].Hash()} - p := new(Payload) - p.message = &message{} - p.SetType(payload.PrepareRequestType) - tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) - tx.Nonce = 999 - p.SetPayload(&prepareRequest{transactionHashes: hashes}) - p.SetValidatorIndex(1) + // Everyone sends a message. + for i := 0; i < 4; i++ { + p := new(Payload) + p.message = &message{} + // One PrepareRequest and three ChangeViews. + if i == 1 { + p.SetType(payload.PrepareRequestType) + p.SetPayload(&prepareRequest{transactionHashes: hashes}) + } else { + p.SetType(payload.ChangeViewType) + p.SetPayload(&changeView{newViewNumber: 1, timestamp: uint32(time.Now().Unix())}) + } + p.SetHeight(1) + p.SetValidatorIndex(uint16(i)) - priv, _ := getTestValidator(1) - require.NoError(t, p.Sign(priv)) + priv, _ := getTestValidator(i) + require.NoError(t, p.Sign(priv)) - srv.OnPayload(p) + // Skip srv.OnPayload, because the service is not really started. + srv.dbft.OnReceive(p) + } + require.Equal(t, uint8(1), srv.dbft.ViewNumber) require.Equal(t, hashes, srv.lastProposal) - srv.dbft.ViewNumber = 1 - t.Run("new transactions will be proposed in case of failure", func(t *testing.T) { txx := srv.getVerifiedTx() require.Equal(t, 1, len(txx), "there is only 1 tx in mempool") @@ -157,6 +167,10 @@ func TestService_getTx(t *testing.T) { func TestService_OnPayload(t *testing.T) { srv := newTestService(t) + // This test directly reads things from srv.messages that normally + // is read by internal goroutine started with Start(). So let's + // pretend we really did start already. + srv.started.Store(true) priv, _ := getTestValidator(1) p := new(Payload)