diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 012ceb142..f4c267b0e 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -130,7 +130,12 @@ func (t *Transaction) GetAttributes(typ AttrType) []Attribute { // decodeHashableFields decodes the fields that are used for signing the // transaction, which are all fields except the scripts. -func (t *Transaction) decodeHashableFields(br *io.BinReader) { +func (t *Transaction) decodeHashableFields(br *io.BinReader, buf []byte) { + var start, end int + + if buf != nil { + start = len(buf) - br.Len() + } t.Version = uint8(br.ReadB()) t.Nonce = br.ReadU32LE() t.SystemFee = int64(br.ReadU64LE()) @@ -164,10 +169,14 @@ func (t *Transaction) decodeHashableFields(br *io.BinReader) { if br.Err == nil { br.Err = t.isValid() } + if buf != nil { + end = len(buf) - br.Len() + t.hash = hash.Sha256(buf[start:end]) + } } -func (t *Transaction) decodeBinaryNoSize(br *io.BinReader) { - t.decodeHashableFields(br) +func (t *Transaction) decodeBinaryNoSize(br *io.BinReader, buf []byte) { + t.decodeHashableFields(br, buf) if br.Err != nil { return } @@ -186,14 +195,14 @@ func (t *Transaction) decodeBinaryNoSize(br *io.BinReader) { // Create the hash of the transaction at decode, so we dont need // to do it anymore. - if br.Err == nil { + if br.Err == nil && buf == nil { br.Err = t.createHash() } } // DecodeBinary implements Serializable interface. func (t *Transaction) DecodeBinary(br *io.BinReader) { - t.decodeBinaryNoSize(br) + t.decodeBinaryNoSize(br, nil) if br.Err == nil { _ = t.Size() @@ -258,18 +267,15 @@ func (t *Transaction) createHash() error { // DecodeHashableFields decodes a part of transaction which should be hashed. func (t *Transaction) DecodeHashableFields(buf []byte) error { r := io.NewBinReaderFromBuf(buf) - t.decodeHashableFields(r) + t.decodeHashableFields(r, buf) if r.Err != nil { return r.Err } // Ensure all the data was read. - _ = r.ReadB() - if r.Err == nil { + if r.Len() != 0 { return errors.New("additional data after the signed part") } t.Scripts = make([]Witness, 0) - - t.hash = hash.Sha256(buf) return nil } @@ -287,12 +293,11 @@ func (t *Transaction) Bytes() []byte { func NewTransactionFromBytes(b []byte) (*Transaction, error) { tx := &Transaction{} r := io.NewBinReaderFromBuf(b) - tx.decodeBinaryNoSize(r) + tx.decodeBinaryNoSize(r, b) if r.Err != nil { return nil, r.Err } - _ = r.ReadB() - if r.Err == nil { + if r.Len() != 0 { return nil, errors.New("additional data after the transaction") } tx.size = len(b) diff --git a/pkg/core/transaction/transaction_test.go b/pkg/core/transaction/transaction_test.go index 62a46123c..52318fdd8 100644 --- a/pkg/core/transaction/transaction_test.go +++ b/pkg/core/transaction/transaction_test.go @@ -97,6 +97,11 @@ func TestNewTransactionFromBytes(t *testing.T) { require.NoError(t, err) require.Equal(t, tx, tx1) + tx2 := new(Transaction) + err = testserdes.DecodeBinary(data, tx2) + require.NoError(t, err) + require.Equal(t, tx1, tx2) + data = append(data, 42) _, err = NewTransactionFromBytes(data) require.Error(t, err) diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index fdcb08249..e31921400 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -16,20 +16,13 @@ const MaxArraySize = 0x1000000 // Used to simplify error handling when reading into a struct with many fields. type BinReader struct { r io.Reader - u64 []byte - u32 []byte - u16 []byte - u8 []byte + uv [8]byte Err error } // NewBinReaderFromIO makes a BinReader from io.Reader. func NewBinReaderFromIO(ior io.Reader) *BinReader { - u64 := make([]byte, 8) - u32 := u64[:4] - u16 := u64[:2] - u8 := u64[:1] - return &BinReader{r: ior, u64: u64, u32: u32, u16: u16, u8: u8} + return &BinReader{r: ior} } // NewBinReaderFromBuf makes a BinReader from byte buffer. @@ -38,54 +31,65 @@ func NewBinReaderFromBuf(b []byte) *BinReader { return NewBinReaderFromIO(r) } +// Len returns the number of bytes of the unread portion of the buffer if +// reading from bytes.Reader or -1 otherwise. +func (r *BinReader) Len() int { + var res = -1 + byteReader, ok := r.r.(*bytes.Reader) + if ok { + res = byteReader.Len() + } + return res +} + // ReadU64LE reads a little-endian encoded uint64 value from the underlying // io.Reader. On read failures it returns zero. func (r *BinReader) ReadU64LE() uint64 { - r.ReadBytes(r.u64) + r.ReadBytes(r.uv[:8]) if r.Err != nil { return 0 } - return binary.LittleEndian.Uint64(r.u64) + return binary.LittleEndian.Uint64(r.uv[:8]) } // ReadU32LE reads a little-endian encoded uint32 value from the underlying // io.Reader. On read failures it returns zero. func (r *BinReader) ReadU32LE() uint32 { - r.ReadBytes(r.u32) + r.ReadBytes(r.uv[:4]) if r.Err != nil { return 0 } - return binary.LittleEndian.Uint32(r.u32) + return binary.LittleEndian.Uint32(r.uv[:4]) } // ReadU16LE reads a little-endian encoded uint16 value from the underlying // io.Reader. On read failures it returns zero. func (r *BinReader) ReadU16LE() uint16 { - r.ReadBytes(r.u16) + r.ReadBytes(r.uv[:2]) if r.Err != nil { return 0 } - return binary.LittleEndian.Uint16(r.u16) + return binary.LittleEndian.Uint16(r.uv[:2]) } // ReadU16BE reads a big-endian encoded uint16 value from the underlying // io.Reader. On read failures it returns zero. func (r *BinReader) ReadU16BE() uint16 { - r.ReadBytes(r.u16) + r.ReadBytes(r.uv[:2]) if r.Err != nil { return 0 } - return binary.BigEndian.Uint16(r.u16) + return binary.BigEndian.Uint16(r.uv[:2]) } // ReadB reads a byte from the underlying io.Reader. On read failures it // returns zero. func (r *BinReader) ReadB() byte { - r.ReadBytes(r.u8) + r.ReadBytes(r.uv[:1]) if r.Err != nil { return 0 } - return r.u8[0] + return r.uv[0] } // ReadBool reads a boolean value encoded in a zero/non-zero byte from the diff --git a/pkg/network/payload/notary_request.go b/pkg/network/payload/notary_request.go index 0d2d67b5c..1222c0263 100644 --- a/pkg/network/payload/notary_request.go +++ b/pkg/network/payload/notary_request.go @@ -29,8 +29,7 @@ func NewP2PNotaryRequestFromBytes(b []byte) (*P2PNotaryRequest, error) { if br.Err != nil { return nil, br.Err } - _ = br.ReadB() - if br.Err == nil { + if br.Len() != 0 { return nil, errors.New("additional data after the payload") } return req, nil