diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index c404561d0..9aeeeb516 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -179,7 +179,10 @@ func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) // OnPayload handles Payload receive. func (s *service) OnPayload(cp *Payload) { - if s.cache.Has(cp.Hash()) { + if !cp.Verify() { + s.log.Debug("can't verify payload from #%d", cp.validatorIndex) + return + } else if s.cache.Has(cp.Hash()) { return } @@ -224,6 +227,10 @@ func (s *service) broadcast(p payload.ConsensusPayload) { pr.minerTx = *s.txx.Get(pr.transactionHashes[0]).(*transaction.Transaction) } + if err := p.(*Payload).Sign(s.dbft.Priv.(*privateKey)); err != nil { + s.log.Warnf("can't sign consensus payload: %v", err) + } + s.cache.Add(p) s.Config.Broadcast(p.(*Payload)) } diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index e62d53185..12af0e89d 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -27,6 +27,45 @@ func TestNewService(t *testing.T) { require.Equal(t, tx, txx[1]) } +func TestService_OnPayload(t *testing.T) { + srv := newTestService(t) + + priv, _ := getTestValidator(1) + p := new(Payload) + p.SetValidatorIndex(1) + p.SetPayload(&prepareRequest{}) + + // payload is not signed + srv.OnPayload(p) + shouldNotReceive(t, srv.messages) + require.Nil(t, srv.GetPayload(p.Hash())) + + require.NoError(t, p.Sign(priv)) + srv.OnPayload(p) + shouldReceive(t, srv.messages) + require.Equal(t, p, srv.GetPayload(p.Hash())) + + // payload has already been received + srv.OnPayload(p) + shouldNotReceive(t, srv.messages) +} + +func shouldReceive(t *testing.T, ch chan Payload) { + select { + case <-ch: + default: + require.Fail(t, "missing expected message") + } +} + +func shouldNotReceive(t *testing.T, ch chan Payload) { + select { + case <-ch: + require.Fail(t, "unexpected message receive") + default: + } +} + func newTestService(t *testing.T) *service { srv, err := NewService(Config{ Broadcast: func(*Payload) {}, @@ -42,6 +81,38 @@ func newTestService(t *testing.T) *service { return srv.(*service) } +func getTestValidator(i int) (*privateKey, *publicKey) { + var wallet *config.WalletConfig + switch i { + case 0: + wallet = &config.WalletConfig{ + Path: "6PYLmjBYJ4wQTCEfqvnznGJwZeW9pfUcV5m5oreHxqryUgqKpTRAFt9L8Y", + Password: "one", + } + case 1: + wallet = &config.WalletConfig{ + Path: "6PYXHjPaNvW8YknSXaKsTWjf9FRxo1s4naV2jdmSQEgzaqKGX368rndN3L", + Password: "two", + } + case 2: + wallet = &config.WalletConfig{ + Path: "6PYX86vYiHfUbpD95hfN1xgnvcSxy5skxfWYKu3ztjecxk6ikYs2kcWbeh", + Password: "three", + } + case 3: + wallet = &config.WalletConfig{ + Path: "6PYRXVwHSqFSukL3CuXxdQ75VmsKpjeLgQLEjt83FrtHf1gCVphHzdD4nc", + Password: "four", + } + default: + return nil, nil + } + + priv, pub := getKeyPair(wallet) + + return priv.(*privateKey), pub.(*publicKey) +} + func newTestChain(t *testing.T) *core.Blockchain { unitTestNetCfg, err := config.Load("../../config", config.ModeUnitTestNet) require.NoError(t, err)