crypto/consensus: sign hashes and cache them for consensus payloads

Avoid serializing payload again and again for various purposes. To sign it, we
only need a hash.

Some 2.4% gain in TPS could be achieved with this.
This commit is contained in:
Roman Khimov 2020-08-29 18:44:45 +03:00
parent 49e9c1aa0f
commit 53c014a0bb
11 changed files with 119 additions and 42 deletions

View file

@ -35,6 +35,10 @@ type (
height uint32 height uint32
Witness transaction.Witness Witness transaction.Witness
hash util.Uint256
signedHash util.Uint256
signedpart []byte
} }
) )
@ -106,11 +110,13 @@ func (p Payload) GetRecoveryMessage() payload.RecoveryMessage {
} }
// MarshalUnsigned implements payload.ConsensusPayload interface. // MarshalUnsigned implements payload.ConsensusPayload interface.
func (p Payload) MarshalUnsigned() []byte { func (p *Payload) MarshalUnsigned() []byte {
w := io.NewBufBinWriter() if p.signedpart == nil {
p.encodeHashData(w.BinWriter) w := io.NewBufBinWriter()
p.encodeHashData(w.BinWriter)
return w.Bytes() p.signedpart = w.Bytes()
}
return p.signedpart
} }
// UnmarshalUnsigned implements payload.ConsensusPayload interface. // UnmarshalUnsigned implements payload.ConsensusPayload interface.
@ -179,7 +185,10 @@ func (p *Payload) EncodeBinaryUnsigned(w *io.BinWriter) {
// EncodeBinary implements io.Serializable interface. // EncodeBinary implements io.Serializable interface.
func (p *Payload) EncodeBinary(w *io.BinWriter) { func (p *Payload) EncodeBinary(w *io.BinWriter) {
p.EncodeBinaryUnsigned(w) if p.signedpart == nil {
_ = p.MarshalUnsigned()
}
w.WriteBytes(p.signedpart[4:])
w.WriteB(1) w.WriteB(1)
p.Witness.EncodeBinary(w) p.Witness.EncodeBinary(w)
@ -193,10 +202,7 @@ func (p *Payload) encodeHashData(w *io.BinWriter) {
// Sign signs payload using the private key. // Sign signs payload using the private key.
// It also sets corresponding verification and invocation scripts. // It also sets corresponding verification and invocation scripts.
func (p *Payload) Sign(key *privateKey) error { func (p *Payload) Sign(key *privateKey) error {
sig, err := key.Sign(p.GetSignedPart()) sig := key.SignHash(p.GetSignedHash())
if err != nil {
return err
}
buf := io.NewBufBinWriter() buf := io.NewBufBinWriter()
emit.Bytes(buf.BinWriter, sig) emit.Bytes(buf.BinWriter, sig)
@ -224,15 +230,41 @@ func (p *Payload) DecodeBinaryUnsigned(r *io.BinReader) {
} }
} }
// GetSignedHash returns a hash of the payload used to verify it.
func (p *Payload) GetSignedHash() util.Uint256 {
if p.signedHash.Equals(util.Uint256{}) {
if p.createHash() != nil {
panic("failed to compute hash!")
}
}
return p.signedHash
}
// Hash implements payload.ConsensusPayload interface. // Hash implements payload.ConsensusPayload interface.
func (p *Payload) Hash() util.Uint256 { func (p *Payload) Hash() util.Uint256 {
w := io.NewBufBinWriter() if p.hash.Equals(util.Uint256{}) {
p.encodeHashData(w.BinWriter) if p.createHash() != nil {
if w.Err != nil { panic("failed to compute hash!")
panic("failed to hash payload") }
} }
return p.hash
}
return hash.DoubleSha256(w.Bytes()) // createHash creates hashes of the payload.
func (p *Payload) createHash() error {
b := p.GetSignedPart()
if b == nil {
return errors.New("failed to serialize hashable data")
}
p.updateHashes(b)
return nil
}
// updateHashes updates Payload's hashes based on the given buffer which should
// be a signable data slice.
func (p *Payload) updateHashes(b []byte) {
p.signedHash = hash.Sha256(b)
p.hash = hash.Sha256(p.signedHash.BytesBE())
} }
// DecodeBinary implements io.Serializable interface. // DecodeBinary implements io.Serializable interface.

View file

@ -104,12 +104,14 @@ func TestConsensusPayload_Serializable(t *testing.T) {
require.Nil(t, actual.message) require.Nil(t, actual.message)
// message should now be decoded from actual.data byte array // message should now be decoded from actual.data byte array
assert.NoError(t, actual.decodeData()) assert.NoError(t, actual.decodeData())
assert.NotNil(t, actual.MarshalUnsigned())
require.Equal(t, p, actual) require.Equal(t, p, actual)
data = p.MarshalUnsigned() data = p.MarshalUnsigned()
pu := NewPayload(netmode.Magic(rand.Uint32())) pu := NewPayload(netmode.Magic(rand.Uint32()))
require.NoError(t, pu.UnmarshalUnsigned(data)) require.NoError(t, pu.UnmarshalUnsigned(data))
assert.NoError(t, pu.decodeData()) assert.NoError(t, pu.decodeData())
_ = pu.MarshalUnsigned()
p.Witness = transaction.Witness{} p.Witness = transaction.Witness{}
require.Equal(t, p, pu) require.Equal(t, p, pu)
@ -153,6 +155,7 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
p := new(Payload) p := new(Payload)
require.NoError(t, testserdes.DecodeBinary(buf, p)) require.NoError(t, testserdes.DecodeBinary(buf, p))
// decode `data` into `message` // decode `data` into `message`
_ = p.Hash()
assert.NoError(t, p.decodeData()) assert.NoError(t, p.decodeData())
require.Equal(t, expected, p) require.Equal(t, expected, p)

View file

@ -57,6 +57,8 @@ func TestRecoveryMessage_Setters(t *testing.T) {
ps := r.GetPrepareResponses(p, pubs) ps := r.GetPrepareResponses(p, pubs)
require.Len(t, ps, 1) require.Len(t, ps, 1)
// Update hashes and serialized data.
_ = ps[0].Hash()
require.Equal(t, p2, ps[0]) require.Equal(t, p2, ps[0])
ps0 := ps[0].(*Payload) ps0 := ps[0].(*Payload)
require.True(t, srv.validatePayload(ps0)) require.True(t, srv.validatePayload(ps0))
@ -91,6 +93,8 @@ func TestRecoveryMessage_Setters(t *testing.T) {
ps := r.GetChangeViews(p, pubs) ps := r.GetChangeViews(p, pubs)
require.Len(t, ps, 1) require.Len(t, ps, 1)
// update hashes and serialized data.
_ = ps[0].Hash()
require.Equal(t, p3, ps[0]) require.Equal(t, p3, ps[0])
ps0 := ps[0].(*Payload) ps0 := ps[0].(*Payload)
@ -109,6 +113,8 @@ func TestRecoveryMessage_Setters(t *testing.T) {
ps := r.GetCommits(p, pubs) ps := r.GetCommits(p, pubs)
require.Len(t, ps, 1) require.Len(t, ps, 1)
// update hashes and serialized data.
_ = ps[0].Hash()
require.Equal(t, p4, ps[0]) require.Equal(t, p4, ps[0])
ps0 := ps[0].(*Payload) ps0 := ps[0].(*Payload)

View file

@ -78,8 +78,8 @@ func (b *Base) Hash() util.Uint256 {
return b.hash return b.hash
} }
// VerificationHash returns the hash of the block used to verify it. // GetSignedHash returns a hash of the block used to verify it.
func (b *Base) VerificationHash() util.Uint256 { func (b *Base) GetSignedHash() util.Uint256 {
if b.verificationHash.Equals(util.Uint256{}) { if b.verificationHash.Equals(util.Uint256{}) {
b.createHash() b.createHash()
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto" "github.com/nspcc-dev/neo-go/pkg/crypto"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
@ -30,15 +31,17 @@ func ECDSASecp256k1Verify(ic *interop.Context) error {
// ecdsaVerify is internal representation of ECDSASecp256k1Verify and // ecdsaVerify is internal representation of ECDSASecp256k1Verify and
// ECDSASecp256r1Verify. // ECDSASecp256r1Verify.
func ecdsaVerify(ic *interop.Context, curve elliptic.Curve) error { func ecdsaVerify(ic *interop.Context, curve elliptic.Curve) error {
msg := getMessage(ic, ic.VM.Estack().Pop().Item()) hashToCheck, err := getMessageHash(ic, ic.VM.Estack().Pop().Item())
hashToCheck := hash.Sha256(msg).BytesBE() if err != nil {
return err
}
keyb := ic.VM.Estack().Pop().Bytes() keyb := ic.VM.Estack().Pop().Bytes()
signature := ic.VM.Estack().Pop().Bytes() signature := ic.VM.Estack().Pop().Bytes()
pkey, err := keys.NewPublicKeyFromBytes(keyb, curve) pkey, err := keys.NewPublicKeyFromBytes(keyb, curve)
if err != nil { if err != nil {
return err return err
} }
res := pkey.Verify(signature, hashToCheck) res := pkey.Verify(signature, hashToCheck.BytesBE())
ic.VM.Estack().PushVal(res) ic.VM.Estack().PushVal(res)
return nil return nil
} }
@ -58,8 +61,10 @@ func ECDSASecp256k1CheckMultisig(ic *interop.Context) error {
// ecdsaCheckMultisig is internal representation of ECDSASecp256r1CheckMultisig and // ecdsaCheckMultisig is internal representation of ECDSASecp256r1CheckMultisig and
// ECDSASecp256k1CheckMultisig // ECDSASecp256k1CheckMultisig
func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error { func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error {
msg := getMessage(ic, ic.VM.Estack().Pop().Item()) hashToCheck, err := getMessageHash(ic, ic.VM.Estack().Pop().Item())
hashToCheck := hash.Sha256(msg).BytesBE() if err != nil {
return err
}
pkeys, err := ic.VM.Estack().PopSigElements() pkeys, err := ic.VM.Estack().PopSigElements()
if err != nil { if err != nil {
return fmt.Errorf("wrong parameters: %w", err) return fmt.Errorf("wrong parameters: %w", err)
@ -76,23 +81,23 @@ func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error {
if len(pkeys) < len(sigs) { if len(pkeys) < len(sigs) {
return errors.New("more signatures than there are keys") return errors.New("more signatures than there are keys")
} }
sigok := vm.CheckMultisigPar(ic.VM, curve, hashToCheck, pkeys, sigs) sigok := vm.CheckMultisigPar(ic.VM, curve, hashToCheck.BytesBE(), pkeys, sigs)
ic.VM.Estack().PushVal(sigok) ic.VM.Estack().PushVal(sigok)
return nil return nil
} }
func getMessage(ic *interop.Context, item stackitem.Item) []byte { func getMessageHash(ic *interop.Context, item stackitem.Item) (util.Uint256, error) {
var msg []byte var msg []byte
switch val := item.(type) { switch val := item.(type) {
case *stackitem.Interop: case *stackitem.Interop:
msg = val.Value().(crypto.Verifiable).GetSignedPart() return val.Value().(crypto.Verifiable).GetSignedHash(), nil
case stackitem.Null: case stackitem.Null:
msg = ic.Container.GetSignedPart() return ic.Container.GetSignedHash(), nil
default: default:
var err error var err error
if msg, err = val.TryBytes(); err != nil { if msg, err = val.TryBytes(); err != nil {
return nil return util.Uint256{}, err
} }
} }
return msg return hash.Sha256(msg), nil
} }

View file

@ -2,20 +2,37 @@ package crypto
import ( import (
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/crypto"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
// Sha256 returns sha256 hash of the data. // Sha256 returns sha256 hash of the data.
func Sha256(ic *interop.Context) error { func Sha256(ic *interop.Context) error {
msg := getMessage(ic, ic.VM.Estack().Pop().Item()) h, err := getMessageHash(ic, ic.VM.Estack().Pop().Item())
h := hash.Sha256(msg).BytesBE() if err != nil {
ic.VM.Estack().PushVal(h) return err
}
ic.VM.Estack().PushVal(h.BytesBE())
return nil return nil
} }
// RipeMD160 returns RipeMD160 hash of the data. // RipeMD160 returns RipeMD160 hash of the data.
func RipeMD160(ic *interop.Context) error { func RipeMD160(ic *interop.Context) error {
msg := getMessage(ic, ic.VM.Estack().Pop().Item()) var msg []byte
item := ic.VM.Estack().Pop().Item()
switch val := item.(type) {
case *stackitem.Interop:
msg = val.Value().(crypto.Verifiable).GetSignedPart()
case stackitem.Null:
msg = ic.Container.GetSignedPart()
default:
var err error
if msg, err = val.TryBytes(); err != nil {
return err
}
}
h := hash.RipeMD160(msg).BytesBE() h := hash.RipeMD160(msg).BytesBE()
ic.VM.Estack().PushVal(h) ic.VM.Estack().PushVal(h)
return nil return nil

View file

@ -216,7 +216,7 @@ func TestECDSAVerify(t *testing.T) {
t.Run("invalid message", func(t *testing.T) { t.Run("invalid message", func(t *testing.T) {
sign := priv.Sign(msg) sign := priv.Sign(msg)
runCase(t, false, false, sign, priv.PublicKey().Bytes(), runCase(t, true, false, sign, priv.PublicKey().Bytes(),
stackitem.NewArray([]stackitem.Item{stackitem.NewByteArray(msg)})) stackitem.NewArray([]stackitem.Item{stackitem.NewByteArray(msg)}))
}) })
} }

View file

@ -59,6 +59,13 @@ func (s *MPTRootBase) GetSignedPart() []byte {
return buf.Bytes() return buf.Bytes()
} }
// GetSignedHash returns hash of MPTRootBase which needs to be signed.
func (s *MPTRootBase) GetSignedHash() util.Uint256 {
buf := io.NewBufBinWriter()
s.EncodeBinary(buf.BinWriter)
return hash.Sha256(buf.Bytes())
}
// Equals checks if s == other. // Equals checks if s == other.
func (s *MPTRootBase) Equals(other *MPTRootBase) bool { func (s *MPTRootBase) Equals(other *MPTRootBase) bool {
return s.Version == other.Version && s.Index == other.Index && return s.Version == other.Version && s.Index == other.Index &&

View file

@ -110,8 +110,8 @@ func (t *Transaction) Hash() util.Uint256 {
return t.hash return t.hash
} }
// VerificationHash returns the hash of the transaction used to verify it. // GetSignedHash returns a hash of the transaction used to verify it.
func (t *Transaction) VerificationHash() util.Uint256 { func (t *Transaction) GetSignedHash() util.Uint256 {
if t.verificationHash.Equals(util.Uint256{}) { if t.verificationHash.Equals(util.Uint256{}) {
if t.createHash() != nil { if t.createHash() != nil {
panic("failed to compute hash!") panic("failed to compute hash!")

View file

@ -128,15 +128,19 @@ func (p *PrivateKey) GetScriptHash() util.Uint160 {
return pk.GetScriptHash() return pk.GetScriptHash()
} }
// Sign signs arbitrary length data using the private key. // Sign signs arbitrary length data using the private key. It uses SHA256 to
// calculate hash and then SignHash to create a signature (so you can save on
// hash calculation if you already have it).
func (p *PrivateKey) Sign(data []byte) []byte { func (p *PrivateKey) Sign(data []byte) []byte {
var ( var digest = sha256.Sum256(data)
privateKey = &p.PrivateKey
digest = sha256.Sum256(data)
)
r, s := rfc6979.SignECDSA(privateKey, digest[:], sha256.New) return p.SignHash(digest)
return getSignatureSlice(privateKey.Curve, r, s) }
// SignHash signs particular hash the private key.
func (p *PrivateKey) SignHash(digest util.Uint256) []byte {
r, s := rfc6979.SignECDSA(&p.PrivateKey, digest[:], sha256.New)
return getSignatureSlice(p.PrivateKey.Curve, r, s)
} }
func getSignatureSlice(curve elliptic.Curve, r, s *big.Int) []byte { func getSignatureSlice(curve elliptic.Curve, r, s *big.Int) []byte {

View file

@ -1,8 +1,11 @@
package crypto package crypto
import "github.com/nspcc-dev/neo-go/pkg/util"
// Verifiable represents an object which can be verified. // Verifiable represents an object which can be verified.
type Verifiable interface { type Verifiable interface {
GetSignedPart() []byte GetSignedPart() []byte
GetSignedHash() util.Uint256
} }
// VerifiableDecodable represents an object which can be both verified and // VerifiableDecodable represents an object which can be both verified and