forked from TrueCloudLab/neoneo-go
crypto/consensus: sign hashes and cache them for consensus payloads
Avoid serializing payload again and again for various purposes. To sign it, we only need a hash. Some 2.4% gain in TPS could be achieved with this.
This commit is contained in:
parent
49e9c1aa0f
commit
53c014a0bb
11 changed files with 119 additions and 42 deletions
|
@ -35,6 +35,10 @@ type (
|
||||||
height uint32
|
height uint32
|
||||||
|
|
||||||
Witness transaction.Witness
|
Witness transaction.Witness
|
||||||
|
|
||||||
|
hash util.Uint256
|
||||||
|
signedHash util.Uint256
|
||||||
|
signedpart []byte
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,11 +110,13 @@ func (p Payload) GetRecoveryMessage() payload.RecoveryMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalUnsigned implements payload.ConsensusPayload interface.
|
// MarshalUnsigned implements payload.ConsensusPayload interface.
|
||||||
func (p Payload) MarshalUnsigned() []byte {
|
func (p *Payload) MarshalUnsigned() []byte {
|
||||||
w := io.NewBufBinWriter()
|
if p.signedpart == nil {
|
||||||
p.encodeHashData(w.BinWriter)
|
w := io.NewBufBinWriter()
|
||||||
|
p.encodeHashData(w.BinWriter)
|
||||||
return w.Bytes()
|
p.signedpart = w.Bytes()
|
||||||
|
}
|
||||||
|
return p.signedpart
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalUnsigned implements payload.ConsensusPayload interface.
|
// UnmarshalUnsigned implements payload.ConsensusPayload interface.
|
||||||
|
@ -179,7 +185,10 @@ func (p *Payload) EncodeBinaryUnsigned(w *io.BinWriter) {
|
||||||
|
|
||||||
// EncodeBinary implements io.Serializable interface.
|
// EncodeBinary implements io.Serializable interface.
|
||||||
func (p *Payload) EncodeBinary(w *io.BinWriter) {
|
func (p *Payload) EncodeBinary(w *io.BinWriter) {
|
||||||
p.EncodeBinaryUnsigned(w)
|
if p.signedpart == nil {
|
||||||
|
_ = p.MarshalUnsigned()
|
||||||
|
}
|
||||||
|
w.WriteBytes(p.signedpart[4:])
|
||||||
|
|
||||||
w.WriteB(1)
|
w.WriteB(1)
|
||||||
p.Witness.EncodeBinary(w)
|
p.Witness.EncodeBinary(w)
|
||||||
|
@ -193,10 +202,7 @@ func (p *Payload) encodeHashData(w *io.BinWriter) {
|
||||||
// Sign signs payload using the private key.
|
// Sign signs payload using the private key.
|
||||||
// It also sets corresponding verification and invocation scripts.
|
// It also sets corresponding verification and invocation scripts.
|
||||||
func (p *Payload) Sign(key *privateKey) error {
|
func (p *Payload) Sign(key *privateKey) error {
|
||||||
sig, err := key.Sign(p.GetSignedPart())
|
sig := key.SignHash(p.GetSignedHash())
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := io.NewBufBinWriter()
|
buf := io.NewBufBinWriter()
|
||||||
emit.Bytes(buf.BinWriter, sig)
|
emit.Bytes(buf.BinWriter, sig)
|
||||||
|
@ -224,15 +230,41 @@ func (p *Payload) DecodeBinaryUnsigned(r *io.BinReader) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSignedHash returns a hash of the payload used to verify it.
|
||||||
|
func (p *Payload) GetSignedHash() util.Uint256 {
|
||||||
|
if p.signedHash.Equals(util.Uint256{}) {
|
||||||
|
if p.createHash() != nil {
|
||||||
|
panic("failed to compute hash!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.signedHash
|
||||||
|
}
|
||||||
|
|
||||||
// Hash implements payload.ConsensusPayload interface.
|
// Hash implements payload.ConsensusPayload interface.
|
||||||
func (p *Payload) Hash() util.Uint256 {
|
func (p *Payload) Hash() util.Uint256 {
|
||||||
w := io.NewBufBinWriter()
|
if p.hash.Equals(util.Uint256{}) {
|
||||||
p.encodeHashData(w.BinWriter)
|
if p.createHash() != nil {
|
||||||
if w.Err != nil {
|
panic("failed to compute hash!")
|
||||||
panic("failed to hash payload")
|
}
|
||||||
}
|
}
|
||||||
|
return p.hash
|
||||||
|
}
|
||||||
|
|
||||||
return hash.DoubleSha256(w.Bytes())
|
// createHash creates hashes of the payload.
|
||||||
|
func (p *Payload) createHash() error {
|
||||||
|
b := p.GetSignedPart()
|
||||||
|
if b == nil {
|
||||||
|
return errors.New("failed to serialize hashable data")
|
||||||
|
}
|
||||||
|
p.updateHashes(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateHashes updates Payload's hashes based on the given buffer which should
|
||||||
|
// be a signable data slice.
|
||||||
|
func (p *Payload) updateHashes(b []byte) {
|
||||||
|
p.signedHash = hash.Sha256(b)
|
||||||
|
p.hash = hash.Sha256(p.signedHash.BytesBE())
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeBinary implements io.Serializable interface.
|
// DecodeBinary implements io.Serializable interface.
|
||||||
|
|
|
@ -104,12 +104,14 @@ func TestConsensusPayload_Serializable(t *testing.T) {
|
||||||
require.Nil(t, actual.message)
|
require.Nil(t, actual.message)
|
||||||
// message should now be decoded from actual.data byte array
|
// message should now be decoded from actual.data byte array
|
||||||
assert.NoError(t, actual.decodeData())
|
assert.NoError(t, actual.decodeData())
|
||||||
|
assert.NotNil(t, actual.MarshalUnsigned())
|
||||||
require.Equal(t, p, actual)
|
require.Equal(t, p, actual)
|
||||||
|
|
||||||
data = p.MarshalUnsigned()
|
data = p.MarshalUnsigned()
|
||||||
pu := NewPayload(netmode.Magic(rand.Uint32()))
|
pu := NewPayload(netmode.Magic(rand.Uint32()))
|
||||||
require.NoError(t, pu.UnmarshalUnsigned(data))
|
require.NoError(t, pu.UnmarshalUnsigned(data))
|
||||||
assert.NoError(t, pu.decodeData())
|
assert.NoError(t, pu.decodeData())
|
||||||
|
_ = pu.MarshalUnsigned()
|
||||||
|
|
||||||
p.Witness = transaction.Witness{}
|
p.Witness = transaction.Witness{}
|
||||||
require.Equal(t, p, pu)
|
require.Equal(t, p, pu)
|
||||||
|
@ -153,6 +155,7 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
|
||||||
p := new(Payload)
|
p := new(Payload)
|
||||||
require.NoError(t, testserdes.DecodeBinary(buf, p))
|
require.NoError(t, testserdes.DecodeBinary(buf, p))
|
||||||
// decode `data` into `message`
|
// decode `data` into `message`
|
||||||
|
_ = p.Hash()
|
||||||
assert.NoError(t, p.decodeData())
|
assert.NoError(t, p.decodeData())
|
||||||
require.Equal(t, expected, p)
|
require.Equal(t, expected, p)
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,8 @@ func TestRecoveryMessage_Setters(t *testing.T) {
|
||||||
|
|
||||||
ps := r.GetPrepareResponses(p, pubs)
|
ps := r.GetPrepareResponses(p, pubs)
|
||||||
require.Len(t, ps, 1)
|
require.Len(t, ps, 1)
|
||||||
|
// Update hashes and serialized data.
|
||||||
|
_ = ps[0].Hash()
|
||||||
require.Equal(t, p2, ps[0])
|
require.Equal(t, p2, ps[0])
|
||||||
ps0 := ps[0].(*Payload)
|
ps0 := ps[0].(*Payload)
|
||||||
require.True(t, srv.validatePayload(ps0))
|
require.True(t, srv.validatePayload(ps0))
|
||||||
|
@ -91,6 +93,8 @@ func TestRecoveryMessage_Setters(t *testing.T) {
|
||||||
|
|
||||||
ps := r.GetChangeViews(p, pubs)
|
ps := r.GetChangeViews(p, pubs)
|
||||||
require.Len(t, ps, 1)
|
require.Len(t, ps, 1)
|
||||||
|
// update hashes and serialized data.
|
||||||
|
_ = ps[0].Hash()
|
||||||
require.Equal(t, p3, ps[0])
|
require.Equal(t, p3, ps[0])
|
||||||
|
|
||||||
ps0 := ps[0].(*Payload)
|
ps0 := ps[0].(*Payload)
|
||||||
|
@ -109,6 +113,8 @@ func TestRecoveryMessage_Setters(t *testing.T) {
|
||||||
|
|
||||||
ps := r.GetCommits(p, pubs)
|
ps := r.GetCommits(p, pubs)
|
||||||
require.Len(t, ps, 1)
|
require.Len(t, ps, 1)
|
||||||
|
// update hashes and serialized data.
|
||||||
|
_ = ps[0].Hash()
|
||||||
require.Equal(t, p4, ps[0])
|
require.Equal(t, p4, ps[0])
|
||||||
|
|
||||||
ps0 := ps[0].(*Payload)
|
ps0 := ps[0].(*Payload)
|
||||||
|
|
|
@ -78,8 +78,8 @@ func (b *Base) Hash() util.Uint256 {
|
||||||
return b.hash
|
return b.hash
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerificationHash returns the hash of the block used to verify it.
|
// GetSignedHash returns a hash of the block used to verify it.
|
||||||
func (b *Base) VerificationHash() util.Uint256 {
|
func (b *Base) GetSignedHash() util.Uint256 {
|
||||||
if b.verificationHash.Equals(util.Uint256{}) {
|
if b.verificationHash.Equals(util.Uint256{}) {
|
||||||
b.createHash()
|
b.createHash()
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto"
|
"github.com/nspcc-dev/neo-go/pkg/crypto"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
)
|
)
|
||||||
|
@ -30,15 +31,17 @@ func ECDSASecp256k1Verify(ic *interop.Context) error {
|
||||||
// ecdsaVerify is internal representation of ECDSASecp256k1Verify and
|
// ecdsaVerify is internal representation of ECDSASecp256k1Verify and
|
||||||
// ECDSASecp256r1Verify.
|
// ECDSASecp256r1Verify.
|
||||||
func ecdsaVerify(ic *interop.Context, curve elliptic.Curve) error {
|
func ecdsaVerify(ic *interop.Context, curve elliptic.Curve) error {
|
||||||
msg := getMessage(ic, ic.VM.Estack().Pop().Item())
|
hashToCheck, err := getMessageHash(ic, ic.VM.Estack().Pop().Item())
|
||||||
hashToCheck := hash.Sha256(msg).BytesBE()
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
keyb := ic.VM.Estack().Pop().Bytes()
|
keyb := ic.VM.Estack().Pop().Bytes()
|
||||||
signature := ic.VM.Estack().Pop().Bytes()
|
signature := ic.VM.Estack().Pop().Bytes()
|
||||||
pkey, err := keys.NewPublicKeyFromBytes(keyb, curve)
|
pkey, err := keys.NewPublicKeyFromBytes(keyb, curve)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res := pkey.Verify(signature, hashToCheck)
|
res := pkey.Verify(signature, hashToCheck.BytesBE())
|
||||||
ic.VM.Estack().PushVal(res)
|
ic.VM.Estack().PushVal(res)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -58,8 +61,10 @@ func ECDSASecp256k1CheckMultisig(ic *interop.Context) error {
|
||||||
// ecdsaCheckMultisig is internal representation of ECDSASecp256r1CheckMultisig and
|
// ecdsaCheckMultisig is internal representation of ECDSASecp256r1CheckMultisig and
|
||||||
// ECDSASecp256k1CheckMultisig
|
// ECDSASecp256k1CheckMultisig
|
||||||
func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error {
|
func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error {
|
||||||
msg := getMessage(ic, ic.VM.Estack().Pop().Item())
|
hashToCheck, err := getMessageHash(ic, ic.VM.Estack().Pop().Item())
|
||||||
hashToCheck := hash.Sha256(msg).BytesBE()
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
pkeys, err := ic.VM.Estack().PopSigElements()
|
pkeys, err := ic.VM.Estack().PopSigElements()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("wrong parameters: %w", err)
|
return fmt.Errorf("wrong parameters: %w", err)
|
||||||
|
@ -76,23 +81,23 @@ func ecdsaCheckMultisig(ic *interop.Context, curve elliptic.Curve) error {
|
||||||
if len(pkeys) < len(sigs) {
|
if len(pkeys) < len(sigs) {
|
||||||
return errors.New("more signatures than there are keys")
|
return errors.New("more signatures than there are keys")
|
||||||
}
|
}
|
||||||
sigok := vm.CheckMultisigPar(ic.VM, curve, hashToCheck, pkeys, sigs)
|
sigok := vm.CheckMultisigPar(ic.VM, curve, hashToCheck.BytesBE(), pkeys, sigs)
|
||||||
ic.VM.Estack().PushVal(sigok)
|
ic.VM.Estack().PushVal(sigok)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMessage(ic *interop.Context, item stackitem.Item) []byte {
|
func getMessageHash(ic *interop.Context, item stackitem.Item) (util.Uint256, error) {
|
||||||
var msg []byte
|
var msg []byte
|
||||||
switch val := item.(type) {
|
switch val := item.(type) {
|
||||||
case *stackitem.Interop:
|
case *stackitem.Interop:
|
||||||
msg = val.Value().(crypto.Verifiable).GetSignedPart()
|
return val.Value().(crypto.Verifiable).GetSignedHash(), nil
|
||||||
case stackitem.Null:
|
case stackitem.Null:
|
||||||
msg = ic.Container.GetSignedPart()
|
return ic.Container.GetSignedHash(), nil
|
||||||
default:
|
default:
|
||||||
var err error
|
var err error
|
||||||
if msg, err = val.TryBytes(); err != nil {
|
if msg, err = val.TryBytes(); err != nil {
|
||||||
return nil
|
return util.Uint256{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return msg
|
return hash.Sha256(msg), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,20 +2,37 @@ package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/crypto"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sha256 returns sha256 hash of the data.
|
// Sha256 returns sha256 hash of the data.
|
||||||
func Sha256(ic *interop.Context) error {
|
func Sha256(ic *interop.Context) error {
|
||||||
msg := getMessage(ic, ic.VM.Estack().Pop().Item())
|
h, err := getMessageHash(ic, ic.VM.Estack().Pop().Item())
|
||||||
h := hash.Sha256(msg).BytesBE()
|
if err != nil {
|
||||||
ic.VM.Estack().PushVal(h)
|
return err
|
||||||
|
}
|
||||||
|
ic.VM.Estack().PushVal(h.BytesBE())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RipeMD160 returns RipeMD160 hash of the data.
|
// RipeMD160 returns RipeMD160 hash of the data.
|
||||||
func RipeMD160(ic *interop.Context) error {
|
func RipeMD160(ic *interop.Context) error {
|
||||||
msg := getMessage(ic, ic.VM.Estack().Pop().Item())
|
var msg []byte
|
||||||
|
|
||||||
|
item := ic.VM.Estack().Pop().Item()
|
||||||
|
switch val := item.(type) {
|
||||||
|
case *stackitem.Interop:
|
||||||
|
msg = val.Value().(crypto.Verifiable).GetSignedPart()
|
||||||
|
case stackitem.Null:
|
||||||
|
msg = ic.Container.GetSignedPart()
|
||||||
|
default:
|
||||||
|
var err error
|
||||||
|
if msg, err = val.TryBytes(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
h := hash.RipeMD160(msg).BytesBE()
|
h := hash.RipeMD160(msg).BytesBE()
|
||||||
ic.VM.Estack().PushVal(h)
|
ic.VM.Estack().PushVal(h)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -216,7 +216,7 @@ func TestECDSAVerify(t *testing.T) {
|
||||||
|
|
||||||
t.Run("invalid message", func(t *testing.T) {
|
t.Run("invalid message", func(t *testing.T) {
|
||||||
sign := priv.Sign(msg)
|
sign := priv.Sign(msg)
|
||||||
runCase(t, false, false, sign, priv.PublicKey().Bytes(),
|
runCase(t, true, false, sign, priv.PublicKey().Bytes(),
|
||||||
stackitem.NewArray([]stackitem.Item{stackitem.NewByteArray(msg)}))
|
stackitem.NewArray([]stackitem.Item{stackitem.NewByteArray(msg)}))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,6 +59,13 @@ func (s *MPTRootBase) GetSignedPart() []byte {
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSignedHash returns hash of MPTRootBase which needs to be signed.
|
||||||
|
func (s *MPTRootBase) GetSignedHash() util.Uint256 {
|
||||||
|
buf := io.NewBufBinWriter()
|
||||||
|
s.EncodeBinary(buf.BinWriter)
|
||||||
|
return hash.Sha256(buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
// Equals checks if s == other.
|
// Equals checks if s == other.
|
||||||
func (s *MPTRootBase) Equals(other *MPTRootBase) bool {
|
func (s *MPTRootBase) Equals(other *MPTRootBase) bool {
|
||||||
return s.Version == other.Version && s.Index == other.Index &&
|
return s.Version == other.Version && s.Index == other.Index &&
|
||||||
|
|
|
@ -110,8 +110,8 @@ func (t *Transaction) Hash() util.Uint256 {
|
||||||
return t.hash
|
return t.hash
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerificationHash returns the hash of the transaction used to verify it.
|
// GetSignedHash returns a hash of the transaction used to verify it.
|
||||||
func (t *Transaction) VerificationHash() util.Uint256 {
|
func (t *Transaction) GetSignedHash() util.Uint256 {
|
||||||
if t.verificationHash.Equals(util.Uint256{}) {
|
if t.verificationHash.Equals(util.Uint256{}) {
|
||||||
if t.createHash() != nil {
|
if t.createHash() != nil {
|
||||||
panic("failed to compute hash!")
|
panic("failed to compute hash!")
|
||||||
|
|
|
@ -128,15 +128,19 @@ func (p *PrivateKey) GetScriptHash() util.Uint160 {
|
||||||
return pk.GetScriptHash()
|
return pk.GetScriptHash()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sign signs arbitrary length data using the private key.
|
// Sign signs arbitrary length data using the private key. It uses SHA256 to
|
||||||
|
// calculate hash and then SignHash to create a signature (so you can save on
|
||||||
|
// hash calculation if you already have it).
|
||||||
func (p *PrivateKey) Sign(data []byte) []byte {
|
func (p *PrivateKey) Sign(data []byte) []byte {
|
||||||
var (
|
var digest = sha256.Sum256(data)
|
||||||
privateKey = &p.PrivateKey
|
|
||||||
digest = sha256.Sum256(data)
|
|
||||||
)
|
|
||||||
|
|
||||||
r, s := rfc6979.SignECDSA(privateKey, digest[:], sha256.New)
|
return p.SignHash(digest)
|
||||||
return getSignatureSlice(privateKey.Curve, r, s)
|
}
|
||||||
|
|
||||||
|
// SignHash signs particular hash the private key.
|
||||||
|
func (p *PrivateKey) SignHash(digest util.Uint256) []byte {
|
||||||
|
r, s := rfc6979.SignECDSA(&p.PrivateKey, digest[:], sha256.New)
|
||||||
|
return getSignatureSlice(p.PrivateKey.Curve, r, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSignatureSlice(curve elliptic.Curve, r, s *big.Int) []byte {
|
func getSignatureSlice(curve elliptic.Curve, r, s *big.Int) []byte {
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
|
||||||
// Verifiable represents an object which can be verified.
|
// Verifiable represents an object which can be verified.
|
||||||
type Verifiable interface {
|
type Verifiable interface {
|
||||||
GetSignedPart() []byte
|
GetSignedPart() []byte
|
||||||
|
GetSignedHash() util.Uint256
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifiableDecodable represents an object which can be both verified and
|
// VerifiableDecodable represents an object which can be both verified and
|
||||||
|
|
Loading…
Reference in a new issue