From ebd01a467599b469e486cb08a170ab04e787bc8a Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Wed, 1 May 2024 21:54:21 +0200 Subject: [PATCH] backend: add tests for watchdogRoundTripper --- internal/backend/watchdog_roundtriper_test.go | 201 ++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 internal/backend/watchdog_roundtriper_test.go diff --git a/internal/backend/watchdog_roundtriper_test.go b/internal/backend/watchdog_roundtriper_test.go new file mode 100644 index 000000000..a13d670e0 --- /dev/null +++ b/internal/backend/watchdog_roundtriper_test.go @@ -0,0 +1,201 @@ +package backend + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + rtest "github.com/restic/restic/internal/test" +) + +func TestRead(t *testing.T) { + data := []byte("abcdef") + var ctr int + kick := func() { + ctr++ + } + var closed bool + onClose := func() { + closed = true + } + + wd := newWatchdogReadCloser(io.NopCloser(bytes.NewReader(data)), 1, kick, onClose) + + out, err := io.ReadAll(wd) + rtest.OK(t, err) + rtest.Equals(t, data, out, "data mismatch") + // the EOF read also triggers the kick function + rtest.Equals(t, len(data)*2+2, ctr, "unexpected number of kick calls") + + rtest.Equals(t, false, closed, "close function called too early") + rtest.OK(t, wd.Close()) + rtest.Equals(t, true, closed, "close function not called") +} + +func TestRoundtrip(t *testing.T) { + t.Parallel() + + // at the higher delay values, it takes longer to transmit the request/response body + // than the roundTripper timeout + for _, delay := range []int{0, 1, 10, 20} { + t.Run(fmt.Sprintf("%v", delay), func(t *testing.T) { + msg := []byte("ping-pong-data") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(500) + return + } + w.WriteHeader(200) + + // slowly send the reply + for len(data) >= 2 { + _, _ = w.Write(data[:2]) + w.(http.Flusher).Flush() + data = data[2:] + time.Sleep(time.Duration(delay) * time.Millisecond) + } + _, _ = w.Write(data) + })) + defer srv.Close() + + rt := newWatchdogRoundtripper(http.DefaultTransport, 50*time.Millisecond, 2) + req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(newSlowReader(bytes.NewReader(msg), time.Duration(delay)*time.Millisecond))) + rtest.OK(t, err) + + resp, err := rt.RoundTrip(req) + rtest.OK(t, err) + rtest.Equals(t, 200, resp.StatusCode, "unexpected status code") + + response, err := io.ReadAll(resp.Body) + rtest.OK(t, err) + rtest.Equals(t, msg, response, "unexpected response") + + rtest.OK(t, resp.Body.Close()) + }) + } +} + +func TestCanceledRoundtrip(t *testing.T) { + rt := newWatchdogRoundtripper(http.DefaultTransport, time.Second, 2) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + req, err := http.NewRequestWithContext(ctx, "GET", "http://some.random.url.dfdgsfg", nil) + rtest.OK(t, err) + + resp, err := rt.RoundTrip(req) + rtest.Equals(t, context.Canceled, err) + // make linter happy + if resp != nil { + rtest.OK(t, resp.Body.Close()) + } +} + +type slowReader struct { + data io.Reader + delay time.Duration +} + +func newSlowReader(data io.Reader, delay time.Duration) *slowReader { + return &slowReader{ + data: data, + delay: delay, + } +} + +func (s *slowReader) Read(p []byte) (n int, err error) { + time.Sleep(s.delay) + return s.data.Read(p) +} + +func TestUploadTimeout(t *testing.T) { + t.Parallel() + + msg := []byte("ping") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(500) + return + } + t.Error("upload should have been canceled") + })) + defer srv.Close() + + rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024) + req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(newSlowReader(bytes.NewReader(msg), 100*time.Millisecond))) + rtest.OK(t, err) + + resp, err := rt.RoundTrip(req) + rtest.Equals(t, context.Canceled, err) + // make linter happy + if resp != nil { + rtest.OK(t, resp.Body.Close()) + } +} + +func TestProcessingTimeout(t *testing.T) { + t.Parallel() + + msg := []byte("ping") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(500) + return + } + time.Sleep(100 * time.Millisecond) + w.WriteHeader(200) + })) + defer srv.Close() + + rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024) + req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(bytes.NewReader(msg))) + rtest.OK(t, err) + + resp, err := rt.RoundTrip(req) + rtest.Equals(t, context.Canceled, err) + // make linter happy + if resp != nil { + rtest.OK(t, resp.Body.Close()) + } +} + +func TestDownloadTimeout(t *testing.T) { + t.Parallel() + + msg := []byte("ping") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(500) + return + } + w.WriteHeader(200) + _, _ = w.Write(data[:2]) + w.(http.Flusher).Flush() + data = data[2:] + + time.Sleep(100 * time.Millisecond) + _, _ = w.Write(data) + + })) + defer srv.Close() + + rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024) + req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(bytes.NewReader(msg))) + rtest.OK(t, err) + + resp, err := rt.RoundTrip(req) + rtest.OK(t, err) + rtest.Equals(t, 200, resp.StatusCode, "unexpected status code") + + _, err = io.ReadAll(resp.Body) + rtest.Equals(t, context.Canceled, err, "response download not canceled") + rtest.OK(t, resp.Body.Close()) +}