From aacf58c9ab02388d3fb2e08c139be215c70e348d Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Sun, 15 Sep 2019 14:58:19 +0300 Subject: [PATCH] util: add 'constructors' for BinReader/BinWriter And an additional BufBinWriter to ease buffer management. --- pkg/core/account_state.go | 4 +-- pkg/core/asset_state.go | 4 +-- pkg/core/block.go | 8 ++--- pkg/core/block_base.go | 8 ++--- pkg/core/header_hash_list.go | 2 +- pkg/core/spent_coin_state.go | 4 +-- pkg/core/storage/helpers.go | 2 +- pkg/core/transaction/attribute.go | 4 +-- pkg/core/transaction/claim.go | 4 +-- pkg/core/transaction/input.go | 4 +-- pkg/core/transaction/invocation.go | 4 +-- pkg/core/transaction/output.go | 4 +-- pkg/core/transaction/publish.go | 4 +-- pkg/core/transaction/register.go | 4 +-- pkg/core/transaction/state.go | 4 +-- pkg/core/transaction/state_descriptor.go | 4 +-- pkg/core/transaction/transaction.go | 6 ++-- pkg/core/transaction/witness.go | 4 +-- pkg/core/unspent_coin_state.go | 4 +-- pkg/network/message.go | 4 +-- pkg/network/payload/address.go | 8 ++--- pkg/network/payload/getblocks.go | 4 +-- pkg/network/payload/headers.go | 4 +-- pkg/network/payload/inventory.go | 4 +-- pkg/network/payload/merkleblock.go | 2 +- pkg/network/payload/version.go | 4 +-- pkg/smartcontract/contract_test.go | 4 +-- pkg/util/binaryBufWriter.go | 29 +++++++++++++++++++ pkg/util/binaryReader.go | 26 ++++++++++++----- pkg/util/binaryWriter.go | 25 +++++++++------- pkg/util/binaryrw_test.go | 37 ++++++++++++------------ 31 files changed, 138 insertions(+), 95 deletions(-) create mode 100644 pkg/util/binaryBufWriter.go diff --git a/pkg/core/account_state.go b/pkg/core/account_state.go index 44d277e76..7325b1b4d 100644 --- a/pkg/core/account_state.go +++ b/pkg/core/account_state.go @@ -68,7 +68,7 @@ func NewAccountState(scriptHash util.Uint160) *AccountState { // DecodeBinary decodes AccountState from the given io.Reader. func (s *AccountState) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&s.Version) br.ReadLE(&s.ScriptHash) br.ReadLE(&s.IsFrozen) @@ -96,7 +96,7 @@ func (s *AccountState) DecodeBinary(r io.Reader) error { // EncodeBinary encode AccountState to the given io.Writer. func (s *AccountState) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(s.Version) bw.WriteLE(s.ScriptHash) bw.WriteLE(s.IsFrozen) diff --git a/pkg/core/asset_state.go b/pkg/core/asset_state.go index 19f7db8c3..02707eee9 100644 --- a/pkg/core/asset_state.go +++ b/pkg/core/asset_state.go @@ -47,7 +47,7 @@ type AssetState struct { // DecodeBinary implements the Payload interface. func (a *AssetState) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&a.ID) br.ReadLE(&a.AssetType) @@ -76,7 +76,7 @@ func (a *AssetState) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (a *AssetState) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(a.ID) bw.WriteLE(a.AssetType) bw.WriteString(a.Name) diff --git a/pkg/core/block.go b/pkg/core/block.go index 38d8a9122..e17483e4d 100644 --- a/pkg/core/block.go +++ b/pkg/core/block.go @@ -78,7 +78,7 @@ func NewBlockFromTrimmedBytes(b []byte) (*Block, error) { return block, err } - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) var padding uint8 br.ReadLE(&padding) if br.Err != nil { @@ -109,7 +109,7 @@ func (b *Block) Trim() ([]byte, error) { if err := b.encodeHashableFields(buf); err != nil { return nil, err } - bw := util.BinWriter{W: buf} + bw := util.NewBinWriterFromIO(buf) bw.WriteLE(uint8(1)) if bw.Err != nil { return nil, bw.Err @@ -134,7 +134,7 @@ func (b *Block) DecodeBinary(r io.Reader) error { return err } - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) lentx := br.ReadVarUint() if br.Err != nil { return br.Err @@ -156,7 +156,7 @@ func (b *Block) EncodeBinary(w io.Writer) error { if err != nil { return err } - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(b.Transactions))) if bw.Err != nil { return err diff --git a/pkg/core/block_base.go b/pkg/core/block_base.go index 00fd4d277..75f769340 100644 --- a/pkg/core/block_base.go +++ b/pkg/core/block_base.go @@ -66,7 +66,7 @@ func (b *BlockBase) DecodeBinary(r io.Reader) error { } var padding uint8 - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&padding) if br.Err != nil { return br.Err @@ -84,7 +84,7 @@ func (b *BlockBase) EncodeBinary(w io.Writer) error { if err := b.encodeHashableFields(w); err != nil { return err } - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(uint8(1)) if bw.Err != nil { return bw.Err @@ -111,7 +111,7 @@ func (b *BlockBase) createHash() error { // encodeHashableFields will only encode the fields used for hashing. // see Hash() for more information about the fields. func (b *BlockBase) encodeHashableFields(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(b.Version) bw.WriteLE(b.PrevHash) bw.WriteLE(b.MerkleRoot) @@ -125,7 +125,7 @@ func (b *BlockBase) encodeHashableFields(w io.Writer) error { // decodeHashableFields will only decode the fields used for hashing. // see Hash() for more information about the fields. func (b *BlockBase) decodeHashableFields(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&b.Version) br.ReadLE(&b.PrevHash) br.ReadLE(&b.MerkleRoot) diff --git a/pkg/core/header_hash_list.go b/pkg/core/header_hash_list.go index 35c190afa..d54723aba 100644 --- a/pkg/core/header_hash_list.go +++ b/pkg/core/header_hash_list.go @@ -59,7 +59,7 @@ func (l *HeaderHashList) Slice(start, end int) []util.Uint256 { // WriteTo will write n underlying hashes to the given io.Writer // starting from start. func (l *HeaderHashList) Write(w io.Writer, start, n int) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(n)) hashes := l.Slice(start, start+n) for _, hash := range hashes { diff --git a/pkg/core/spent_coin_state.go b/pkg/core/spent_coin_state.go index fcfde1ae7..594fd3756 100644 --- a/pkg/core/spent_coin_state.go +++ b/pkg/core/spent_coin_state.go @@ -67,7 +67,7 @@ func NewSpentCoinState(hash util.Uint256, height uint32) *SpentCoinState { // DecodeBinary implements the Payload interface. func (s *SpentCoinState) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&s.txHash) br.ReadLE(&s.txHeight) @@ -87,7 +87,7 @@ func (s *SpentCoinState) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (s *SpentCoinState) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(s.txHash) bw.WriteLE(s.txHeight) bw.WriteVarUint(uint64(len(s.items))) diff --git a/pkg/core/storage/helpers.go b/pkg/core/storage/helpers.go index 4a35d9ad7..ca3c1a7c2 100644 --- a/pkg/core/storage/helpers.go +++ b/pkg/core/storage/helpers.go @@ -84,7 +84,7 @@ func HeaderHashes(s Store) ([]util.Uint256, error) { // the given byte array. func read2000Uint256Hashes(b []byte) ([]util.Uint256, error) { r := bytes.NewReader(b) - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) lenHashes := br.ReadVarUint() hashes := make([]util.Uint256, lenHashes) br.ReadLE(hashes) diff --git a/pkg/core/transaction/attribute.go b/pkg/core/transaction/attribute.go index 1b3ca0c3b..c5df619f3 100644 --- a/pkg/core/transaction/attribute.go +++ b/pkg/core/transaction/attribute.go @@ -17,7 +17,7 @@ type Attribute struct { // DecodeBinary implements the Payload interface. func (attr *Attribute) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&attr.Usage) // very special case @@ -54,7 +54,7 @@ func (attr *Attribute) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (attr *Attribute) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(&attr.Usage) switch attr.Usage { case ECDH02, ECDH03: diff --git a/pkg/core/transaction/claim.go b/pkg/core/transaction/claim.go index 468ce0f19..f5c7c8e8d 100644 --- a/pkg/core/transaction/claim.go +++ b/pkg/core/transaction/claim.go @@ -13,7 +13,7 @@ type ClaimTX struct { // DecodeBinary implements the Payload interface. func (tx *ClaimTX) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) lenClaims := br.ReadVarUint() if br.Err != nil { return br.Err @@ -30,7 +30,7 @@ func (tx *ClaimTX) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (tx *ClaimTX) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(tx.Claims))) if bw.Err != nil { return bw.Err diff --git a/pkg/core/transaction/input.go b/pkg/core/transaction/input.go index ab923981f..f2bd3e4df 100644 --- a/pkg/core/transaction/input.go +++ b/pkg/core/transaction/input.go @@ -17,7 +17,7 @@ type Input struct { // DecodeBinary implements the Payload interface. func (in *Input) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&in.PrevHash) br.ReadLE(&in.PrevIndex) return br.Err @@ -25,7 +25,7 @@ func (in *Input) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (in *Input) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(in.PrevHash) bw.WriteLE(in.PrevIndex) return bw.Err diff --git a/pkg/core/transaction/invocation.go b/pkg/core/transaction/invocation.go index 0b255426c..36df0a948 100644 --- a/pkg/core/transaction/invocation.go +++ b/pkg/core/transaction/invocation.go @@ -34,7 +34,7 @@ func NewInvocationTX(script []byte) *Transaction { // DecodeBinary implements the Payload interface. func (tx *InvocationTX) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) tx.Script = br.ReadBytes() if tx.Version >= 1 { br.ReadLE(&tx.Gas) @@ -46,7 +46,7 @@ func (tx *InvocationTX) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (tx *InvocationTX) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteBytes(tx.Script) if tx.Version >= 1 { bw.WriteLE(tx.Gas) diff --git a/pkg/core/transaction/output.go b/pkg/core/transaction/output.go index 60a55b5cf..1afa21e5a 100644 --- a/pkg/core/transaction/output.go +++ b/pkg/core/transaction/output.go @@ -35,7 +35,7 @@ func NewOutput(assetID util.Uint256, amount util.Fixed8, scriptHash util.Uint160 // DecodeBinary implements the Payload interface. func (out *Output) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&out.AssetID) br.ReadLE(&out.Amount) br.ReadLE(&out.ScriptHash) @@ -44,7 +44,7 @@ func (out *Output) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (out *Output) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(out.AssetID) bw.WriteLE(out.Amount) bw.WriteLE(out.ScriptHash) diff --git a/pkg/core/transaction/publish.go b/pkg/core/transaction/publish.go index 3a3c735f3..4d20de820 100644 --- a/pkg/core/transaction/publish.go +++ b/pkg/core/transaction/publish.go @@ -24,7 +24,7 @@ type PublishTX struct { // DecodeBinary implements the Payload interface. func (tx *PublishTX) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) tx.Script = br.ReadBytes() lenParams := br.ReadVarUint() @@ -56,7 +56,7 @@ func (tx *PublishTX) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (tx *PublishTX) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteBytes(tx.Script) bw.WriteVarUint(uint64(len(tx.ParamList))) for _, param := range tx.ParamList { diff --git a/pkg/core/transaction/register.go b/pkg/core/transaction/register.go index 96e5e1c6c..41ed9ca75 100644 --- a/pkg/core/transaction/register.go +++ b/pkg/core/transaction/register.go @@ -31,7 +31,7 @@ type RegisterTX struct { // DecodeBinary implements the Payload interface. func (tx *RegisterTX) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&tx.AssetType) tx.Name = br.ReadString() @@ -53,7 +53,7 @@ func (tx *RegisterTX) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (tx *RegisterTX) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(tx.AssetType) bw.WriteString(tx.Name) bw.WriteLE(tx.Amount) diff --git a/pkg/core/transaction/state.go b/pkg/core/transaction/state.go index 5ee0c4a0d..79127d14d 100644 --- a/pkg/core/transaction/state.go +++ b/pkg/core/transaction/state.go @@ -13,7 +13,7 @@ type StateTX struct { // DecodeBinary implements the Payload interface. func (tx *StateTX) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) lenDesc := br.ReadVarUint() if br.Err != nil { return br.Err @@ -30,7 +30,7 @@ func (tx *StateTX) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (tx *StateTX) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(tx.Descriptors))) if bw.Err != nil { return bw.Err diff --git a/pkg/core/transaction/state_descriptor.go b/pkg/core/transaction/state_descriptor.go index 80a30ab04..2ad49c494 100644 --- a/pkg/core/transaction/state_descriptor.go +++ b/pkg/core/transaction/state_descriptor.go @@ -25,7 +25,7 @@ type StateDescriptor struct { // DecodeBinary implements the Payload interface. func (s *StateDescriptor) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&s.Type) s.Key = br.ReadBytes() @@ -37,7 +37,7 @@ func (s *StateDescriptor) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (s *StateDescriptor) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(s.Type) bw.WriteBytes(s.Key) bw.WriteBytes(s.Value) diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 90efcc777..8dedd17c1 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -78,7 +78,7 @@ func (t *Transaction) AddInput(in *Input) { // DecodeBinary implements the payload interface. func (t *Transaction) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&t.Type) br.ReadLE(&t.Version) if br.Err != nil { @@ -173,7 +173,7 @@ func (t *Transaction) EncodeBinary(w io.Writer) error { if err := t.encodeHashableFields(w); err != nil { return err } - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(t.Scripts))) if bw.Err != nil { return bw.Err @@ -189,7 +189,7 @@ func (t *Transaction) EncodeBinary(w io.Writer) error { // encodeHashableFields will only encode the fields that are not used for // signing the transaction, which are all fields except the scripts. func (t *Transaction) encodeHashableFields(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(t.Type) bw.WriteLE(t.Version) diff --git a/pkg/core/transaction/witness.go b/pkg/core/transaction/witness.go index 125fede63..d3a07e44d 100644 --- a/pkg/core/transaction/witness.go +++ b/pkg/core/transaction/witness.go @@ -17,7 +17,7 @@ type Witness struct { // DecodeBinary implements the payload interface. func (w *Witness) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) w.InvocationScript = br.ReadBytes() w.VerificationScript = br.ReadBytes() @@ -26,7 +26,7 @@ func (w *Witness) DecodeBinary(r io.Reader) error { // EncodeBinary implements the payload interface. func (w *Witness) EncodeBinary(writer io.Writer) error { - bw := util.BinWriter{W: writer} + bw := util.NewBinWriterFromIO(writer) bw.WriteBytes(w.InvocationScript) bw.WriteBytes(w.VerificationScript) diff --git a/pkg/core/unspent_coin_state.go b/pkg/core/unspent_coin_state.go index e8972934f..1954a750c 100644 --- a/pkg/core/unspent_coin_state.go +++ b/pkg/core/unspent_coin_state.go @@ -67,7 +67,7 @@ func (u UnspentCoins) commit(b storage.Batch) error { // EncodeBinary encodes UnspentCoinState to the given io.Writer. func (s *UnspentCoinState) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(s.states))) for _, state := range s.states { bw.WriteLE(byte(state)) @@ -77,7 +77,7 @@ func (s *UnspentCoinState) EncodeBinary(w io.Writer) error { // DecodeBinary decodes UnspentCoinState from the given io.Reader. func (s *UnspentCoinState) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) lenStates := br.ReadVarUint() s.states = make([]CoinState, lenStates) for i := 0; i < int(lenStates); i++ { diff --git a/pkg/network/message.go b/pkg/network/message.go index 34bf4b137..01b1a0947 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -148,7 +148,7 @@ func (m *Message) CommandType() CommandType { // Decode a Message from the given reader. func (m *Message) Decode(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&m.Magic) br.ReadLE(&m.Command) br.ReadLE(&m.Length) @@ -232,7 +232,7 @@ func (m *Message) decodePayload(r io.Reader) error { // Encode a Message to any given io.Writer. func (m *Message) Encode(w io.Writer) error { - br := util.BinWriter{W: w} + br := util.NewBinWriterFromIO(w) br.WriteLE(m.Magic) br.WriteLE(m.Command) br.WriteLE(m.Length) diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go index a83498a7d..d0b1765cd 100644 --- a/pkg/network/payload/address.go +++ b/pkg/network/payload/address.go @@ -30,7 +30,7 @@ func NewAddressAndTime(e *net.TCPAddr, t time.Time) *AddressAndTime { // DecodeBinary implements the Payload interface. func (p *AddressAndTime) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&p.Timestamp) br.ReadLE(&p.Services) br.ReadBE(&p.IP) @@ -40,7 +40,7 @@ func (p *AddressAndTime) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (p *AddressAndTime) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(p.Timestamp) bw.WriteLE(p.Services) bw.WriteBE(p.IP) @@ -72,7 +72,7 @@ func NewAddressList(n int) *AddressList { // DecodeBinary implements the Payload interface. func (p *AddressList) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) listLen := br.ReadVarUint() if br.Err != nil { return br.Err @@ -90,7 +90,7 @@ func (p *AddressList) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (p *AddressList) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(p.Addrs))) if bw.Err != nil { return bw.Err diff --git a/pkg/network/payload/getblocks.go b/pkg/network/payload/getblocks.go index 0e0a3419e..1246a9348 100644 --- a/pkg/network/payload/getblocks.go +++ b/pkg/network/payload/getblocks.go @@ -24,7 +24,7 @@ func NewGetBlocks(start []util.Uint256, stop util.Uint256) *GetBlocks { // DecodeBinary implements the payload interface. func (p *GetBlocks) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) lenStart := br.ReadVarUint() p.HashStart = make([]util.Uint256, lenStart) @@ -35,7 +35,7 @@ func (p *GetBlocks) DecodeBinary(r io.Reader) error { // EncodeBinary implements the payload interface. func (p *GetBlocks) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(p.HashStart))) bw.WriteLE(p.HashStart) bw.WriteLE(p.HashStop) diff --git a/pkg/network/payload/headers.go b/pkg/network/payload/headers.go index b0f45de9b..1e3904d1b 100644 --- a/pkg/network/payload/headers.go +++ b/pkg/network/payload/headers.go @@ -20,7 +20,7 @@ const ( // DecodeBinary implements the Payload interface. func (p *Headers) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) lenHeaders := br.ReadVarUint() if br.Err != nil { return br.Err @@ -46,7 +46,7 @@ func (p *Headers) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (p *Headers) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteVarUint(uint64(len(p.Hdrs))) if bw.Err != nil { return bw.Err diff --git a/pkg/network/payload/inventory.go b/pkg/network/payload/inventory.go index 012a9cf82..cdfd783fe 100644 --- a/pkg/network/payload/inventory.go +++ b/pkg/network/payload/inventory.go @@ -57,7 +57,7 @@ func NewInventory(typ InventoryType, hashes []util.Uint256) *Inventory { // DecodeBinary implements the Payload interface. func (p *Inventory) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&p.Type) listLen := br.ReadVarUint() @@ -71,7 +71,7 @@ func (p *Inventory) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (p *Inventory) EncodeBinary(w io.Writer) error { - bw := util.BinWriter{W: w} + bw := util.NewBinWriterFromIO(w) bw.WriteLE(p.Type) listLen := len(p.Hashes) diff --git a/pkg/network/payload/merkleblock.go b/pkg/network/payload/merkleblock.go index 3b1e90eae..2396bec85 100644 --- a/pkg/network/payload/merkleblock.go +++ b/pkg/network/payload/merkleblock.go @@ -21,7 +21,7 @@ func (m *MerkleBlock) DecodeBinary(r io.Reader) error { if err := m.BlockBase.DecodeBinary(r); err != nil { return err } - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) m.TxCount = int(br.ReadVarUint()) n := br.ReadVarUint() diff --git a/pkg/network/payload/version.go b/pkg/network/payload/version.go index 8f050d9f7..3402ed1cd 100644 --- a/pkg/network/payload/version.go +++ b/pkg/network/payload/version.go @@ -56,7 +56,7 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version { // DecodeBinary implements the Payload interface. func (p *Version) DecodeBinary(r io.Reader) error { - br := util.BinReader{R: r} + br := util.NewBinReaderFromIO(r) br.ReadLE(&p.Version) br.ReadLE(&p.Services) br.ReadLE(&p.Timestamp) @@ -70,7 +70,7 @@ func (p *Version) DecodeBinary(r io.Reader) error { // EncodeBinary implements the Payload interface. func (p *Version) EncodeBinary(w io.Writer) error { - br := util.BinWriter{W: w} + br := util.NewBinWriterFromIO(w) br.WriteLE(p.Version) br.WriteLE(p.Services) br.WriteLE(p.Timestamp) diff --git a/pkg/smartcontract/contract_test.go b/pkg/smartcontract/contract_test.go index 37c76d575..d42e41811 100644 --- a/pkg/smartcontract/contract_test.go +++ b/pkg/smartcontract/contract_test.go @@ -1,7 +1,6 @@ package smartcontract import ( - "bytes" "testing" "github.com/CityOfZion/neo-go/pkg/crypto/keys" @@ -22,8 +21,7 @@ func TestCreateMultiSigRedeemScript(t *testing.T) { t.Fatal(err) } - buf := bytes.NewBuffer(out) - br := util.BinReader{R: buf} + br := util.NewBinReaderFromBuf(out) var b uint8 br.ReadLE(&b) assert.Equal(t, vm.PUSH3, vm.Instruction(b)) diff --git a/pkg/util/binaryBufWriter.go b/pkg/util/binaryBufWriter.go new file mode 100644 index 000000000..0a2ba4cfc --- /dev/null +++ b/pkg/util/binaryBufWriter.go @@ -0,0 +1,29 @@ +package util + +import ( + "bytes" + "errors" +) + +// BufBinWriter is an additional layer on top of BinWriter that +// automatically creates buffer to write into that you can get after all +// writes via Bytes(). +type BufBinWriter struct { + *BinWriter + buf *bytes.Buffer +} + +// NewBufBinWriter makes a BufBinWriter with an empty byte buffer. +func NewBufBinWriter() *BufBinWriter { + b := new(bytes.Buffer) + return &BufBinWriter{BinWriter: NewBinWriterFromIO(b), buf: b} +} + +// Bytes returns resulting buffer and makes future writes return an error. +func (bw *BufBinWriter) Bytes() []byte { + if bw.Err != nil { + return nil + } + bw.Err = errors.New("buffer already drained") + return bw.buf.Bytes() +} diff --git a/pkg/util/binaryReader.go b/pkg/util/binaryReader.go index fe92ae61a..47694891a 100644 --- a/pkg/util/binaryReader.go +++ b/pkg/util/binaryReader.go @@ -1,6 +1,7 @@ package util import ( + "bytes" "encoding/binary" "io" ) @@ -8,17 +9,28 @@ import ( //BinReader is a convenient wrapper around a io.Reader and err object // Used to simplify error handling when reading into a struct with many fields type BinReader struct { - R io.Reader + r io.Reader Err error } +// NewBinReaderFromIO makes a BinReader from io.Reader. +func NewBinReaderFromIO(ior io.Reader) *BinReader { + return &BinReader{r: ior} +} + +// NewBinReaderFromBuf makes a BinReader from byte buffer. +func NewBinReaderFromBuf(b []byte) *BinReader { + r := bytes.NewReader(b) + return NewBinReaderFromIO(r) +} + // ReadLE reads from the underlying io.Reader // into the interface v in little-endian format func (r *BinReader) ReadLE(v interface{}) { if r.Err != nil { return } - r.Err = binary.Read(r.R, binary.LittleEndian, v) + r.Err = binary.Read(r.r, binary.LittleEndian, v) } // ReadBE reads from the underlying io.Reader @@ -27,7 +39,7 @@ func (r *BinReader) ReadBE(v interface{}) { if r.Err != nil { return } - r.Err = binary.Read(r.R, binary.BigEndian, v) + r.Err = binary.Read(r.r, binary.BigEndian, v) } // ReadVarUint reads a variable-length-encoded integer from the @@ -38,21 +50,21 @@ func (r *BinReader) ReadVarUint() uint64 { } var b uint8 - r.Err = binary.Read(r.R, binary.LittleEndian, &b) + r.Err = binary.Read(r.r, binary.LittleEndian, &b) if b == 0xfd { var v uint16 - r.Err = binary.Read(r.R, binary.LittleEndian, &v) + r.Err = binary.Read(r.r, binary.LittleEndian, &v) return uint64(v) } if b == 0xfe { var v uint32 - r.Err = binary.Read(r.R, binary.LittleEndian, &v) + r.Err = binary.Read(r.r, binary.LittleEndian, &v) return uint64(v) } if b == 0xff { var v uint64 - r.Err = binary.Read(r.R, binary.LittleEndian, &v) + r.Err = binary.Read(r.r, binary.LittleEndian, &v) return v } diff --git a/pkg/util/binaryWriter.go b/pkg/util/binaryWriter.go index 8f2c5229d..b4e58fc36 100644 --- a/pkg/util/binaryWriter.go +++ b/pkg/util/binaryWriter.go @@ -10,16 +10,21 @@ import ( // Used to simplify error handling when writing into a io.Writer // from a struct with many fields type BinWriter struct { - W io.Writer + w io.Writer Err error } +// NewBinWriterFromIO makes a BinWriter from io.Writer. +func NewBinWriterFromIO(iow io.Writer) *BinWriter { + return &BinWriter{w: iow} +} + // WriteLE writes into the underlying io.Writer from an object v in little-endian format func (w *BinWriter) WriteLE(v interface{}) { if w.Err != nil { return } - w.Err = binary.Write(w.W, binary.LittleEndian, v) + w.Err = binary.Write(w.w, binary.LittleEndian, v) } // WriteBE writes into the underlying io.Writer from an object v in big-endian format @@ -27,7 +32,7 @@ func (w *BinWriter) WriteBE(v interface{}) { if w.Err != nil { return } - w.Err = binary.Write(w.W, binary.BigEndian, v) + w.Err = binary.Write(w.w, binary.BigEndian, v) } // WriteVarUint writes a uint64 into the underlying writer using variable-length encoding @@ -46,23 +51,23 @@ func (w *BinWriter) WriteVarUint(val uint64) { } if val < 0xfd { - w.Err = binary.Write(w.W, binary.LittleEndian, uint8(val)) + w.Err = binary.Write(w.w, binary.LittleEndian, uint8(val)) return } if val < 0xFFFF { - w.Err = binary.Write(w.W, binary.LittleEndian, byte(0xfd)) - w.Err = binary.Write(w.W, binary.LittleEndian, uint16(val)) + w.Err = binary.Write(w.w, binary.LittleEndian, byte(0xfd)) + w.Err = binary.Write(w.w, binary.LittleEndian, uint16(val)) return } if val < 0xFFFFFFFF { - w.Err = binary.Write(w.W, binary.LittleEndian, byte(0xfe)) - w.Err = binary.Write(w.W, binary.LittleEndian, uint32(val)) + w.Err = binary.Write(w.w, binary.LittleEndian, byte(0xfe)) + w.Err = binary.Write(w.w, binary.LittleEndian, uint32(val)) return } - w.Err = binary.Write(w.W, binary.LittleEndian, byte(0xff)) - w.Err = binary.Write(w.W, binary.LittleEndian, val) + w.Err = binary.Write(w.w, binary.LittleEndian, byte(0xff)) + w.Err = binary.Write(w.w, binary.LittleEndian, val) } diff --git a/pkg/util/binaryrw_test.go b/pkg/util/binaryrw_test.go index 7bf320f20..34931afd2 100644 --- a/pkg/util/binaryrw_test.go +++ b/pkg/util/binaryrw_test.go @@ -1,7 +1,6 @@ package util import ( - "bytes" "testing" "github.com/stretchr/testify/assert" @@ -10,25 +9,25 @@ import ( func TestWriteVarUint1(t *testing.T) { var ( val = uint64(1) - buf = new(bytes.Buffer) ) - bw := BinWriter{W: buf} + bw := NewBufBinWriter() bw.WriteVarUint(val) assert.Nil(t, bw.Err) - assert.Equal(t, 1, buf.Len()) + buf := bw.Bytes() + assert.Equal(t, 1, len(buf)) } func TestWriteVarUint1000(t *testing.T) { var ( val = uint64(1000) - buf = new(bytes.Buffer) ) - bw := BinWriter{W: buf} + bw := NewBufBinWriter() bw.WriteVarUint(val) assert.Nil(t, bw.Err) - assert.Equal(t, 3, buf.Len()) - assert.Equal(t, byte(0xfd), buf.Bytes()[0]) - br := BinReader{R: buf} + buf := bw.Bytes() + assert.Equal(t, 3, len(buf)) + assert.Equal(t, byte(0xfd), buf[0]) + br := NewBinReaderFromBuf(buf) res := br.ReadVarUint() assert.Nil(t, br.Err) assert.Equal(t, val, res) @@ -37,14 +36,14 @@ func TestWriteVarUint1000(t *testing.T) { func TestWriteVarUint100000(t *testing.T) { var ( val = uint64(100000) - buf = new(bytes.Buffer) ) - bw := BinWriter{W: buf} + bw := NewBufBinWriter() bw.WriteVarUint(val) assert.Nil(t, bw.Err) - assert.Equal(t, 5, buf.Len()) - assert.Equal(t, byte(0xfe), buf.Bytes()[0]) - br := BinReader{R: buf} + buf := bw.Bytes() + assert.Equal(t, 5, len(buf)) + assert.Equal(t, byte(0xfe), buf[0]) + br := NewBinReaderFromBuf(buf) res := br.ReadVarUint() assert.Nil(t, br.Err) assert.Equal(t, val, res) @@ -53,14 +52,14 @@ func TestWriteVarUint100000(t *testing.T) { func TestWriteVarUint100000000000(t *testing.T) { var ( val = uint64(1000000000000) - buf = new(bytes.Buffer) ) - bw := BinWriter{W: buf} + bw := NewBufBinWriter() bw.WriteVarUint(val) assert.Nil(t, bw.Err) - assert.Equal(t, 9, buf.Len()) - assert.Equal(t, byte(0xff), buf.Bytes()[0]) - br := BinReader{R: buf} + buf := bw.Bytes() + assert.Equal(t, 9, len(buf)) + assert.Equal(t, byte(0xff), buf[0]) + br := NewBinReaderFromBuf(buf) res := br.ReadVarUint() assert.Nil(t, br.Err) assert.Equal(t, val, res)