diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index af66daf7a..aa409b19a 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -97,7 +97,7 @@ func (p *changeViewCompact) DecodeBinary(r *io.BinReader) { p.ValidatorIndex = r.ReadU16LE() p.OriginalViewNumber = r.ReadB() p.Timestamp = r.ReadU32LE() - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -114,7 +114,7 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) { p.ValidatorIndex = r.ReadU16LE() r.ReadBytes(p.Signature[:]) r.ReadBytes(p.StateSignature[:]) - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -129,7 +129,7 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) { // DecodeBinary implements io.Serializable interface. func (p *preparationCompact) DecodeBinary(r *io.BinReader) { p.ValidatorIndex = r.ReadU16LE() - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index fd23355a2..bd62b53ee 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -168,8 +168,16 @@ func (r *BinReader) ReadVarUint() uint64 { // ReadVarBytes reads the next set of bytes from the underlying reader. // ReadVarUInt() is used to determine how large that slice is -func (r *BinReader) ReadVarBytes() []byte { +func (r *BinReader) ReadVarBytes(maxSize ...int) []byte { n := r.ReadVarUint() + ms := maxArraySize + if len(maxSize) != 0 { + ms = maxSize[0] + } + if n > uint64(ms) { + r.Err = fmt.Errorf("byte-slice is too big (%d)", n) + return nil + } b := make([]byte, n) r.ReadBytes(b) return b diff --git a/pkg/io/binaryrw_test.go b/pkg/io/binaryrw_test.go index d5e1cf8c6..fd998d503 100644 --- a/pkg/io/binaryrw_test.go +++ b/pkg/io/binaryrw_test.go @@ -143,6 +143,35 @@ func TestBufBinWriter_Len(t *testing.T) { require.Equal(t, 1, bw.Len()) } +func TestBinReader_ReadVarBytes(t *testing.T) { + buf := make([]byte, 11) + for i := range buf { + buf[i] = byte(i) + } + w := NewBufBinWriter() + w.WriteVarBytes(buf) + require.NoError(t, w.Err) + data := w.Bytes() + + t.Run("NoArguments", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + actual := r.ReadVarBytes() + require.NoError(t, r.Err) + require.Equal(t, buf, actual) + }) + t.Run("Good", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + actual := r.ReadVarBytes(11) + require.NoError(t, r.Err) + require.Equal(t, buf, actual) + }) + t.Run("Bad", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + r.ReadVarBytes(10) + require.Error(t, r.Err) + }) +} + func TestWriterErrHandling(t *testing.T) { var badio = &badRW{} bw := NewBinWriterFromIO(badio)