diff --git a/pkg/core/account_state.go b/pkg/core/account_state.go index 7d7fdb807..541381f5b 100644 --- a/pkg/core/account_state.go +++ b/pkg/core/account_state.go @@ -94,12 +94,7 @@ func (s *AccountState) DecodeBinary(br *io.BinReader) { br.ReadLE(&s.Version) br.ReadLE(&s.ScriptHash) br.ReadLE(&s.IsFrozen) - lenVotes := br.ReadVarUint() - s.Votes = make([]*keys.PublicKey, lenVotes) - for i := 0; i < int(lenVotes); i++ { - s.Votes[i] = &keys.PublicKey{} - s.Votes[i].DecodeBinary(br) - } + s.Votes = br.ReadArray(keys.PublicKey{}).([]*keys.PublicKey) s.Balances = make(map[util.Uint256]util.Fixed8) lenBalances := br.ReadVarUint() @@ -117,10 +112,7 @@ func (s *AccountState) EncodeBinary(bw *io.BinWriter) { bw.WriteLE(s.Version) bw.WriteLE(s.ScriptHash) bw.WriteLE(s.IsFrozen) - bw.WriteVarUint(uint64(len(s.Votes))) - for _, point := range s.Votes { - point.EncodeBinary(bw) - } + bw.WriteArray(s.Votes) balances := s.nonZeroBalances() bw.WriteVarUint(uint64(len(balances))) diff --git a/pkg/core/block.go b/pkg/core/block.go index de790fc43..90b94a749 100644 --- a/pkg/core/block.go +++ b/pkg/core/block.go @@ -128,13 +128,7 @@ func (b *Block) Trim() ([]byte, error) { // Serializable interface. func (b *Block) DecodeBinary(br *io.BinReader) { b.BlockBase.DecodeBinary(br) - - lentx := br.ReadVarUint() - b.Transactions = make([]*transaction.Transaction, lentx) - for i := 0; i < int(lentx); i++ { - b.Transactions[i] = &transaction.Transaction{} - b.Transactions[i].DecodeBinary(br) - } + b.Transactions = br.ReadArray(transaction.Transaction{}).([]*transaction.Transaction) } // EncodeBinary encodes the block to the given BinWriter, implementing diff --git a/pkg/core/transaction/claim.go b/pkg/core/transaction/claim.go index 7140f6aad..ef7eecfae 100644 --- a/pkg/core/transaction/claim.go +++ b/pkg/core/transaction/claim.go @@ -11,18 +11,10 @@ type ClaimTX struct { // DecodeBinary implements Serializable interface. func (tx *ClaimTX) DecodeBinary(br *io.BinReader) { - lenClaims := br.ReadVarUint() - tx.Claims = make([]*Input, lenClaims) - for i := 0; i < int(lenClaims); i++ { - tx.Claims[i] = &Input{} - tx.Claims[i].DecodeBinary(br) - } + tx.Claims = br.ReadArray(Input{}).([]*Input) } // EncodeBinary implements Serializable interface. func (tx *ClaimTX) EncodeBinary(bw *io.BinWriter) { - bw.WriteVarUint(uint64(len(tx.Claims))) - for _, claim := range tx.Claims { - claim.EncodeBinary(bw) - } + bw.WriteArray(tx.Claims) } diff --git a/pkg/core/transaction/state.go b/pkg/core/transaction/state.go index 22a214e99..c3c181125 100644 --- a/pkg/core/transaction/state.go +++ b/pkg/core/transaction/state.go @@ -11,18 +11,10 @@ type StateTX struct { // DecodeBinary implements Serializable interface. func (tx *StateTX) DecodeBinary(r *io.BinReader) { - lenDesc := r.ReadVarUint() - tx.Descriptors = make([]*StateDescriptor, lenDesc) - for i := 0; i < int(lenDesc); i++ { - tx.Descriptors[i] = &StateDescriptor{} - tx.Descriptors[i].DecodeBinary(r) - } + tx.Descriptors = r.ReadArray(StateDescriptor{}).([]*StateDescriptor) } // EncodeBinary implements Serializable interface. func (tx *StateTX) EncodeBinary(w *io.BinWriter) { - w.WriteVarUint(uint64(len(tx.Descriptors))) - for _, desc := range tx.Descriptors { - desc.EncodeBinary(w) - } + w.WriteArray(tx.Descriptors) } diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index d878bb29f..712d44d78 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -95,33 +95,10 @@ func (t *Transaction) DecodeBinary(br *io.BinReader) { br.ReadLE(&t.Version) t.decodeData(br) - lenAttrs := br.ReadVarUint() - t.Attributes = make([]*Attribute, lenAttrs) - for i := 0; i < int(lenAttrs); i++ { - t.Attributes[i] = &Attribute{} - t.Attributes[i].DecodeBinary(br) - } - - lenInputs := br.ReadVarUint() - t.Inputs = make([]*Input, lenInputs) - for i := 0; i < int(lenInputs); i++ { - t.Inputs[i] = &Input{} - t.Inputs[i].DecodeBinary(br) - } - - lenOutputs := br.ReadVarUint() - t.Outputs = make([]*Output, lenOutputs) - for i := 0; i < int(lenOutputs); i++ { - t.Outputs[i] = &Output{} - t.Outputs[i].DecodeBinary(br) - } - - lenScripts := br.ReadVarUint() - t.Scripts = make([]*Witness, lenScripts) - for i := 0; i < int(lenScripts); i++ { - t.Scripts[i] = &Witness{} - t.Scripts[i].DecodeBinary(br) - } + t.Attributes = br.ReadArray(Attribute{}).([]*Attribute) + t.Inputs = br.ReadArray(Input{}).([]*Input) + t.Outputs = br.ReadArray(Output{}).([]*Output) + t.Scripts = br.ReadArray(Witness{}).([]*Witness) // Create the hash of the transaction at decode, so we dont need // to do it anymore. @@ -167,10 +144,7 @@ func (t *Transaction) decodeData(r *io.BinReader) { // EncodeBinary implements Serializable interface. func (t *Transaction) EncodeBinary(bw *io.BinWriter) { t.encodeHashableFields(bw) - bw.WriteVarUint(uint64(len(t.Scripts))) - for _, s := range t.Scripts { - s.EncodeBinary(bw) - } + bw.WriteArray(t.Scripts) } // encodeHashableFields encodes the fields that are not used for @@ -185,22 +159,13 @@ func (t *Transaction) encodeHashableFields(bw *io.BinWriter) { } // Attributes - bw.WriteVarUint(uint64(len(t.Attributes))) - for _, attr := range t.Attributes { - attr.EncodeBinary(bw) - } + bw.WriteArray(t.Attributes) // Inputs - bw.WriteVarUint(uint64(len(t.Inputs))) - for _, in := range t.Inputs { - in.EncodeBinary(bw) - } + bw.WriteArray(t.Inputs) // Outputs - bw.WriteVarUint(uint64(len(t.Outputs))) - for _, out := range t.Outputs { - out.EncodeBinary(bw) - } + bw.WriteArray(t.Outputs) } // createHash creates the hash of the transaction. diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index 2a00e20c6..689a80974 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "reflect" ) // BinReader is a convenient wrapper around a io.Reader and err object. @@ -33,6 +34,38 @@ func (r *BinReader) ReadLE(v interface{}) { r.Err = binary.Read(r.r, binary.LittleEndian, v) } +// ReadArray reads a slice or an array of pointer to t from r and returns. +func (r *BinReader) ReadArray(t interface{}) interface{} { + elemType := reflect.ValueOf(t).Type() + method, ok := reflect.PtrTo(elemType).MethodByName("DecodeBinary") + if !ok || !isDecodeBinaryMethod(method) { + panic(elemType.String() + " does not have DecodeBinary(*io.BinReader)") + } + + sliceType := reflect.SliceOf(reflect.PtrTo(elemType)) + if r.Err != nil { + return reflect.Zero(sliceType).Interface() + } + + l := int(r.ReadVarUint()) + arr := reflect.MakeSlice(sliceType, l, l) + for i := 0; i < l; i++ { + elem := arr.Index(i) + method := elem.MethodByName("DecodeBinary") + elem.Set(reflect.New(elemType)) + method.Call([]reflect.Value{reflect.ValueOf(r)}) + } + + return arr.Interface() +} + +func isDecodeBinaryMethod(method reflect.Method) bool { + t := method.Type + return t != nil && + t.NumIn() == 2 && t.In(1) == reflect.TypeOf((*BinReader)(nil)) && + t.NumOut() == 0 +} + // ReadBE reads from the underlying io.Reader // into the interface v in big-endian format. func (r *BinReader) ReadBE(v interface{}) { diff --git a/pkg/io/binaryWriter.go b/pkg/io/binaryWriter.go index 912cb3c51..19086e255 100644 --- a/pkg/io/binaryWriter.go +++ b/pkg/io/binaryWriter.go @@ -3,6 +3,7 @@ package io import ( "encoding/binary" "io" + "reflect" ) // BinWriter is a convenient wrapper around a io.Writer and err object. @@ -34,6 +35,37 @@ func (w *BinWriter) WriteBE(v interface{}) { w.Err = binary.Write(w.w, binary.BigEndian, v) } +// WriteArray writes a slice or an array arr into w. +func (w *BinWriter) WriteArray(arr interface{}) { + switch val := reflect.ValueOf(arr); val.Kind() { + case reflect.Slice, reflect.Array: + typ := val.Type().Elem() + method, ok := typ.MethodByName("EncodeBinary") + if !ok || !isEncodeBinaryMethod(method) { + panic(typ.String() + " does not have EncodeBinary(*BinWriter)") + } + + if w.Err != nil { + return + } + + w.WriteVarUint(uint64(val.Len())) + for i := 0; i < val.Len(); i++ { + method := val.Index(i).MethodByName("EncodeBinary") + method.Call([]reflect.Value{reflect.ValueOf(w)}) + } + default: + panic("not an array") + } +} + +func isEncodeBinaryMethod(method reflect.Method) bool { + t := method.Type + return t != nil && + t.NumIn() == 2 && t.In(1) == reflect.TypeOf((*BinWriter)(nil)) && + t.NumOut() == 0 +} + // WriteVarUint writes a uint64 into the underlying writer using variable-length encoding. func (w *BinWriter) WriteVarUint(val uint64) { if w.Err != nil { diff --git a/pkg/io/binaryrw_test.go b/pkg/io/binaryrw_test.go index 1c147f8bf..304b6b651 100644 --- a/pkg/io/binaryrw_test.go +++ b/pkg/io/binaryrw_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // mocks io.Reader and io.Writer, always fails to Write() or Read(). @@ -191,3 +192,83 @@ func TestWriteVarUint100000000000(t *testing.T) { assert.Nil(t, br.Err) assert.Equal(t, val, res) } + +type testSerializable uint16 + +// EncodeBinary implements io.Serializable interface. +func (t testSerializable) EncodeBinary(w *BinWriter) { + w.WriteLE(t) +} + +// DecodeBinary implements io.Serializable interface. +func (t *testSerializable) DecodeBinary(r *BinReader) { + r.ReadLE(t) +} + +func TestBinWriter_WriteArray(t *testing.T) { + var arr [3]testSerializable + for i := range arr { + arr[i] = testSerializable(i) + } + + expected := []byte{3, 0, 0, 1, 0, 2, 0} + + w := NewBufBinWriter() + w.WriteArray(arr) + require.NoError(t, w.Err) + require.Equal(t, expected, w.Bytes()) + + w.Reset() + w.WriteArray(arr[:]) + require.NoError(t, w.Err) + require.Equal(t, expected, w.Bytes()) + + arrS := make([]Serializable, len(arr)) + for i := range arrS { + arrS[i] = &arr[i] + } + + w.Reset() + w.WriteArray(arr) + require.NoError(t, w.Err) + require.Equal(t, expected, w.Bytes()) + + w.Reset() + require.Panics(t, func() { w.WriteArray(1) }) + + w.Reset() + w.Err = errors.New("error") + w.WriteArray(arr[:]) + require.Error(t, w.Err) + require.Equal(t, w.Bytes(), []byte(nil)) + + w.Reset() + w.Err = errors.New("error") + require.Panics(t, func() { w.WriteArray([]int{}) }) + + w.Reset() + w.Err = errors.New("error") + require.Panics(t, func() { w.WriteArray(make(chan testSerializable)) }) +} + +func TestBinReader_ReadArray(t *testing.T) { + data := []byte{3, 0, 0, 1, 0, 2, 0} + r := NewBinReaderFromBuf(data) + result := r.ReadArray(testSerializable(0)) + elems := []testSerializable{0, 1, 2} + require.Equal(t, []*testSerializable{&elems[0], &elems[1], &elems[2]}, result) + + r = NewBinReaderFromBuf(data) + r.Err = errors.New("error") + result = r.ReadArray(testSerializable(0)) + require.Error(t, r.Err) + require.Equal(t, ([]*testSerializable)(nil), result) + + r = NewBinReaderFromBuf([]byte{0}) + result = r.ReadArray(testSerializable(0)) + require.NoError(t, r.Err) + require.Equal(t, []*testSerializable{}, result) + + r = NewBinReaderFromBuf([]byte{0}) + require.Panics(t, func() { r.ReadArray(0) }) +} diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go index 47c7b8b5b..2f7728eb2 100644 --- a/pkg/network/payload/address.go +++ b/pkg/network/payload/address.go @@ -67,19 +67,10 @@ func NewAddressList(n int) *AddressList { // DecodeBinary implements Serializable interface. func (p *AddressList) DecodeBinary(br *io.BinReader) { - listLen := br.ReadVarUint() - - p.Addrs = make([]*AddressAndTime, listLen) - for i := 0; i < int(listLen); i++ { - p.Addrs[i] = &AddressAndTime{} - p.Addrs[i].DecodeBinary(br) - } + p.Addrs = br.ReadArray(AddressAndTime{}).([]*AddressAndTime) } // EncodeBinary implements Serializable interface. func (p *AddressList) EncodeBinary(bw *io.BinWriter) { - bw.WriteVarUint(uint64(len(p.Addrs))) - for _, addr := range p.Addrs { - addr.EncodeBinary(bw) - } + bw.WriteArray(p.Addrs) } diff --git a/pkg/network/payload/headers.go b/pkg/network/payload/headers.go index 6ee94c5eb..8635cb1e6 100644 --- a/pkg/network/payload/headers.go +++ b/pkg/network/payload/headers.go @@ -37,9 +37,5 @@ func (p *Headers) DecodeBinary(br *io.BinReader) { // EncodeBinary implements Serializable interface. func (p *Headers) EncodeBinary(bw *io.BinWriter) { - bw.WriteVarUint(uint64(len(p.Hdrs))) - - for _, header := range p.Hdrs { - header.EncodeBinary(bw) - } + bw.WriteArray(p.Hdrs) }