diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index 8b248b271..e1dddc27e 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -40,7 +40,7 @@ const uint256size = 32 // DecodeBinary implements io.Serializable interface. func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { - m.ChangeViewPayloads = r.ReadArray(changeViewCompact{}).([]*changeViewCompact) + r.ReadArray(&m.ChangeViewPayloads) var hasReq bool r.ReadLE(&hasReq) @@ -61,8 +61,8 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { } } - m.PreparationPayloads = r.ReadArray(preparationCompact{}).([]*preparationCompact) - m.CommitPayloads = r.ReadArray(commitCompact{}).([]*commitCompact) + r.ReadArray(&m.PreparationPayloads) + r.ReadArray(&m.CommitPayloads) } // EncodeBinary implements io.Serializable interface. diff --git a/pkg/core/account_state.go b/pkg/core/account_state.go index 541381f5b..9404fd7bf 100644 --- a/pkg/core/account_state.go +++ b/pkg/core/account_state.go @@ -94,7 +94,7 @@ func (s *AccountState) DecodeBinary(br *io.BinReader) { br.ReadLE(&s.Version) br.ReadLE(&s.ScriptHash) br.ReadLE(&s.IsFrozen) - s.Votes = br.ReadArray(keys.PublicKey{}).([]*keys.PublicKey) + br.ReadArray(&s.Votes) s.Balances = make(map[util.Uint256]util.Fixed8) lenBalances := br.ReadVarUint() diff --git a/pkg/core/block.go b/pkg/core/block.go index 90b94a749..45bff104c 100644 --- a/pkg/core/block.go +++ b/pkg/core/block.go @@ -128,7 +128,7 @@ func (b *Block) Trim() ([]byte, error) { // Serializable interface. func (b *Block) DecodeBinary(br *io.BinReader) { b.BlockBase.DecodeBinary(br) - b.Transactions = br.ReadArray(transaction.Transaction{}).([]*transaction.Transaction) + br.ReadArray(&b.Transactions) } // EncodeBinary encodes the block to the given BinWriter, implementing diff --git a/pkg/core/transaction/claim.go b/pkg/core/transaction/claim.go index ef7eecfae..423a4bc79 100644 --- a/pkg/core/transaction/claim.go +++ b/pkg/core/transaction/claim.go @@ -11,7 +11,7 @@ type ClaimTX struct { // DecodeBinary implements Serializable interface. func (tx *ClaimTX) DecodeBinary(br *io.BinReader) { - tx.Claims = br.ReadArray(Input{}).([]*Input) + br.ReadArray(&tx.Claims) } // EncodeBinary implements Serializable interface. diff --git a/pkg/core/transaction/state.go b/pkg/core/transaction/state.go index c3c181125..76178a205 100644 --- a/pkg/core/transaction/state.go +++ b/pkg/core/transaction/state.go @@ -11,7 +11,7 @@ type StateTX struct { // DecodeBinary implements Serializable interface. func (tx *StateTX) DecodeBinary(r *io.BinReader) { - tx.Descriptors = r.ReadArray(StateDescriptor{}).([]*StateDescriptor) + r.ReadArray(&tx.Descriptors) } // EncodeBinary implements Serializable interface. diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 712d44d78..aefeb8c7b 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -95,10 +95,10 @@ func (t *Transaction) DecodeBinary(br *io.BinReader) { br.ReadLE(&t.Version) t.decodeData(br) - t.Attributes = br.ReadArray(Attribute{}).([]*Attribute) - t.Inputs = br.ReadArray(Input{}).([]*Input) - t.Outputs = br.ReadArray(Output{}).([]*Output) - t.Scripts = br.ReadArray(Witness{}).([]*Witness) + br.ReadArray(&t.Attributes) + br.ReadArray(&t.Inputs) + br.ReadArray(&t.Outputs) + br.ReadArray(&t.Scripts) // Create the hash of the transaction at decode, so we dont need // to do it anymore. diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index 689a80974..126593480 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -34,29 +34,50 @@ func (r *BinReader) ReadLE(v interface{}) { r.Err = binary.Read(r.r, binary.LittleEndian, v) } -// ReadArray reads a slice or an array of pointer to t from r and returns. -func (r *BinReader) ReadArray(t interface{}) interface{} { - elemType := reflect.ValueOf(t).Type() - method, ok := reflect.PtrTo(elemType).MethodByName("DecodeBinary") - if !ok || !isDecodeBinaryMethod(method) { - panic(elemType.String() + " does not have DecodeBinary(*io.BinReader)") +// ReadArray reads array into value which must be +// a pointer to a slice. +func (r *BinReader) ReadArray(t interface{}) { + value := reflect.ValueOf(t) + if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Slice { + panic(value.Type().String() + " is not a pointer to a slice") + } + + sliceType := value.Elem().Type() + elemType := sliceType.Elem() + isPtr := elemType.Kind() == reflect.Ptr + if isPtr { + checkHasDecodeBinary(elemType) + } else { + checkHasDecodeBinary(reflect.PtrTo(elemType)) } - sliceType := reflect.SliceOf(reflect.PtrTo(elemType)) if r.Err != nil { - return reflect.Zero(sliceType).Interface() + return } l := int(r.ReadVarUint()) arr := reflect.MakeSlice(sliceType, l, l) + for i := 0; i < l; i++ { - elem := arr.Index(i) + var elem reflect.Value + if isPtr { + elem = reflect.New(elemType.Elem()) + arr.Index(i).Set(elem) + } else { + elem = arr.Index(i).Addr() + } method := elem.MethodByName("DecodeBinary") - elem.Set(reflect.New(elemType)) method.Call([]reflect.Value{reflect.ValueOf(r)}) } - return arr.Interface() + value.Elem().Set(arr) +} + +func checkHasDecodeBinary(v reflect.Type) { + method, ok := v.MethodByName("DecodeBinary") + if !ok || !isDecodeBinaryMethod(method) { + panic(v.String() + " does not have DecodeBinary(*io.BinReader)") + } } func isDecodeBinaryMethod(method reflect.Method) bool { diff --git a/pkg/io/binaryrw_test.go b/pkg/io/binaryrw_test.go index 304b6b651..8bc9e1578 100644 --- a/pkg/io/binaryrw_test.go +++ b/pkg/io/binaryrw_test.go @@ -253,22 +253,53 @@ func TestBinWriter_WriteArray(t *testing.T) { func TestBinReader_ReadArray(t *testing.T) { data := []byte{3, 0, 0, 1, 0, 2, 0} - r := NewBinReaderFromBuf(data) - result := r.ReadArray(testSerializable(0)) elems := []testSerializable{0, 1, 2} - require.Equal(t, []*testSerializable{&elems[0], &elems[1], &elems[2]}, result) + + r := NewBinReaderFromBuf(data) + arrPtr := []*testSerializable{} + r.ReadArray(&arrPtr) + require.Equal(t, []*testSerializable{&elems[0], &elems[1], &elems[2]}, arrPtr) r = NewBinReaderFromBuf(data) - r.Err = errors.New("error") - result = r.ReadArray(testSerializable(0)) - require.Error(t, r.Err) - require.Equal(t, ([]*testSerializable)(nil), result) - - r = NewBinReaderFromBuf([]byte{0}) - result = r.ReadArray(testSerializable(0)) + arrVal := []testSerializable{} + r.ReadArray(&arrVal) require.NoError(t, r.Err) - require.Equal(t, []*testSerializable{}, result) + require.Equal(t, elems, arrVal) r = NewBinReaderFromBuf([]byte{0}) + r.ReadArray(&arrVal) + require.NoError(t, r.Err) + require.Equal(t, []testSerializable{}, arrVal) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + arrVal = ([]testSerializable)(nil) + r.ReadArray(&arrVal) + require.Error(t, r.Err) + require.Equal(t, ([]testSerializable)(nil), arrVal) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + arrPtr = ([]*testSerializable)(nil) + r.ReadArray(&arrVal) + require.Error(t, r.Err) + require.Equal(t, ([]*testSerializable)(nil), arrPtr) + + r = NewBinReaderFromBuf([]byte{0}) + arrVal = []testSerializable{1, 2} + r.ReadArray(&arrVal) + require.NoError(t, r.Err) + require.Equal(t, []testSerializable{}, arrVal) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + require.Panics(t, func() { r.ReadArray(&[]*int{}) }) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") + require.Panics(t, func() { r.ReadArray(&[]int{}) }) + + r = NewBinReaderFromBuf([]byte{0}) + r.Err = errors.New("error") require.Panics(t, func() { r.ReadArray(0) }) } diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go index 2f7728eb2..0daee4456 100644 --- a/pkg/network/payload/address.go +++ b/pkg/network/payload/address.go @@ -67,7 +67,7 @@ func NewAddressList(n int) *AddressList { // DecodeBinary implements Serializable interface. func (p *AddressList) DecodeBinary(br *io.BinReader) { - p.Addrs = br.ReadArray(AddressAndTime{}).([]*AddressAndTime) + br.ReadArray(&p.Addrs) } // EncodeBinary implements Serializable interface.