consensus: verify payloads correctly

This commit is contained in:
Evgenii Stratonikov 2019-12-16 11:57:49 +03:00
parent 714c466c2c
commit 64d24d8ddd
5 changed files with 75 additions and 18 deletions

View file

@ -9,6 +9,7 @@ import (
"github.com/CityOfZion/neo-go/config" "github.com/CityOfZion/neo-go/config"
"github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/crypto/hash"
"github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/crypto/keys"
"github.com/CityOfZion/neo-go/pkg/smartcontract" "github.com/CityOfZion/neo-go/pkg/smartcontract"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
@ -163,6 +164,19 @@ func (s *service) eventLoop() {
} }
} }
func (s *service) validatePayload(p *Payload) bool {
validators := s.getValidators()
if int(p.validatorIndex) >= len(validators) {
return false
}
pub := validators[p.validatorIndex]
vs := pub.(*publicKey).GetVerificationScript()
h := hash.Hash160(vs)
return p.Verify(h)
}
func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) { func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) {
acc, err := wallet.DecryptAccount(cfg.Path, cfg.Password) acc, err := wallet.DecryptAccount(cfg.Path, cfg.Password)
if err != nil { if err != nil {
@ -179,10 +193,7 @@ func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey)
// OnPayload handles Payload receive. // OnPayload handles Payload receive.
func (s *service) OnPayload(cp *Payload) { func (s *service) OnPayload(cp *Payload) {
if !cp.Verify() { if !s.validatePayload(cp) || s.cache.Has(cp.Hash()) {
s.log.Debug("can't verify payload from #%d", cp.validatorIndex)
return
} else if s.cache.Has(cp.Hash()) {
return return
} }

View file

@ -27,6 +27,35 @@ func TestNewService(t *testing.T) {
require.Equal(t, tx, txx[1]) require.Equal(t, tx, txx[1])
} }
func TestService_ValidatePayload(t *testing.T) {
srv := newTestService(t)
priv, _ := getTestValidator(1)
p := new(Payload)
p.SetPayload(&prepareRequest{})
t.Run("invalid validator index", func(t *testing.T) {
p.SetValidatorIndex(11)
require.NoError(t, p.Sign(priv))
var ok bool
require.NotPanics(t, func() { ok = srv.validatePayload(p) })
require.False(t, ok)
})
t.Run("wrong validator index", func(t *testing.T) {
p.SetValidatorIndex(2)
require.NoError(t, p.Sign(priv))
require.False(t, srv.validatePayload(p))
})
t.Run("normal case", func(t *testing.T) {
p.SetValidatorIndex(1)
require.NoError(t, p.Sign(priv))
require.True(t, srv.validatePayload(p))
})
}
func TestService_OnPayload(t *testing.T) { func TestService_OnPayload(t *testing.T) {
srv := newTestService(t) srv := newTestService(t)

View file

@ -4,6 +4,7 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/crypto/hash"
"github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/io"
@ -195,18 +196,27 @@ func (p *Payload) Sign(key *privateKey) error {
} }
// Verify verifies payload using provided Witness. // Verify verifies payload using provided Witness.
func (p *Payload) Verify() bool { func (p *Payload) Verify(scriptHash util.Uint160) bool {
h := sha256.Sum256(p.MarshalUnsigned()) verification, err := core.ScriptFromWitness(scriptHash, &p.Witness)
v := vm.New() if err != nil {
v.SetCheckedHash(h[:])
v.Load(append(p.Witness.InvocationScript, p.Witness.VerificationScript...))
if err := v.Run(); err != nil || v.Estack().Len() == 0 {
return false return false
} }
result, err := v.Estack().Top().TryBool() v := vm.New()
h := sha256.Sum256(p.MarshalUnsigned())
return err == nil && result v.SetCheckedHash(h[:])
v.LoadScript(verification)
v.LoadScript(p.Witness.InvocationScript)
err = v.Run()
if err != nil || v.HasFailed() || v.Estack().Len() != 1 {
return false
}
res, err := v.Estack().Pop().TryBool()
return err == nil && res
} }
// DecodeBinaryUnsigned reads payload from w excluding signature. // DecodeBinaryUnsigned reads payload from w excluding signature.

View file

@ -300,9 +300,9 @@ func TestPayload_Sign(t *testing.T) {
priv := &privateKey{key} priv := &privateKey{key}
p := randomPayload(t, prepareRequestType) p := randomPayload(t, prepareRequestType)
require.False(t, p.Verify()) require.False(t, p.Verify(util.Uint160{}))
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
require.True(t, p.Verify()) require.True(t, p.Verify(p.Witness.ScriptHash()))
} }
func TestMessageType_String(t *testing.T) { func TestMessageType_String(t *testing.T) {

View file

@ -55,7 +55,8 @@ func TestRecoveryMessage_Setters(t *testing.T) {
ps := r.GetPrepareResponses(p, pubs) ps := r.GetPrepareResponses(p, pubs)
require.Len(t, ps, 1) require.Len(t, ps, 1)
require.Equal(t, p2, ps[0]) require.Equal(t, p2, ps[0])
require.True(t, ps[0].(*Payload).Verify()) ps0 := ps[0].(*Payload)
require.True(t, ps0.Verify(ps0.Witness.ScriptHash()))
}) })
t.Run("prepare request is added", func(t *testing.T) { t.Run("prepare request is added", func(t *testing.T) {
@ -66,7 +67,9 @@ func TestRecoveryMessage_Setters(t *testing.T) {
pr = r.GetPrepareRequest(p, pubs, p1.ValidatorIndex()) pr = r.GetPrepareRequest(p, pubs, p1.ValidatorIndex())
require.NotNil(t, pr) require.NotNil(t, pr)
require.Equal(t, p1, pr) require.Equal(t, p1, pr)
require.True(t, pr.(*Payload).Verify())
pl := pr.(*Payload)
require.True(t, pl.Verify(pl.Witness.ScriptHash()))
}) })
t.Run("change view is added", func(t *testing.T) { t.Run("change view is added", func(t *testing.T) {
@ -84,7 +87,9 @@ func TestRecoveryMessage_Setters(t *testing.T) {
ps := r.GetChangeViews(p, pubs) ps := r.GetChangeViews(p, pubs)
require.Len(t, ps, 1) require.Len(t, ps, 1)
require.Equal(t, p3, ps[0]) require.Equal(t, p3, ps[0])
require.True(t, ps[0].(*Payload).Verify())
ps0 := ps[0].(*Payload)
require.True(t, ps0.Verify(ps0.Witness.ScriptHash()))
}) })
t.Run("commit is added", func(t *testing.T) { t.Run("commit is added", func(t *testing.T) {
@ -99,7 +104,9 @@ func TestRecoveryMessage_Setters(t *testing.T) {
ps := r.GetCommits(p, pubs) ps := r.GetCommits(p, pubs)
require.Len(t, ps, 1) require.Len(t, ps, 1)
require.Equal(t, p4, ps[0]) require.Equal(t, p4, ps[0])
require.True(t, ps[0].(*Payload).Verify())
ps0 := ps[0].(*Payload)
require.True(t, ps0.Verify(ps0.Witness.ScriptHash()))
}) })
} }