diff --git a/pkg/core/transaction/bench_test.go b/pkg/core/transaction/bench_test.go new file mode 100644 index 000000000..1cf85a1cb --- /dev/null +++ b/pkg/core/transaction/bench_test.go @@ -0,0 +1,55 @@ +package transaction + +import ( + "encoding/base64" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/require" +) + +// Some typical transfer tx from mainnet. +var ( + benchTx []byte + benchTxB64 = "AK9KzFu0P5gAAAAAAIjOEgAAAAAA7jAAAAGIDdjSt7aj2J+dktSobkC9j0/CJwEAWwsCAMLrCwwUtXfkIuockX9HAVMNeEuQMxMlYkMMFIgN2NK3tqPYn52S1KhuQL2PT8InFMAfDAh0cmFuc2ZlcgwUz3bii9AGLEpHjuNVYQETGfPPpNJBYn1bUjkBQgxAUiZNae4OTSu2EOGW+6fwslLIpVsczOAR9o6R796tFf2KG+nLzs709tCQ7NELZOQ7zUzfF19ADLvH/efNT4v9LygMIQNT96/wFdPSBO7NUI9Kpn9EffTRXsS6ZJ9PqRvbenijVEFW57Mn" + benchTxJSON []byte +) + +func init() { + var err error + benchTx, err = base64.StdEncoding.DecodeString(benchTxB64) + if err != nil { + panic(err) + } + t, err := NewTransactionFromBytes(benchTx) + if err != nil { + panic(err) + } + benchTxJSON, err = t.MarshalJSON() + if err != nil { + panic(err) + } +} + +func BenchmarkDecodeBinary(t *testing.B) { + for n := 0; n < t.N; n++ { + r := io.NewBinReaderFromBuf(benchTx) + tx := &Transaction{} + tx.DecodeBinary(r) + require.NoError(t, r.Err) + } +} + +func BenchmarkDecodeJSON(t *testing.B) { + for n := 0; n < t.N; n++ { + tx := &Transaction{} + require.NoError(t, tx.UnmarshalJSON(benchTxJSON)) + } +} + +func BenchmarkDecodeFromBytes(t *testing.B) { + for n := 0; n < t.N; n++ { + _, err := NewTransactionFromBytes(benchTx) + require.NoError(t, err) + } +} diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 25809657c..7dd9b8ccf 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -1,6 +1,7 @@ package transaction import ( + "crypto/sha256" "encoding/json" "errors" "fmt" @@ -135,30 +136,66 @@ func (t *Transaction) decodeHashableFields(br *io.BinReader) { t.SystemFee = int64(br.ReadU64LE()) t.NetworkFee = int64(br.ReadU64LE()) t.ValidUntilBlock = br.ReadU32LE() - br.ReadArray(&t.Signers, MaxAttributes) - br.ReadArray(&t.Attributes, MaxAttributes-len(t.Signers)) + nsigners := br.ReadVarUint() + if br.Err != nil { + return + } + if nsigners > MaxAttributes { + br.Err = errors.New("too many signers") + return + } else if nsigners == 0 { + br.Err = errors.New("missing signers") + return + } + t.Signers = make([]Signer, nsigners) + for i := 0; i < int(nsigners); i++ { + t.Signers[i].DecodeBinary(br) + } + nattrs := br.ReadVarUint() + if nattrs > MaxAttributes-nsigners { + br.Err = errors.New("too many attributes") + return + } + t.Attributes = make([]Attribute, nattrs) + for i := 0; i < int(nattrs); i++ { + t.Attributes[i].DecodeBinary(br) + } t.Script = br.ReadVarBytes(MaxScriptLength) if br.Err == nil { br.Err = t.isValid() } } -// DecodeBinary implements Serializable interface. -func (t *Transaction) DecodeBinary(br *io.BinReader) { +func (t *Transaction) decodeBinaryNoSize(br *io.BinReader) { t.decodeHashableFields(br) if br.Err != nil { return } - br.ReadArray(&t.Scripts, len(t.Signers)) - if len(t.Signers) != len(t.Scripts) { + nscripts := br.ReadVarUint() + if nscripts > MaxAttributes { + br.Err = errors.New("too many witnesses") + return + } else if int(nscripts) != len(t.Signers) { br.Err = fmt.Errorf("%w: %d vs %d", ErrInvalidWitnessNum, len(t.Signers), len(t.Scripts)) return } + t.Scripts = make([]Witness, nscripts) + for i := 0; i < int(nscripts); i++ { + t.Scripts[i].DecodeBinary(br) + } // Create the hash of the transaction at decode, so we dont need // to do it anymore. if br.Err == nil { br.Err = t.createHash() + } +} + +// DecodeBinary implements Serializable interface. +func (t *Transaction) DecodeBinary(br *io.BinReader) { + t.decodeBinaryNoSize(br) + + if br.Err == nil { _ = t.Size() } } @@ -198,13 +235,14 @@ func (t *Transaction) EncodeHashableFields() ([]byte, error) { // createHash creates the hash of the transaction. func (t *Transaction) createHash() error { - buf := io.NewBufBinWriter() - t.encodeHashableFields(buf.BinWriter) - if buf.Err != nil { - return buf.Err + shaHash := sha256.New() + bw := io.NewBinWriterFromIO(shaHash) + t.encodeHashableFields(bw) + if bw.Err != nil { + return bw.Err } - t.hash = hash.Sha256(buf.Bytes()) + shaHash.Sum(t.hash[:0]) return nil } @@ -240,7 +278,7 @@ func (t *Transaction) Bytes() []byte { func NewTransactionFromBytes(b []byte) (*Transaction, error) { tx := &Transaction{} r := io.NewBinReaderFromBuf(b) - tx.DecodeBinary(r) + tx.decodeBinaryNoSize(r) if r.Err != nil { return nil, r.Err } diff --git a/pkg/network/message.go b/pkg/network/message.go index f2892aae9..5eb360d9d 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -130,7 +130,6 @@ func (m *Message) decodePayload() error { buf = d } - r := io.NewBinReaderFromBuf(buf) var p payload.Payload switch m.Command { case CMDVersion: @@ -154,7 +153,12 @@ func (m *Message) decodePayload() error { case CMDHeaders: p = &payload.Headers{StateRootInHeader: m.StateRootInHeader} case CMDTX: - p = &transaction.Transaction{} + p, err := transaction.NewTransactionFromBytes(buf) + if err != nil { + return err + } + m.Payload = p + return nil case CMDMerkleBlock: p = &payload.MerkleBlock{} case CMDPing, CMDPong: @@ -164,6 +168,7 @@ func (m *Message) decodePayload() error { default: return fmt.Errorf("can't decode command %s", m.Command.String()) } + r := io.NewBinReaderFromBuf(buf) p.DecodeBinary(r) if r.Err == nil || r.Err == payload.ErrTooManyHeaders { m.Payload = p