diff --git a/pkg/consensus/prepare_request.go b/pkg/consensus/prepare_request.go index 02c374747..8f6f500cd 100644 --- a/pkg/consensus/prepare_request.go +++ b/pkg/consensus/prepare_request.go @@ -20,10 +20,7 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) { w.WriteLE(p.Timestamp) w.WriteLE(p.Nonce) w.WriteBE(p.NextConsensus[:]) - w.WriteVarUint(uint64(len(p.TransactionHashes))) - for i := range p.TransactionHashes { - w.WriteBE(p.TransactionHashes[i][:]) - } + w.WriteArray(p.TransactionHashes) p.MinerTransaction.EncodeBinary(w) } @@ -32,12 +29,6 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) { r.ReadLE(&p.Timestamp) r.ReadLE(&p.Nonce) r.ReadBE(p.NextConsensus[:]) - - lenHashes := r.ReadVarUint() - p.TransactionHashes = make([]util.Uint256, lenHashes) - for i := range p.TransactionHashes { - r.ReadBE(p.TransactionHashes[i][:]) - } - + r.ReadArray(&p.TransactionHashes) p.MinerTransaction.DecodeBinary(r) } diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index 8b248b271..e1dddc27e 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -40,7 +40,7 @@ const uint256size = 32 // DecodeBinary implements io.Serializable interface. func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { - m.ChangeViewPayloads = r.ReadArray(changeViewCompact{}).([]*changeViewCompact) + r.ReadArray(&m.ChangeViewPayloads) var hasReq bool r.ReadLE(&hasReq) @@ -61,8 +61,8 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { } } - m.PreparationPayloads = r.ReadArray(preparationCompact{}).([]*preparationCompact) - m.CommitPayloads = r.ReadArray(commitCompact{}).([]*commitCompact) + r.ReadArray(&m.PreparationPayloads) + r.ReadArray(&m.CommitPayloads) } // EncodeBinary implements io.Serializable interface. diff --git a/pkg/core/account_state.go b/pkg/core/account_state.go index 541381f5b..9404fd7bf 100644 --- a/pkg/core/account_state.go +++ b/pkg/core/account_state.go @@ -94,7 +94,7 @@ func (s *AccountState) DecodeBinary(br *io.BinReader) { br.ReadLE(&s.Version) br.ReadLE(&s.ScriptHash) br.ReadLE(&s.IsFrozen) - s.Votes = br.ReadArray(keys.PublicKey{}).([]*keys.PublicKey) + br.ReadArray(&s.Votes) s.Balances = make(map[util.Uint256]util.Fixed8) lenBalances := br.ReadVarUint() diff --git a/pkg/core/block.go b/pkg/core/block.go index 90b94a749..f02fb425f 100644 --- a/pkg/core/block.go +++ b/pkg/core/block.go @@ -128,17 +128,14 @@ func (b *Block) Trim() ([]byte, error) { // Serializable interface. func (b *Block) DecodeBinary(br *io.BinReader) { b.BlockBase.DecodeBinary(br) - b.Transactions = br.ReadArray(transaction.Transaction{}).([]*transaction.Transaction) + br.ReadArray(&b.Transactions) } // EncodeBinary encodes the block to the given BinWriter, implementing // Serializable interface. func (b *Block) EncodeBinary(bw *io.BinWriter) { b.BlockBase.EncodeBinary(bw) - bw.WriteVarUint(uint64(len(b.Transactions))) - for _, tx := range b.Transactions { - tx.EncodeBinary(bw) - } + bw.WriteArray(b.Transactions) } // Compare implements the queue Item interface. diff --git a/pkg/core/contract_state.go b/pkg/core/contract_state.go index b5e736b7a..e3efdce27 100644 --- a/pkg/core/contract_state.go +++ b/pkg/core/contract_state.go @@ -39,11 +39,7 @@ func (a Contracts) commit(store storage.Store) error { // DecodeBinary implements Serializable interface. func (cs *ContractState) DecodeBinary(br *io.BinReader) { cs.Script = br.ReadBytes() - paramBytes := br.ReadBytes() - cs.ParamList = make([]smartcontract.ParamType, len(paramBytes)) - for k := range paramBytes { - cs.ParamList[k] = smartcontract.ParamType(paramBytes[k]) - } + br.ReadArray(&cs.ParamList) br.ReadLE(&cs.ReturnType) br.ReadLE(&cs.Properties) cs.Name = br.ReadString() @@ -57,10 +53,7 @@ func (cs *ContractState) DecodeBinary(br *io.BinReader) { // EncodeBinary implements Serializable interface. func (cs *ContractState) EncodeBinary(bw *io.BinWriter) { bw.WriteBytes(cs.Script) - bw.WriteVarUint(uint64(len(cs.ParamList))) - for k := range cs.ParamList { - bw.WriteLE(cs.ParamList[k]) - } + bw.WriteArray(cs.ParamList) bw.WriteLE(cs.ReturnType) bw.WriteLE(cs.Properties) bw.WriteString(cs.Name) diff --git a/pkg/core/header_hash_list.go b/pkg/core/header_hash_list.go index 8f5bb6264..1a4b2fc05 100644 --- a/pkg/core/header_hash_list.go +++ b/pkg/core/header_hash_list.go @@ -58,10 +58,6 @@ func (l *HeaderHashList) Slice(start, end int) []util.Uint256 { // WriteTo writes n underlying hashes to the given BinWriter // starting from start. func (l *HeaderHashList) Write(bw *io.BinWriter, start, n int) error { - bw.WriteVarUint(uint64(n)) - hashes := l.Slice(start, start+n) - for _, hash := range hashes { - bw.WriteLE(hash) - } + bw.WriteArray(l.Slice(start, start+n)) return bw.Err } diff --git a/pkg/core/transaction/claim.go b/pkg/core/transaction/claim.go index ef7eecfae..423a4bc79 100644 --- a/pkg/core/transaction/claim.go +++ b/pkg/core/transaction/claim.go @@ -11,7 +11,7 @@ type ClaimTX struct { // DecodeBinary implements Serializable interface. func (tx *ClaimTX) DecodeBinary(br *io.BinReader) { - tx.Claims = br.ReadArray(Input{}).([]*Input) + br.ReadArray(&tx.Claims) } // EncodeBinary implements Serializable interface. diff --git a/pkg/core/transaction/state.go b/pkg/core/transaction/state.go index c3c181125..76178a205 100644 --- a/pkg/core/transaction/state.go +++ b/pkg/core/transaction/state.go @@ -11,7 +11,7 @@ type StateTX struct { // DecodeBinary implements Serializable interface. func (tx *StateTX) DecodeBinary(r *io.BinReader) { - tx.Descriptors = r.ReadArray(StateDescriptor{}).([]*StateDescriptor) + r.ReadArray(&tx.Descriptors) } // EncodeBinary implements Serializable interface. diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 712d44d78..aefeb8c7b 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -95,10 +95,10 @@ func (t *Transaction) DecodeBinary(br *io.BinReader) { br.ReadLE(&t.Version) t.decodeData(br) - t.Attributes = br.ReadArray(Attribute{}).([]*Attribute) - t.Inputs = br.ReadArray(Input{}).([]*Input) - t.Outputs = br.ReadArray(Output{}).([]*Output) - t.Scripts = br.ReadArray(Witness{}).([]*Witness) + br.ReadArray(&t.Attributes) + br.ReadArray(&t.Inputs) + br.ReadArray(&t.Outputs) + br.ReadArray(&t.Scripts) // Create the hash of the transaction at decode, so we dont need // to do it anymore. diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index 689a80974..126593480 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -34,29 +34,50 @@ func (r *BinReader) ReadLE(v interface{}) { r.Err = binary.Read(r.r, binary.LittleEndian, v) } -// ReadArray reads a slice or an array of pointer to t from r and returns. -func (r *BinReader) ReadArray(t interface{}) interface{} { - elemType := reflect.ValueOf(t).Type() - method, ok := reflect.PtrTo(elemType).MethodByName("DecodeBinary") - if !ok || !isDecodeBinaryMethod(method) { - panic(elemType.String() + " does not have DecodeBinary(*io.BinReader)") +// ReadArray reads array into value which must be +// a pointer to a slice. +func (r *BinReader) ReadArray(t interface{}) { + value := reflect.ValueOf(t) + if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Slice { + panic(value.Type().String() + " is not a pointer to a slice") + } + + sliceType := value.Elem().Type() + elemType := sliceType.Elem() + isPtr := elemType.Kind() == reflect.Ptr + if isPtr { + checkHasDecodeBinary(elemType) + } else { + checkHasDecodeBinary(reflect.PtrTo(elemType)) } - sliceType := reflect.SliceOf(reflect.PtrTo(elemType)) if r.Err != nil { - return reflect.Zero(sliceType).Interface() + return } l := int(r.ReadVarUint()) arr := reflect.MakeSlice(sliceType, l, l) + for i := 0; i < l; i++ { - elem := arr.Index(i) + var elem reflect.Value + if isPtr { + elem = reflect.New(elemType.Elem()) + arr.Index(i).Set(elem) + } else { + elem = arr.Index(i).Addr() + } method := elem.MethodByName("DecodeBinary") - elem.Set(reflect.New(elemType)) method.Call([]reflect.Value{reflect.ValueOf(r)}) } - return arr.Interface() + value.Elem().Set(arr) +} + +func checkHasDecodeBinary(v reflect.Type) { + method, ok := v.MethodByName("DecodeBinary") + if !ok || !isDecodeBinaryMethod(method) { + panic(v.String() + " does not have DecodeBinary(*io.BinReader)") + } } func isDecodeBinaryMethod(method reflect.Method) bool { diff --git a/pkg/io/binaryrw_test.go b/pkg/io/binaryrw_test.go index 304b6b651..8bc9e1578 100644 --- a/pkg/io/binaryrw_test.go +++ b/pkg/io/binaryrw_test.go @@ -253,22 +253,53 @@ func TestBinWriter_WriteArray(t *testing.T) { func TestBinReader_ReadArray(t *testing.T) { data := []byte{3, 0, 0, 1, 0, 2, 0} - r := NewBinReaderFromBuf(data) - result := r.ReadArray(testSerializable(0)) elems := []testSerializable{0, 1, 2} - require.Equal(t, []*testSerializable{&elems[0], &elems[1], &elems[2]}, result) + + r := NewBinReaderFromBuf(data) + arrPtr := []*testSerializable{} + r.ReadArray(&arrPtr) + require.Equal(t, []*testSerializable{&elems[0], &elems[1], &elems[2]}, arrPtr) r = NewBinReaderFromBuf(data) - r.Err = errors.New("error") - result = r.ReadArray(testSerializable(0)) - require.Error(t, r.Err) - require.Equal(t, ([]*testSerializable)(nil), result) - - r = NewBinReaderFromBuf([]byte{0}) - result = r.ReadArray(testSerializable(0)) + arrVal := []testSerializable{} + r.ReadArray(&arrVal) require.NoError(t, r.Err) - require.Equal(t, []*testSerializable{}, result) + require.Equal(t, elems, arrVal) r = NewBinReaderFromBuf([]byte{0}) + r.ReadArray(&arrVal) + require.NoError(t, r.Err) + require.Equal(t, []testSerializable{}, arrVal) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + arrVal = ([]testSerializable)(nil) + r.ReadArray(&arrVal) + require.Error(t, r.Err) + require.Equal(t, ([]testSerializable)(nil), arrVal) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + arrPtr = ([]*testSerializable)(nil) + r.ReadArray(&arrVal) + require.Error(t, r.Err) + require.Equal(t, ([]*testSerializable)(nil), arrPtr) + + r = NewBinReaderFromBuf([]byte{0}) + arrVal = []testSerializable{1, 2} + r.ReadArray(&arrVal) + require.NoError(t, r.Err) + require.Equal(t, []testSerializable{}, arrVal) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + require.Panics(t, func() { r.ReadArray(&[]*int{}) }) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + require.Panics(t, func() { r.ReadArray(&[]int{}) }) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") require.Panics(t, func() { r.ReadArray(0) }) } diff --git a/pkg/io/size_test.go b/pkg/io/size_test.go index 9b45b6311..4e397894c 100644 --- a/pkg/io/size_test.go +++ b/pkg/io/size_test.go @@ -1,9 +1,10 @@ -package io +package io_test import ( "fmt" "testing" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/util" "github.com/stretchr/testify/assert" ) @@ -13,18 +14,18 @@ type smthSerializable struct { some [42]byte } -func (*smthSerializable) DecodeBinary(*BinReader) {} +func (*smthSerializable) DecodeBinary(*io.BinReader) {} -func (ss *smthSerializable) EncodeBinary(bw *BinWriter) { +func (ss *smthSerializable) EncodeBinary(bw *io.BinWriter) { bw.WriteLE(ss.some) } // Mock structure that gives error in EncodeBinary(). type smthNotReallySerializable struct{} -func (*smthNotReallySerializable) DecodeBinary(*BinReader) {} +func (*smthNotReallySerializable) DecodeBinary(*io.BinReader) {} -func (*smthNotReallySerializable) EncodeBinary(bw *BinWriter) { +func (*smthNotReallySerializable) EncodeBinary(bw *io.BinWriter) { bw.Err = fmt.Errorf("smth bad happened in smthNotReallySerializable") } @@ -182,7 +183,7 @@ func TestVarSize(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("run: %s", tc.name), func(t *testing.T) { - result := GetVarSize(tc.variable) + result := io.GetVarSize(tc.variable) assert.Equal(t, tc.expected, result) }) } @@ -194,7 +195,7 @@ func panicVarSize(t *testing.T, v interface{}) { assert.NotNil(t, r) }() - _ = GetVarSize(v) + _ = io.GetVarSize(v) // this should never execute assert.Nil(t, t) } diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go index 2f7728eb2..0daee4456 100644 --- a/pkg/network/payload/address.go +++ b/pkg/network/payload/address.go @@ -67,7 +67,7 @@ func NewAddressList(n int) *AddressList { // DecodeBinary implements Serializable interface. func (p *AddressList) DecodeBinary(br *io.BinReader) { - p.Addrs = br.ReadArray(AddressAndTime{}).([]*AddressAndTime) + br.ReadArray(&p.Addrs) } // EncodeBinary implements Serializable interface. diff --git a/pkg/network/payload/getblocks.go b/pkg/network/payload/getblocks.go index 12255fcae..8701a1df0 100644 --- a/pkg/network/payload/getblocks.go +++ b/pkg/network/payload/getblocks.go @@ -23,16 +23,12 @@ func NewGetBlocks(start []util.Uint256, stop util.Uint256) *GetBlocks { // DecodeBinary implements Serializable interface. func (p *GetBlocks) DecodeBinary(br *io.BinReader) { - lenStart := br.ReadVarUint() - p.HashStart = make([]util.Uint256, lenStart) - - br.ReadLE(&p.HashStart) + br.ReadArray(&p.HashStart) br.ReadLE(&p.HashStop) } // EncodeBinary implements Serializable interface. func (p *GetBlocks) EncodeBinary(bw *io.BinWriter) { - bw.WriteVarUint(uint64(len(p.HashStart))) - bw.WriteLE(p.HashStart) + bw.WriteArray(p.HashStart) bw.WriteLE(p.HashStop) } diff --git a/pkg/network/payload/inventory.go b/pkg/network/payload/inventory.go index ccefe8527..036aa0279 100644 --- a/pkg/network/payload/inventory.go +++ b/pkg/network/payload/inventory.go @@ -57,21 +57,11 @@ func NewInventory(typ InventoryType, hashes []util.Uint256) *Inventory { // DecodeBinary implements Serializable interface. func (p *Inventory) DecodeBinary(br *io.BinReader) { br.ReadLE(&p.Type) - - listLen := br.ReadVarUint() - p.Hashes = make([]util.Uint256, listLen) - for i := 0; i < int(listLen); i++ { - br.ReadLE(&p.Hashes[i]) - } + br.ReadArray(&p.Hashes) } // EncodeBinary implements Serializable interface. func (p *Inventory) EncodeBinary(bw *io.BinWriter) { bw.WriteLE(p.Type) - - listLen := len(p.Hashes) - bw.WriteVarUint(uint64(listLen)) - for i := 0; i < listLen; i++ { - bw.WriteLE(p.Hashes[i]) - } + bw.WriteArray(p.Hashes) } diff --git a/pkg/network/payload/merkleblock.go b/pkg/network/payload/merkleblock.go index 996f32431..dad96f305 100644 --- a/pkg/network/payload/merkleblock.go +++ b/pkg/network/payload/merkleblock.go @@ -20,11 +20,7 @@ func (m *MerkleBlock) DecodeBinary(br *io.BinReader) { m.BlockBase.DecodeBinary(br) m.TxCount = int(br.ReadVarUint()) - n := br.ReadVarUint() - m.Hashes = make([]util.Uint256, n) - for i := 0; i < len(m.Hashes); i++ { - br.ReadLE(&m.Hashes[i]) - } + br.ReadArray(&m.Hashes) m.Flags = br.ReadBytes() } @@ -34,9 +30,6 @@ func (m *MerkleBlock) EncodeBinary(bw *io.BinWriter) { m.BlockBase.EncodeBinary(bw) bw.WriteVarUint(uint64(m.TxCount)) - bw.WriteVarUint(uint64(len(m.Hashes))) - for i := 0; i < len(m.Hashes); i++ { - bw.WriteLE(m.Hashes[i]) - } + bw.WriteArray(m.Hashes) bw.WriteBytes(m.Flags) } diff --git a/pkg/smartcontract/param_context.go b/pkg/smartcontract/param_context.go index f7414f6a0..98d32a1f9 100644 --- a/pkg/smartcontract/param_context.go +++ b/pkg/smartcontract/param_context.go @@ -1,6 +1,9 @@ package smartcontract -import "github.com/CityOfZion/neo-go/pkg/util" +import ( + "github.com/CityOfZion/neo-go/pkg/io" + "github.com/CityOfZion/neo-go/pkg/util" +) // ParamType represents the Type of the contract parameter. type ParamType byte @@ -67,6 +70,16 @@ func (pt ParamType) MarshalJSON() ([]byte, error) { return []byte(`"` + pt.String() + `"`), nil } +// EncodeBinary implements io.Serializable interface. +func (pt ParamType) EncodeBinary(w *io.BinWriter) { + w.WriteLE(pt) +} + +// DecodeBinary implements io.Serializable interface. +func (pt *ParamType) DecodeBinary(r *io.BinReader) { + r.ReadLE(pt) +} + // NewParameter returns a Parameter with proper initialized Value // of the given ParamType. func NewParameter(t ParamType) Parameter { diff --git a/pkg/util/uint256.go b/pkg/util/uint256.go index 40dfdb93a..76d69c468 100644 --- a/pkg/util/uint256.go +++ b/pkg/util/uint256.go @@ -6,6 +6,8 @@ import ( "encoding/json" "fmt" "strings" + + "github.com/CityOfZion/neo-go/pkg/io" ) // Uint256Size is the size of Uint256 in bytes. @@ -93,3 +95,13 @@ func (u Uint256) MarshalJSON() ([]byte, error) { // -1 implies u < other. // 0 implies u = other. func (u Uint256) CompareTo(other Uint256) int { return bytes.Compare(u[:], other[:]) } + +// EncodeBinary implements io.Serializable interface. +func (u Uint256) EncodeBinary(w *io.BinWriter) { + w.WriteBE(u) +} + +// DecodeBinary implements io.Serializable interface. +func (u *Uint256) DecodeBinary(r *io.BinReader) { + r.ReadBE(u[:]) +} diff --git a/pkg/util/uint256_test.go b/pkg/util/uint256_test.go index 59fff1838..975edec9f 100644 --- a/pkg/util/uint256_test.go +++ b/pkg/util/uint256_test.go @@ -4,7 +4,9 @@ import ( "encoding/hex" "testing" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUint256UnmarshalJSON(t *testing.T) { @@ -75,3 +77,19 @@ func TestUInt256Equals(t *testing.T) { t.Fatalf("%s and %s must be equal", ua, ua) } } + +func TestUint256_Serializable(t *testing.T) { + a := Uint256{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + } + + w := io.NewBufBinWriter() + a.EncodeBinary(w.BinWriter) + require.NoError(t, w.Err) + + var b Uint256 + r := io.NewBinReaderFromBuf(w.Bytes()) + b.DecodeBinary(r) + require.Equal(t, a, b) +}