diff --git a/pkg/consensus/cache_test.go b/pkg/consensus/cache_test.go index 700b706eb..cd4ebe5a3 100644 --- a/pkg/consensus/cache_test.go +++ b/pkg/consensus/cache_test.go @@ -52,6 +52,7 @@ func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) { var sign [signatureSize]byte random.Fill(sign[:]) + payloads[i].message = &message{} payloads[i].SetValidatorIndex(uint16(i)) payloads[i].SetType(payload.MessageType(commitType)) payloads[i].payload = &commit{ diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index bb09c9837..116dcbedc 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -140,7 +140,7 @@ func NewService(cfg Config) (Service, error) { dbft.WithGetValidators(srv.getValidators), dbft.WithGetConsensusAddress(srv.getConsensusAddress), - dbft.WithNewConsensusPayload(func() payload.ConsensusPayload { return new(Payload) }), + dbft.WithNewConsensusPayload(func() payload.ConsensusPayload { p := new(Payload); p.message = &message{}; return p }), dbft.WithNewPrepareRequest(func() payload.PrepareRequest { return new(prepareRequest) }), dbft.WithNewPrepareResponse(func() payload.PrepareResponse { return new(prepareResponse) }), dbft.WithNewChangeView(func() payload.ChangeView { return new(changeView) }), @@ -245,7 +245,7 @@ func (s *service) getKeyPair(pubs []crypto.PublicKey) (int, crypto.PrivateKey, c // OnPayload handles Payload receive. func (s *service) OnPayload(cp *Payload) { - log := s.log.With(zap.Stringer("hash", cp.Hash()), zap.Stringer("type", cp.Type())) + log := s.log.With(zap.Stringer("hash", cp.Hash())) if !s.validatePayload(cp) { log.Debug("can't validate payload") return @@ -262,6 +262,14 @@ func (s *service) OnPayload(cp *Payload) { return } + // decode payload data into message + if cp.message == nil { + if err := cp.decodeData(); err != nil { + log.Debug("can't decode payload data") + return + } + } + // we use switch here because other payloads could be possibly added in future switch cp.Type() { case payload.PrepareRequestType: diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 640c31399..36813a37d 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -41,6 +41,7 @@ func TestService_GetVerified(t *testing.T) { hashes := []util.Uint256{txs[0].Hash(), txs[1].Hash(), txs[2].Hash()} p := new(Payload) + p.message = &message{} p.SetType(payload.PrepareRequestType) p.SetPayload(&prepareRequest{transactionHashes: hashes, minerTx: *transaction.NewMinerTXWithNonce(999)}) p.SetValidatorIndex(1) @@ -76,6 +77,7 @@ func TestService_ValidatePayload(t *testing.T) { srv := newTestService(t) priv, _ := getTestValidator(1) p := new(Payload) + p.message = &message{} p.SetPayload(&prepareRequest{}) @@ -138,6 +140,7 @@ func TestService_OnPayload(t *testing.T) { priv, _ := getTestValidator(1) p := new(Payload) + p.message = &message{} p.SetValidatorIndex(1) p.SetPayload(&prepareRequest{}) diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 0d9bc8b66..925125fc2 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -27,8 +27,9 @@ type ( // Payload is a type for consensus-related messages. Payload struct { - message + *message + data []byte version uint32 validatorIndex uint16 prevHash util.Uint256 @@ -168,9 +169,12 @@ func (p *Payload) EncodeBinaryUnsigned(w *io.BinWriter) { w.WriteU16LE(p.validatorIndex) w.WriteU32LE(p.timestamp) - ww := io.NewBufBinWriter() - p.message.EncodeBinary(ww.BinWriter) - w.WriteVarBytes(ww.Bytes()) + if p.message != nil { + ww := io.NewBufBinWriter() + p.message.EncodeBinary(ww.BinWriter) + p.data = ww.Bytes() + } + w.WriteVarBytes(p.data) } // EncodeBinary implements io.Serializable interface. @@ -227,14 +231,10 @@ func (p *Payload) DecodeBinaryUnsigned(r *io.BinReader) { p.validatorIndex = r.ReadU16LE() p.timestamp = r.ReadU32LE() - data := r.ReadVarBytes() + p.data = r.ReadVarBytes() if r.Err != nil { return } - - rr := io.NewBinReaderFromBuf(data) - p.message.DecodeBinary(rr) - r.Err = rr.Err } // Hash implements payload.ConsensusPayload interface. @@ -318,3 +318,15 @@ func (t messageType) String() string { return fmt.Sprintf("UNKNOWN(0x%02x)", byte(t)) } } + +// decode data of payload into it's message +func (p *Payload) decodeData() error { + m := new(message) + br := io.NewBinReaderFromBuf(p.data) + m.DecodeBinary(br) + if br.Err != nil { + return errors.Wrap(br.Err, "cannot decode data into message") + } + p.message = m + return nil +} diff --git a/pkg/consensus/payload_test.go b/pkg/consensus/payload_test.go index 949f5a5c5..049038a33 100644 --- a/pkg/consensus/payload_test.go +++ b/pkg/consensus/payload_test.go @@ -28,6 +28,7 @@ var messageTypes = []messageType{ func TestConsensusPayload_Setters(t *testing.T) { var p Payload + p.message = &message{} p.SetVersion(1) assert.EqualValues(t, 1, p.Version()) @@ -86,11 +87,20 @@ func TestConsensusPayload_Hash(t *testing.T) { func TestConsensusPayload_Serializable(t *testing.T) { for _, mt := range messageTypes { p := randomPayload(t, mt) - testserdes.EncodeDecodeBinary(t, p, new(Payload)) + actual := new(Payload) + data, err := testserdes.EncodeBinary(p) + require.NoError(t, err) + require.NoError(t, testserdes.DecodeBinary(data, actual)) + // message is nil after decoding as we didn't yet call decodeData + require.Nil(t, actual.message) + // message should now be decoded from actual.data byte array + assert.NoError(t, actual.decodeData()) + require.Equal(t, p, actual) - data := p.MarshalUnsigned() + data = p.MarshalUnsigned() pu := new(Payload) require.NoError(t, pu.UnmarshalUnsigned(data)) + assert.NoError(t, pu.decodeData()) p.Witness = transaction.Witness{} require.Equal(t, p, pu) @@ -115,7 +125,7 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { buf := make([]byte, 46+1+34+1+2) expected := &Payload{ - message: message{ + message: &message{ Type: prepareResponseType, payload: &prepareResponse{}, }, @@ -124,6 +134,8 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { VerificationScript: []byte{}, }, } + // fill `data` for next check + _ = expected.Hash() // valid payload buf[delimeterIndex] = 1 @@ -131,11 +143,15 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { buf[typeIndex] = byte(prepareResponseType) p := new(Payload) require.NoError(t, testserdes.DecodeBinary(buf, p)) + // decode `data` into `message` + assert.NoError(t, p.decodeData()) require.Equal(t, expected, p) // invalid type buf[typeIndex] = 0xFF - require.Error(t, testserdes.DecodeBinary(buf, new(Payload))) + actual := new(Payload) + require.NoError(t, testserdes.DecodeBinary(buf, actual)) + require.Error(t, actual.decodeData()) // invalid format buf[delimeterIndex] = 0 @@ -176,7 +192,7 @@ func TestRecoveryMessage_Serializable(t *testing.T) { func randomPayload(t *testing.T, mt messageType) *Payload { p := &Payload{ - message: message{ + message: &message{ Type: mt, ViewNumber: byte(rand.Uint32()), payload: randomMessage(t, mt), diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index 8bc30e269..030db04ab 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -285,7 +285,7 @@ func getVerificationScript(i uint16, validators []crypto.PublicKey) []byte { func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload { return &Payload{ - message: message{ + message: &message{ Type: t, ViewNumber: recovery.message.ViewNumber, payload: p, diff --git a/pkg/consensus/recovery_message_test.go b/pkg/consensus/recovery_message_test.go index d6a0fdc94..d5f492d32 100644 --- a/pkg/consensus/recovery_message_test.go +++ b/pkg/consensus/recovery_message_test.go @@ -25,6 +25,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { r := &recoveryMessage{} p := new(Payload) + p.message = &message{} p.SetType(payload.RecoveryMessageType) p.SetPayload(r) // sign payload to have verification script @@ -38,6 +39,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { nextConsensus: util.Uint160{1, 2}, } p1 := new(Payload) + p1.message = &message{} p1.SetType(payload.PrepareRequestType) p1.SetPayload(req) p1.SetValidatorIndex(0) @@ -45,6 +47,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { t.Run("prepare response is added", func(t *testing.T) { p2 := new(Payload) + p2.message = &message{} p2.SetType(payload.PrepareResponseType) p2.SetPayload(&prepareResponse{ preparationHash: p1.Hash(), @@ -70,6 +73,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { r.AddPayload(p1) pr = r.GetPrepareRequest(p, pubs, p1.ValidatorIndex()) require.NotNil(t, pr) + require.Equal(t, p1.Hash(), pr.Hash()) require.Equal(t, p1, pr) pl := pr.(*Payload) @@ -78,6 +82,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { t.Run("change view is added", func(t *testing.T) { p3 := new(Payload) + p3.message = &message{} p3.SetType(payload.ChangeViewType) p3.SetPayload(&changeView{ newViewNumber: 1, @@ -98,6 +103,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { t.Run("commit is added", func(t *testing.T) { p4 := new(Payload) + p4.message = &message{} p4.SetType(payload.CommitType) p4.SetPayload(randomMessage(t, commitType)) p4.SetValidatorIndex(4)