Simplify ui.StdioWrapper.Write

Instead of looping to find line breaks, make it look for the last one.
This commit is contained in:
greatroar 2020-10-17 20:23:36 +02:00
parent 863a590a81
commit 35419de232
2 changed files with 18 additions and 39 deletions

View file

@ -55,22 +55,12 @@ func (w *lineWriter) Write(data []byte) (n int, err error) {
// look for line breaks // look for line breaks
buf := w.buf.Bytes() buf := w.buf.Bytes()
skip := 0 i := bytes.LastIndexByte(buf, '\n')
for i := 0; i < len(buf); { if i != -1 {
if buf[i] == '\n' {
// found line
w.print(string(buf[:i+1])) w.print(string(buf[:i+1]))
buf = buf[i+1:] w.buf.Next(i + 1)
skip += i + 1
i = 0
continue
} }
i++
}
_ = w.buf.Next(skip)
return n, err return n, err
} }

View file

@ -1,6 +1,7 @@
package ui package ui
import ( import (
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -9,15 +10,13 @@ import (
func TestStdioWrapper(t *testing.T) { func TestStdioWrapper(t *testing.T) {
var tests = []struct { var tests = []struct {
inputs [][]byte inputs [][]byte
outputs []string output string
}{ }{
{ {
inputs: [][]byte{ inputs: [][]byte{
[]byte("foo"), []byte("foo"),
}, },
outputs: []string{ output: "foo\n",
"foo\n",
},
}, },
{ {
inputs: [][]byte{ inputs: [][]byte{
@ -26,23 +25,19 @@ func TestStdioWrapper(t *testing.T) {
[]byte("\n"), []byte("\n"),
[]byte("baz"), []byte("baz"),
}, },
outputs: []string{ output: "foobar\n" +
"foobar\n",
"baz\n", "baz\n",
}, },
},
{ {
inputs: [][]byte{ inputs: [][]byte{
[]byte("foo"), []byte("foo"),
[]byte("bar\nbaz\n"), []byte("bar\nbaz\n"),
[]byte("bump\n"), []byte("bump\n"),
}, },
outputs: []string{ output: "foobar\n" +
"foobar\n", "baz\n" +
"baz\n",
"bump\n", "bump\n",
}, },
},
{ {
inputs: [][]byte{ inputs: [][]byte{
[]byte("foo"), []byte("foo"),
@ -53,23 +48,17 @@ func TestStdioWrapper(t *testing.T) {
[]byte("x"), []byte("x"),
[]byte("z"), []byte("z"),
}, },
outputs: []string{ output: "foobar\n" +
"foobar\n", "baz\n" +
"baz\n", "bump\n" +
"bump\n",
"xxxz\n", "xxxz\n",
}, },
},
} }
for _, test := range tests { for _, test := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
var lines []string var output strings.Builder
print := func(s string) { w := newLineWriter(func(s string) { output.WriteString(s) })
lines = append(lines, s)
}
w := newLineWriter(print)
for _, data := range test.inputs { for _, data := range test.inputs {
n, err := w.Write(data) n, err := w.Write(data)
@ -87,8 +76,8 @@ func TestStdioWrapper(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if !cmp.Equal(test.outputs, lines) { if outstr := output.String(); outstr != test.output {
t.Error(cmp.Diff(test.outputs, lines)) t.Error(cmp.Diff(test.output, outstr))
} }
}) })
} }