Merge pull request #2110 from nspcc-dev/optimize-tx-decoding

Optimize tx decoding
This commit is contained in:
Roman Khimov 2021-08-05 13:43:11 +03:00 committed by GitHub
commit f685c49cb2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 112 additions and 14 deletions

View file

@ -0,0 +1,55 @@
package transaction
import (
"encoding/base64"
"testing"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/stretchr/testify/require"
)
// Some typical transfer tx from mainnet.
var (
benchTx []byte
benchTxB64 = "AK9KzFu0P5gAAAAAAIjOEgAAAAAA7jAAAAGIDdjSt7aj2J+dktSobkC9j0/CJwEAWwsCAMLrCwwUtXfkIuockX9HAVMNeEuQMxMlYkMMFIgN2NK3tqPYn52S1KhuQL2PT8InFMAfDAh0cmFuc2ZlcgwUz3bii9AGLEpHjuNVYQETGfPPpNJBYn1bUjkBQgxAUiZNae4OTSu2EOGW+6fwslLIpVsczOAR9o6R796tFf2KG+nLzs709tCQ7NELZOQ7zUzfF19ADLvH/efNT4v9LygMIQNT96/wFdPSBO7NUI9Kpn9EffTRXsS6ZJ9PqRvbenijVEFW57Mn"
benchTxJSON []byte
)
func init() {
var err error
benchTx, err = base64.StdEncoding.DecodeString(benchTxB64)
if err != nil {
panic(err)
}
t, err := NewTransactionFromBytes(benchTx)
if err != nil {
panic(err)
}
benchTxJSON, err = t.MarshalJSON()
if err != nil {
panic(err)
}
}
func BenchmarkDecodeBinary(t *testing.B) {
for n := 0; n < t.N; n++ {
r := io.NewBinReaderFromBuf(benchTx)
tx := &Transaction{}
tx.DecodeBinary(r)
require.NoError(t, r.Err)
}
}
func BenchmarkDecodeJSON(t *testing.B) {
for n := 0; n < t.N; n++ {
tx := &Transaction{}
require.NoError(t, tx.UnmarshalJSON(benchTxJSON))
}
}
func BenchmarkDecodeFromBytes(t *testing.B) {
for n := 0; n < t.N; n++ {
_, err := NewTransactionFromBytes(benchTx)
require.NoError(t, err)
}
}

View file

