diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index fa135d362..4a17d52fc 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -23,6 +23,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/wallet" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -65,6 +66,9 @@ type service struct { blockEvents chan *coreb.Block lastProposal []util.Uint256 wallet *wallet.Wallet + // started is a flag set with Start method that runs an event handling + // goroutine. + started *atomic.Bool } // Config is a configuration for consensus services. @@ -105,6 +109,7 @@ func NewService(cfg Config) (Service, error) { transactions: make(chan *transaction.Transaction, 100), blockEvents: make(chan *coreb.Block, 1), + started: atomic.NewBool(false), } if cfg.Wallet == nil { @@ -160,9 +165,11 @@ var ( ) func (s *service) Start() { - s.dbft.Start() - s.Chain.SubscribeForBlocks(s.blockEvents) - go s.eventLoop() + if s.started.CAS(false, true) { + s.dbft.Start() + s.Chain.SubscribeForBlocks(s.blockEvents) + go s.eventLoop() + } } func (s *service) eventLoop() { @@ -305,8 +312,8 @@ func (s *service) OnPayload(cp *Payload) { s.Config.Broadcast(cp) s.cache.Add(cp) - if s.dbft == nil { - log.Debug("dbft is nil") + if s.dbft == nil || !s.started.Load() { + log.Debug("dbft is inactive or not started yet") return } @@ -318,14 +325,6 @@ func (s *service) OnPayload(cp *Payload) { } } - // we use switch here because other payloads could be possibly added in future - switch cp.Type() { - case payload.PrepareRequestType: - req := cp.GetPrepareRequest().(*prepareRequest) - s.txx.Add(&req.minerTx) - s.lastProposal = req.transactionHashes - } - s.messages <- *cp } @@ -391,17 +390,20 @@ func (s *service) verifyBlock(b block.Block) bool { } func (s *service) verifyRequest(p payload.ConsensusPayload) error { - if !s.stateRootEnabled() { - return nil - } - r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1) - if err != nil { - return fmt.Errorf("can't get local state root: %v", err) - } - rb := &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot - if !r.Equals(rb) { - return errors.New("state root mismatch") + req := p.GetPrepareRequest().(*prepareRequest) + if s.stateRootEnabled() { + r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1) + if err != nil { + return fmt.Errorf("can't get local state root: %v", err) + } + if !r.Equals(&req.proposalStateRoot) { + return errors.New("state root mismatch") + } } + // Save lastProposal for getVerified(). + s.txx.Add(&req.minerTx) + s.lastProposal = req.transactionHashes + return nil } diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index e15c6ccf7..4a74e9105 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -2,6 +2,7 @@ package consensus import ( "testing" + "time" "github.com/nspcc-dev/dbft/block" "github.com/nspcc-dev/dbft/payload" @@ -34,6 +35,8 @@ func TestNewService(t *testing.T) { func TestService_GetVerified(t *testing.T) { srv := newTestService(t) + srv.dbft.Start() + txs := []*transaction.Transaction{ newMinerTx(1), newMinerTx(2), @@ -44,20 +47,30 @@ func TestService_GetVerified(t *testing.T) { hashes := []util.Uint256{txs[0].Hash(), txs[1].Hash(), txs[2].Hash()} - p := new(Payload) - p.message = &message{} - p.SetType(payload.PrepareRequestType) - p.SetPayload(&prepareRequest{transactionHashes: hashes, minerTx: *newMinerTx(999)}) - p.SetValidatorIndex(1) + // Everyone sends a message. + for i := 0; i < 4; i++ { + p := new(Payload) + p.message = &message{} + // One PrepareRequest and three ChangeViews. + if i == 1 { + p.SetType(payload.PrepareRequestType) + p.SetPayload(&prepareRequest{transactionHashes: hashes, minerTx: *newMinerTx(999)}) + } else { + p.SetType(payload.ChangeViewType) + p.SetPayload(&changeView{newViewNumber: 1, timestamp: uint32(time.Now().Unix())}) + } + p.SetHeight(1) + p.SetValidatorIndex(uint16(i)) - priv, _ := getTestValidator(1) - require.NoError(t, p.Sign(priv)) + priv, _ := getTestValidator(i) + require.NoError(t, p.Sign(priv)) - srv.OnPayload(p) + // Skip srv.OnPayload, because the service is not really started. + srv.dbft.OnReceive(p) + } + require.Equal(t, uint8(1), srv.dbft.ViewNumber) require.Equal(t, hashes, srv.lastProposal) - srv.dbft.ViewNumber = 1 - t.Run("new transactions will be proposed in case of failure", func(t *testing.T) { txx := srv.getVerifiedTx() require.Equal(t, 2, len(txx), "there is only 1 tx in mempool") @@ -141,6 +154,10 @@ func TestService_getTx(t *testing.T) { func TestService_OnPayload(t *testing.T) { srv := newTestService(t) + // This test directly reads things from srv.messages that normally + // is read by internal goroutine started with Start(). So let's + // pretend we really did start already. + srv.started.Store(true) priv, _ := getTestValidator(1) p := new(Payload)