diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 57b9ac32e..3b00979ad 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -35,6 +35,10 @@ type ( height uint32 Witness transaction.Witness + + hash util.Uint256 + signedHash util.Uint256 + signedpart []byte } ) @@ -106,11 +110,13 @@ func (p Payload) GetRecoveryMessage() payload.RecoveryMessage { } // MarshalUnsigned implements payload.ConsensusPayload interface. -func (p Payload) MarshalUnsigned() []byte { - w := io.NewBufBinWriter() - p.encodeHashData(w.BinWriter) - - return w.Bytes() +func (p *Payload) MarshalUnsigned() []byte { + if p.signedpart == nil { + w := io.NewBufBinWriter() + p.encodeHashData(w.BinWriter) + p.signedpart = w.Bytes() + } + return p.signedpart } // UnmarshalUnsigned implements payload.ConsensusPayload interface. @@ -179,7 +185,10 @@ func (p *Payload) EncodeBinaryUnsigned(w *io.BinWriter) { // EncodeBinary implements io.Serializable interface. func (p *Payload) EncodeBinary(w *io.BinWriter) { - p.EncodeBinaryUnsigned(w) + if p.signedpart == nil { + _ = p.MarshalUnsigned() + } + w.WriteBytes(p.signedpart[4:]) w.WriteB(1) p.Witness.EncodeBinary(w) @@ -193,10 +202,7 @@ func (p *Payload) encodeHashData(w *io.BinWriter) { // Sign signs payload using the private key. // It also sets corresponding verification and invocation scripts. func (p *Payload) Sign(key *privateKey) error { - sig, err := key.Sign(p.GetSignedPart()) - if err != nil { - return err - } + sig := key.SignHash(p.GetSignedHash()) buf := io.NewBufBinWriter() emit.Bytes(buf.BinWriter, sig) @@ -224,15 +230,41 @@ func (p *Payload) DecodeBinaryUnsigned(r *io.BinReader) { } } +// GetSignedHash returns a hash of the payload used to verify it. +func (p *Payload) GetSignedHash() util.Uint256 { + if p.signedHash.Equals(util.Uint256{}) { + if p.createHash() != nil { + panic("failed to compute hash!") + } + } + return p.signedHash +} + // Hash implements payload.ConsensusPayload interface. func (p *Payload) Hash() util.Uint256 { - w := io.NewBufBinWriter() - p.encodeHashData(w.BinWriter) - if w.Err != nil { - panic("failed to hash payload") + if p.hash.Equals(util.Uint256{}) { + if p.createHash() != nil { + panic("failed to compute hash!") + } } + return p.hash +} - return hash.DoubleSha256(w.Bytes()) +// createHash creates hashes of the payload. +func (p *Payload) createHash() error { + b := p.GetSignedPart() + if b == nil { + return errors.New("failed to serialize hashable data") + } + p.updateHashes(b) + return nil +} + +// updateHashes updates Payload's hashes based on the given buffer which should +// be a signable data slice. +func (p *Payload) updateHashes(b []byte) { + p.signedHash = hash.Sha256(b) + p.hash = hash.Sha256(p.signedHash.BytesBE()) } // DecodeBinary implements io.Serializable interface. diff --git a/pkg/consensus/payload_test.go b/pkg/consensus/payload_test.go index e67400095..1724715c8 100644 --- a/pkg/consensus/payload_test.go +++ b/pkg/consensus/payload_test.go @@ -104,12 +104,14 @@ func TestConsensusPayload_Serializable(t *testing.T) { require.Nil(t, actual.message) // message should now be decoded from actual.data byte array assert.NoError(t, actual.decodeData()) + assert.NotNil(t, actual.MarshalUnsigned()) require.Equal(t, p, actual) data = p.MarshalUnsigned() pu := NewPayload(netmode.Magic(rand.Uint32())) require.NoError(t, pu.UnmarshalUnsigned(data)) assert.NoError(t, pu.decodeData()) + _ = pu.MarshalUnsigned() p.Witness = transaction.Witness{} require.Equal(t, p, pu) @@ -153,6 +155,7 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { p := new(Payload) require.NoError(t, testserdes.DecodeBinary(buf, p)) // decode `data` into `message` + _ = p.Hash() assert.NoError(t, p.decodeData()) require.Equal(t, expected, p) diff --git a/pkg/consensus/recovery_message_test.go b/pkg/consensus/recovery_message_test.go index 6b0cba590..b5ef23ceb 100644 --- a/pkg/consensus/recovery_message_test.go +++ b/pkg/consensus/recovery_message_test.go @@ -57,6 +57,8 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetPrepareResponses(p, pubs) require.Len(t, ps, 1) + // Update hashes and serialized data. + _ = ps[0].Hash() require.Equal(t, p2, ps[0]) ps0 := ps[0].(*Payload) require.True(t, srv.validatePayload(ps0)) @@ -91,6 +93,8 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetChangeViews(p, pubs) require.Len(t, ps, 1) + // update hashes and serialized data. + _ = ps[0].Hash() require.Equal(t, p3, ps[0]) ps0 := ps[0].(*Payload) @@ -109,6 +113,8 @@ func TestRecoveryMessage_Setters(t *testing.T) { ps := r.GetCommits(p, pubs) require.Len(t, ps, 1) + // update hashes and serialized data. + _ = ps[0].Hash() require.Equal(t, p4, ps[0]) ps0 := ps[0].(*Payload) diff --git a/pkg/core/block/block_base.go b/pkg/core/block/block_base.go index 0b5af8c95..259dcc4d5 100644 --- a/pkg/core/block/block_base.go +++ b/pkg/core/block/block_base.go @@ -78,8 +78,8 @@ func (b *Base) Hash() util.Uint256 { return b.hash } -// VerificationHash returns the hash of the block used to verify it. -func (b *Base) VerificationHash() util.Uint256 { +// GetSignedHash returns a hash of the block used to verify it. +func (b *Base) GetSignedHash() util.Uint256 { if b.verificationHash.Equals(util.Uint256{}) { b.createHash() } diff --git a/pkg/core/interop/crypto/ecdsa.go b/pkg/core/interop/crypto/ecdsa.go index faf5526f3..5f79e02ad 100644 --- a/pkg/core/interop/crypto/ecdsa.go +++ b/pkg/core/interop/crypto/ecdsa.go @@ -10,6 +10,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -30,15 +31,17 @@ func ECDSASecp256k1Verify(ic *interop.Context) error { // ecdsaVerify is internal representation of ECDSASecp256k1Verify and // ECDSASecp256r1Verify. func ecdsaVerify(ic *interop.Context, curve elliptic.Curve) error { - msg := getMessage(ic, ic.VM.Estack().Pop().Item()) - hashToCheck := hash.Sha256(msg).BytesBE() + hashToCheck, err := getMessageHash(ic, ic.VM.Estack().Pop().Item()) + if err != nil { + return err + } keyb := ic.VM.Estack().Pop().Bytes() signature := ic.VM.Estack().Pop().Bytes() pkey, err := keys.NewPublicKeyFromBytes(keyb, curve) if err != nil { return err } - res := pkey.Verify(signature, hashToCheck) + res := pkey.Verify(signature, hashToCheck.BytesBE()) ic.VM.Estack().PushVal(res) return nil } @@ -58,8 +61,10 @@ func ECDSASecp256k1CheckMultisig(ic *interop.Context) error { // ecdsaCheckMultisig is internal representation of ECDSASecp256r1CheckMultisig and // ECDSASecp256k1CheckMultisig func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error { - msg := getMessage(ic, ic.VM.Estack().Pop().Item()) - hashToCheck := hash.Sha256(msg).BytesBE() + hashToCheck, err := getMessageHash(ic, ic.VM.Estack().Pop().Item()) + if err != nil { + return err + } pkeys, err := ic.VM.Estack().PopSigElements() if err != nil { return fmt.Errorf("wrong parameters: %w", err) @@ -76,23 +81,23 @@ func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error { if len(pkeys) < len(sigs) { return errors.New("more signatures than there are keys") } - sigok := vm.CheckMultisigPar(ic.VM, curve, hashToCheck, pkeys, sigs) + sigok := vm.CheckMultisigPar(ic.VM, curve, hashToCheck.BytesBE(), pkeys, sigs) ic.VM.Estack().PushVal(sigok) return nil } -func getMessage(ic *interop.Context, item stackitem.Item) []byte { +func getMessageHash(ic *interop.Context, item stackitem.Item) (util.Uint256, error) { var msg []byte switch val := item.(type) { case *stackitem.Interop: - msg = val.Value().(crypto.Verifiable).GetSignedPart() + return val.Value().(crypto.Verifiable).GetSignedHash(), nil case stackitem.Null: - msg = ic.Container.GetSignedPart() + return ic.Container.GetSignedHash(), nil default: var err error if msg, err = val.TryBytes(); err != nil { - return nil + return util.Uint256{}, err } } - return msg + return hash.Sha256(msg), nil } diff --git a/pkg/core/interop/crypto/hash.go b/pkg/core/interop/crypto/hash.go index 106f545fd..f00efb3a3 100644 --- a/pkg/core/interop/crypto/hash.go +++ b/pkg/core/interop/crypto/hash.go @@ -2,20 +2,37 @@ package crypto import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/crypto" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) // Sha256 returns sha256 hash of the data. func Sha256(ic *interop.Context) error { - msg := getMessage(ic, ic.VM.Estack().Pop().Item()) - h := hash.Sha256(msg).BytesBE() - ic.VM.Estack().PushVal(h) + h, err := getMessageHash(ic, ic.VM.Estack().Pop().Item()) + if err != nil { + return err + } + ic.VM.Estack().PushVal(h.BytesBE()) return nil } // RipeMD160 returns RipeMD160 hash of the data. func RipeMD160(ic *interop.Context) error { - msg := getMessage(ic, ic.VM.Estack().Pop().Item()) + var msg []byte + + item := ic.VM.Estack().Pop().Item() + switch val := item.(type) { + case *stackitem.Interop: + msg = val.Value().(crypto.Verifiable).GetSignedPart() + case stackitem.Null: + msg = ic.Container.GetSignedPart() + default: + var err error + if msg, err = val.TryBytes(); err != nil { + return err + } + } h := hash.RipeMD160(msg).BytesBE() ic.VM.Estack().PushVal(h) return nil diff --git a/pkg/core/interop_neo_test.go b/pkg/core/interop_neo_test.go index ef5bbfa28..3ce5fd8cb 100644 --- a/pkg/core/interop_neo_test.go +++ b/pkg/core/interop_neo_test.go @@ -216,7 +216,7 @@ func TestECDSAVerify(t *testing.T) { t.Run("invalid message", func(t *testing.T) { sign := priv.Sign(msg) - runCase(t, false, false, sign, priv.PublicKey().Bytes(), + runCase(t, true, false, sign, priv.PublicKey().Bytes(), stackitem.NewArray([]stackitem.Item{stackitem.NewByteArray(msg)})) }) } diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go index dea3f62fa..f2c890781 100644 --- a/pkg/core/state/mpt_root.go +++ b/pkg/core/state/mpt_root.go @@ -59,6 +59,13 @@ func (s *MPTRootBase) GetSignedPart() []byte { return buf.Bytes() } +// GetSignedHash returns hash of MPTRootBase which needs to be signed. +func (s *MPTRootBase) GetSignedHash() util.Uint256 { + buf := io.NewBufBinWriter() + s.EncodeBinary(buf.BinWriter) + return hash.Sha256(buf.Bytes()) +} + // Equals checks if s == other. func (s *MPTRootBase) Equals(other *MPTRootBase) bool { return s.Version == other.Version && s.Index == other.Index && diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 9e343c780..021708d67 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -110,8 +110,8 @@ func (t *Transaction) Hash() util.Uint256 { return t.hash } -// VerificationHash returns the hash of the transaction used to verify it. -func (t *Transaction) VerificationHash() util.Uint256 { +// GetSignedHash returns a hash of the transaction used to verify it. +func (t *Transaction) GetSignedHash() util.Uint256 { if t.verificationHash.Equals(util.Uint256{}) { if t.createHash() != nil { panic("failed to compute hash!") diff --git a/pkg/crypto/keys/private_key.go b/pkg/crypto/keys/private_key.go index 2771b80d7..c790980f2 100644 --- a/pkg/crypto/keys/private_key.go +++ b/pkg/crypto/keys/private_key.go @@ -128,15 +128,19 @@ func (p *PrivateKey) GetScriptHash() util.Uint160 { return pk.GetScriptHash() } -// Sign signs arbitrary length data using the private key. +// Sign signs arbitrary length data using the private key. It uses SHA256 to +// calculate hash and then SignHash to create a signature (so you can save on +// hash calculation if you already have it). func (p *PrivateKey) Sign(data []byte) []byte { - var ( - privateKey = &p.PrivateKey - digest = sha256.Sum256(data) - ) + var digest = sha256.Sum256(data) - r, s := rfc6979.SignECDSA(privateKey, digest[:], sha256.New) - return getSignatureSlice(privateKey.Curve, r, s) + return p.SignHash(digest) +} + +// SignHash signs particular hash the private key. +func (p *PrivateKey) SignHash(digest util.Uint256) []byte { + r, s := rfc6979.SignECDSA(&p.PrivateKey, digest[:], sha256.New) + return getSignatureSlice(p.PrivateKey.Curve, r, s) } func getSignatureSlice(curve elliptic.Curve, r, s *big.Int) []byte { diff --git a/pkg/crypto/verifiable.go b/pkg/crypto/verifiable.go index 341358f62..8be328d83 100644 --- a/pkg/crypto/verifiable.go +++ b/pkg/crypto/verifiable.go @@ -1,8 +1,11 @@ package crypto +import "github.com/nspcc-dev/neo-go/pkg/util" + // Verifiable represents an object which can be verified. type Verifiable interface { GetSignedPart() []byte + GetSignedHash() util.Uint256 } // VerifiableDecodable represents an object which can be both verified and