@ -1,6 +1,7 @@
package transaction package transaction
import ( import (
"crypto/sha256"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -135,30 +136,66 @@ func (t *Transaction) decodeHashableFields(br *io.BinReader) {
t.SystemFee = int64(br.ReadU64LE()) t.SystemFee = int64(br.ReadU64LE())
t.NetworkFee = int64(br.ReadU64LE()) t.NetworkFee = int64(br.ReadU64LE())
t.ValidUntilBlock = br.ReadU32LE() t.ValidUntilBlock = br.ReadU32LE()
br.ReadArray(&t.Signers, MaxAttributes) nsigners := br.ReadVarUint()
br.ReadArray(&t.Attributes, MaxAttributes-len(t.Signers)) if br.Err != nil {
return
}
if nsigners > MaxAttributes {
br.Err = errors.New("too many signers")
return
} else if nsigners == 0 {
br.Err = errors.New("missing signers")
return
}
t.Signers = make([]Signer, nsigners)
for i := 0; i < int(nsigners); i++ {
t.Signers[i].DecodeBinary(br)
}
nattrs := br.ReadVarUint()
if nattrs > MaxAttributes-nsigners {
br.Err = errors.New("too many attributes")
return
}
t.Attributes = make([]Attribute, nattrs)
for i := 0; i < int(nattrs); i++ {
t.Attributes[i].DecodeBinary(br)
}
t.Script = br.ReadVarBytes(MaxScriptLength) t.Script = br.ReadVarBytes(MaxScriptLength)
if br.Err == nil { if br.Err == nil {
br.Err = t.isValid() br.Err = t.isValid()
} }
} }
// DecodeBinary implements Serializable interface. func (t *Transaction) decodeBinaryNoSize(br *io.BinReader) {
func (t *Transaction) DecodeBinary(br *io.BinReader) {
t.decodeHashableFields(br) t.decodeHashableFields(br)
if br.Err != nil { if br.Err != nil {
return return
} }
br.ReadArray(&t.Scripts, len(t.Signers)) nscripts := br.ReadVarUint()
if len(t.Signers) != len(t.Scripts) { if nscripts > MaxAttributes {
br.Err = errors.New("too many witnesses")
return
} else if int(nscripts) != len(t.Signers) {
br.Err = fmt.Errorf("%w: %d vs %d", ErrInvalidWitnessNum, len(t.Signers), len(t.Scripts)) br.Err = fmt.Errorf("%w: %d vs %d", ErrInvalidWitnessNum, len(t.Signers), len(t.Scripts))
return return
} }
t.Scripts = make([]Witness, nscripts)
for i := 0; i < int(nscripts); i++ {
t.Scripts[i].DecodeBinary(br)
}
// Create the hash of the transaction at decode, so we dont need // Create the hash of the transaction at decode, so we dont need
// to do it anymore. // to do it anymore.
if br.Err == nil { if br.Err == nil {
br.Err = t.createHash() br.Err = t.createHash()
}
}
// DecodeBinary implements Serializable interface.
func (t *Transaction) DecodeBinary(br *io.BinReader) {
t.decodeBinaryNoSize(br)
if br.Err == nil {
_ = t.Size() _ = t.Size()
} }
} }
@ -198,13 +235,14 @@ func (t *Transaction) EncodeHashableFields() ([]byte, error) {
// createHash creates the hash of the transaction. // createHash creates the hash of the transaction.
func (t *Transaction) createHash() error { func (t *Transaction) createHash() error {
buf := io.NewBufBinWriter() shaHash := sha256.New()
t.encodeHashableFields(buf.BinWriter) bw := io.NewBinWriterFromIO(shaHash)
if buf.Err != nil { t.encodeHashableFields(bw)
return buf.Err if bw.Err != nil {
return bw.Err
} }
t.hash = hash.Sha256(buf.Bytes()) shaHash.Sum(t.hash[:0])
return nil return nil
} }
@ -240,7 +278,7 @@ func (t *Transaction) Bytes() []byte {
func NewTransactionFromBytes(b []byte) (*Transaction, error) { func NewTransactionFromBytes(b []byte) (*Transaction, error) {
tx := &Transaction{} tx := &Transaction{}
r := io.NewBinReaderFromBuf(b) r := io.NewBinReaderFromBuf(b)
tx.DecodeBinary(r) tx.decodeBinaryNoSize(r)
if r.Err != nil { if r.Err != nil {
return nil, r.Err return nil, r.Err
} }

View file

@ -130,7 +130,6 @@ func (m *Message) decodePayload() error {
buf = d buf = d
} }
r := io.NewBinReaderFromBuf(buf)
var p payload.Payload var p payload.Payload
switch m.Command { switch m.Command {
case CMDVersion: case CMDVersion:
@ -154,7 +153,12 @@ func (m *Message) decodePayload() error {
case CMDHeaders: case CMDHeaders:
p = &payload.Headers{StateRootInHeader: m.StateRootInHeader} p = &payload.Headers{StateRootInHeader: m.StateRootInHeader}
case CMDTX: case CMDTX:
p = &transaction.Transaction{} p, err := transaction.NewTransactionFromBytes(buf)
if err != nil {
return err
}
m.Payload = p
return nil
case CMDMerkleBlock: case CMDMerkleBlock:
p = &payload.MerkleBlock{} p = &payload.MerkleBlock{}
case CMDPing, CMDPong: case CMDPing, CMDPong:
@ -164,6 +168,7 @@ func (m *Message) decodePayload() error {
default: default:
return fmt.Errorf("can't decode command %s", m.Command.String()) return fmt.Errorf("can't decode command %s", m.Command.String())
} }
r := io.NewBinReaderFromBuf(buf)
p.DecodeBinary(r) p.DecodeBinary(r)
if r.Err == nil || r.Err == payload.ErrTooManyHeaders { if r.Err == nil || r.Err == payload.ErrTooManyHeaders {
m.Payload = p m.Payload = p