diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 996c30e78..39c7de9e2 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -9,6 +9,7 @@ import ( "github.com/CityOfZion/neo-go/config" "github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core/transaction" + "github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/smartcontract" "github.com/CityOfZion/neo-go/pkg/util" @@ -163,6 +164,19 @@ func (s *service) eventLoop() { } } +func (s *service) validatePayload(p *Payload) bool { + validators := s.getValidators() + if int(p.validatorIndex) >= len(validators) { + return false + } + + pub := validators[p.validatorIndex] + vs := pub.(*publicKey).GetVerificationScript() + h := hash.Hash160(vs) + + return p.Verify(h) +} + func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) { acc, err := wallet.DecryptAccount(cfg.Path, cfg.Password) if err != nil { @@ -179,10 +193,7 @@ func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) // OnPayload handles Payload receive. func (s *service) OnPayload(cp *Payload) { - if !cp.Verify() { - s.log.Debug("can't verify payload from #%d", cp.validatorIndex) - return - } else if s.cache.Has(cp.Hash()) { + if !s.validatePayload(cp) || s.cache.Has(cp.Hash()) { return } diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 12af0e89d..d934906fc 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -27,6 +27,35 @@ func TestNewService(t *testing.T) { require.Equal(t, tx, txx[1]) } +func TestService_ValidatePayload(t *testing.T) { + srv := newTestService(t) + priv, _ := getTestValidator(1) + p := new(Payload) + + p.SetPayload(&prepareRequest{}) + + t.Run("invalid validator index", func(t *testing.T) { + p.SetValidatorIndex(11) + require.NoError(t, p.Sign(priv)) + + var ok bool + require.NotPanics(t, func() { ok = srv.validatePayload(p) }) + require.False(t, ok) + }) + + t.Run("wrong validator index", func(t *testing.T) { + p.SetValidatorIndex(2) + require.NoError(t, p.Sign(priv)) + require.False(t, srv.validatePayload(p)) + }) + + t.Run("normal case", func(t *testing.T) { + p.SetValidatorIndex(1) + require.NoError(t, p.Sign(priv)) + require.True(t, srv.validatePayload(p)) + }) +} + func TestService_OnPayload(t *testing.T) { srv := newTestService(t) diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 0546f5422..9c9e39099 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "fmt" + "github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/io" @@ -195,18 +196,27 @@ func (p *Payload) Sign(key *privateKey) error { } // Verify verifies payload using provided Witness. -func (p *Payload) Verify() bool { - h := sha256.Sum256(p.MarshalUnsigned()) - v := vm.New() - v.SetCheckedHash(h[:]) - v.Load(append(p.Witness.InvocationScript, p.Witness.VerificationScript...)) - if err := v.Run(); err != nil || v.Estack().Len() == 0 { +func (p *Payload) Verify(scriptHash util.Uint160) bool { + verification, err := core.ScriptFromWitness(scriptHash, &p.Witness) + if err != nil { return false } - result, err := v.Estack().Top().TryBool() + v := vm.New() + h := sha256.Sum256(p.MarshalUnsigned()) - return err == nil && result + v.SetCheckedHash(h[:]) + v.LoadScript(verification) + v.LoadScript(p.Witness.InvocationScript) + + err = v.Run() + if err != nil || v.HasFailed() || v.Estack().Len() != 1 { + return false + } + + res, err := v.Estack().Pop().TryBool() + + return err == nil && res } // DecodeBinaryUnsigned reads payload from w excluding signature. diff --git a/pkg/consensus/payload_test.go b/pkg/consensus/payload_test.go index f61391e73..31ce8173f 100644 --- a/pkg/consensus/payload_test.go +++ b/pkg/consensus/payload_test.go @@ -300,9 +300,9 @@ func TestPayload_Sign(t *testing.T) { priv := &privateKey{key} p := randomPayload(t, prepareRequestType) - require.False(t, p.Verify()) + require.False(t, p.Verify(util.Uint160{})) require.NoError(t, p.Sign(priv)) - require.True(t, p.Verify()) + require.True(t, p.Verify(p.Witness.ScriptHash())) } func TestMessageType_String(t *testing.T) { diff --git a/pkg/consensus/recovery_message_test.go b/pkg/consensus/recovery_message_test.go index 061ac1731..ab798eca3 100644 --- a/pkg/consensus/recovery_message_test.go +++ b/pkg/consensus/recovery_message_test.go @@ -55,7 +55,8 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetPrepareResponses(p, pubs) require.Len(t, ps, 1) require.Equal(t, p2, ps[0]) - require.True(t, ps[0].(*Payload).Verify()) + ps0 := ps[0].(*Payload) + require.True(t, ps0.Verify(ps0.Witness.ScriptHash())) }) t.Run("prepare request is added", func(t *testing.T) { @@ -66,7 +67,9 @@ func TestRecoveryMessage_Setters(t *testing.T) { pr = r.GetPrepareRequest(p, pubs, p1.ValidatorIndex()) require.NotNil(t, pr) require.Equal(t, p1, pr) - require.True(t, pr.(*Payload).Verify()) + + pl := pr.(*Payload) + require.True(t, pl.Verify(pl.Witness.ScriptHash())) }) t.Run("change view is added", func(t *testing.T) { @@ -84,7 +87,9 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetChangeViews(p, pubs) require.Len(t, ps, 1) require.Equal(t, p3, ps[0]) - require.True(t, ps[0].(*Payload).Verify()) + + ps0 := ps[0].(*Payload) + require.True(t, ps0.Verify(ps0.Witness.ScriptHash())) }) t.Run("commit is added", func(t *testing.T) { @@ -99,7 +104,9 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetCommits(p, pubs) require.Len(t, ps, 1) require.Equal(t, p4, ps[0]) - require.True(t, ps[0].(*Payload).Verify()) + + ps0 := ps[0].(*Payload) + require.True(t, ps0.Verify(ps0.Witness.ScriptHash())) }) } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 863da412d..a44a03386 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1372,21 +1372,30 @@ func (bc *Blockchain) GetTestVM() (*vm.VM, storage.Store) { return vm, tmpStore } -// verifyHashAgainstScript verifies given hash against the given witness. -func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transaction.Witness, checkedHash util.Uint256, interopCtx *interopContext, useKeys bool) error { +// ScriptFromWitness returns verification script for provided witness. +// If hash is not equal to the witness script hash, error is returned. +func ScriptFromWitness(hash util.Uint160, witness *transaction.Witness) ([]byte, error) { verification := witness.VerificationScript if len(verification) == 0 { bb := new(bytes.Buffer) err := vm.EmitAppCall(bb, hash, false) if err != nil { - return err + return nil, err } verification = bb.Bytes() - } else { - if h := witness.ScriptHash(); hash != h { - return errors.New("witness hash mismatch") - } + } else if h := witness.ScriptHash(); hash != h { + return nil, errors.New("witness hash mismatch") + } + + return verification, nil +} + +// verifyHashAgainstScript verifies given hash against the given witness. +func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transaction.Witness, checkedHash util.Uint256, interopCtx *interopContext, useKeys bool) error { + verification, err := ScriptFromWitness(hash, witness) + if err != nil { + return err } vm := bc.spawnVMWithInterops(interopCtx) @@ -1396,7 +1405,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa if useKeys && bc.keyCache[hash] != nil { vm.SetPublicKeys(bc.keyCache[hash]) } - err := vm.Run() + err = vm.Run() if vm.HasFailed() { return errors.Errorf("vm failed to execute the script with error: %s", err) } diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index f539b865b..0fa8b3aae 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -4,7 +4,10 @@ import ( "testing" "github.com/CityOfZion/neo-go/pkg/core/storage" + "github.com/CityOfZion/neo-go/pkg/core/transaction" + "github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/io" + "github.com/CityOfZion/neo-go/pkg/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -65,6 +68,27 @@ func TestAddBlock(t *testing.T) { assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) } +func TestScriptFromWitness(t *testing.T) { + witness := &transaction.Witness{} + h := util.Uint160{1, 2, 3} + + res, err := ScriptFromWitness(h, witness) + require.NoError(t, err) + require.NotNil(t, res) + + witness.VerificationScript = []byte{4, 8, 15, 16, 23, 42} + h = hash.Hash160(witness.VerificationScript) + + res, err = ScriptFromWitness(h, witness) + require.NoError(t, err) + require.NotNil(t, res) + + h[0] = ^h[0] + res, err = ScriptFromWitness(h, witness) + require.Error(t, err) + require.Nil(t, res) +} + func TestGetHeader(t *testing.T) { bc := newTestChain(t) block := newBlock(1, newMinerTX())