From 64d24d8ddd39ac5c95cdc8d46631f8074e1a606e Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 16 Dec 2019 11:57:49 +0300 Subject: [PATCH] consensus: verify payloads correctly --- pkg/consensus/consensus.go | 19 +++++++++++++---- pkg/consensus/consensus_test.go | 29 ++++++++++++++++++++++++++ pkg/consensus/payload.go | 26 ++++++++++++++++------- pkg/consensus/payload_test.go | 4 ++-- pkg/consensus/recovery_message_test.go | 15 +++++++++---- 5 files changed, 75 insertions(+), 18 deletions(-) diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 996c30e78..39c7de9e2 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -9,6 +9,7 @@ import ( "github.com/CityOfZion/neo-go/config" "github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core/transaction" + "github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/smartcontract" "github.com/CityOfZion/neo-go/pkg/util" @@ -163,6 +164,19 @@ func (s *service) eventLoop() { } } +func (s *service) validatePayload(p *Payload) bool { + validators := s.getValidators() + if int(p.validatorIndex) >= len(validators) { + return false + } + + pub := validators[p.validatorIndex] + vs := pub.(*publicKey).GetVerificationScript() + h := hash.Hash160(vs) + + return p.Verify(h) +} + func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) { acc, err := wallet.DecryptAccount(cfg.Path, cfg.Password) if err != nil { @@ -179,10 +193,7 @@ func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) // OnPayload handles Payload receive. func (s *service) OnPayload(cp *Payload) { - if !cp.Verify() { - s.log.Debug("can't verify payload from #%d", cp.validatorIndex) - return - } else if s.cache.Has(cp.Hash()) { + if !s.validatePayload(cp) || s.cache.Has(cp.Hash()) { return } diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 12af0e89d..d934906fc 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -27,6 +27,35 @@ func TestNewService(t *testing.T) { require.Equal(t, tx, txx[1]) } +func TestService_ValidatePayload(t *testing.T) { + srv := newTestService(t) + priv, _ := getTestValidator(1) + p := new(Payload) + + p.SetPayload(&prepareRequest{}) + + t.Run("invalid validator index", func(t *testing.T) { + p.SetValidatorIndex(11) + require.NoError(t, p.Sign(priv)) + + var ok bool + require.NotPanics(t, func() { ok = srv.validatePayload(p) }) + require.False(t, ok) + }) + + t.Run("wrong validator index", func(t *testing.T) { + p.SetValidatorIndex(2) + require.NoError(t, p.Sign(priv)) + require.False(t, srv.validatePayload(p)) + }) + + t.Run("normal case", func(t *testing.T) { + p.SetValidatorIndex(1) + require.NoError(t, p.Sign(priv)) + require.True(t, srv.validatePayload(p)) + }) +} + func TestService_OnPayload(t *testing.T) { srv := newTestService(t) diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 0546f5422..9c9e39099 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "fmt" + "github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/io" @@ -195,18 +196,27 @@ func (p *Payload) Sign(key *privateKey) error { } // Verify verifies payload using provided Witness. -func (p *Payload) Verify() bool { - h := sha256.Sum256(p.MarshalUnsigned()) - v := vm.New() - v.SetCheckedHash(h[:]) - v.Load(append(p.Witness.InvocationScript, p.Witness.VerificationScript...)) - if err := v.Run(); err != nil || v.Estack().Len() == 0 { +func (p *Payload) Verify(scriptHash util.Uint160) bool { + verification, err := core.ScriptFromWitness(scriptHash, &p.Witness) + if err != nil { return false } - result, err := v.Estack().Top().TryBool() + v := vm.New() + h := sha256.Sum256(p.MarshalUnsigned()) - return err == nil && result + v.SetCheckedHash(h[:]) + v.LoadScript(verification) + v.LoadScript(p.Witness.InvocationScript) + + err = v.Run() + if err != nil || v.HasFailed() || v.Estack().Len() != 1 { + return false + } + + res, err := v.Estack().Pop().TryBool() + + return err == nil && res } // DecodeBinaryUnsigned reads payload from w excluding signature. diff --git a/pkg/consensus/payload_test.go b/pkg/consensus/payload_test.go index f61391e73..31ce8173f 100644 --- a/pkg/consensus/payload_test.go +++ b/pkg/consensus/payload_test.go @@ -300,9 +300,9 @@ func TestPayload_Sign(t *testing.T) { priv := &privateKey{key} p := randomPayload(t, prepareRequestType) - require.False(t, p.Verify()) + require.False(t, p.Verify(util.Uint160{})) require.NoError(t, p.Sign(priv)) - require.True(t, p.Verify()) + require.True(t, p.Verify(p.Witness.ScriptHash())) } func TestMessageType_String(t *testing.T) { diff --git a/pkg/consensus/recovery_message_test.go b/pkg/consensus/recovery_message_test.go index 061ac1731..ab798eca3 100644 --- a/pkg/consensus/recovery_message_test.go +++ b/pkg/consensus/recovery_message_test.go @@ -55,7 +55,8 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetPrepareResponses(p, pubs) require.Len(t, ps, 1) require.Equal(t, p2, ps[0]) - require.True(t, ps[0].(*Payload).Verify()) + ps0 := ps[0].(*Payload) + require.True(t, ps0.Verify(ps0.Witness.ScriptHash())) }) t.Run("prepare request is added", func(t *testing.T) { @@ -66,7 +67,9 @@ func TestRecoveryMessage_Setters(t *testing.T) { pr = r.GetPrepareRequest(p, pubs, p1.ValidatorIndex()) require.NotNil(t, pr) require.Equal(t, p1, pr) - require.True(t, pr.(*Payload).Verify()) + + pl := pr.(*Payload) + require.True(t, pl.Verify(pl.Witness.ScriptHash())) }) t.Run("change view is added", func(t *testing.T) { @@ -84,7 +87,9 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetChangeViews(p, pubs) require.Len(t, ps, 1) require.Equal(t, p3, ps[0]) - require.True(t, ps[0].(*Payload).Verify()) + + ps0 := ps[0].(*Payload) + require.True(t, ps0.Verify(ps0.Witness.ScriptHash())) }) t.Run("commit is added", func(t *testing.T) { @@ -99,7 +104,9 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetCommits(p, pubs) require.Len(t, ps, 1) require.Equal(t, p4, ps[0]) - require.True(t, ps[0].(*Payload).Verify()) + + ps0 := ps[0].(*Payload) + require.True(t, ps0.Verify(ps0.Witness.ScriptHash())) }) }