diff --git a/cmd/serve/restic/cache.go b/cmd/serve/restic/cache.go index f7f376c82..4d2215129 100644 --- a/cmd/serve/restic/cache.go +++ b/cmd/serve/restic/cache.go @@ -9,20 +9,22 @@ import ( // cache implements a simple object cache type cache struct { - mu sync.RWMutex // protects the cache - items map[string]fs.Object // cache of objects + mu sync.RWMutex // protects the cache + items map[string]fs.Object // cache of objects + cacheObjects bool // whether we are actually caching } // create a new cache -func newCache() *cache { +func newCache(cacheObjects bool) *cache { return &cache{ - items: map[string]fs.Object{}, + items: map[string]fs.Object{}, + cacheObjects: cacheObjects, } } // find the object at remote or return nil func (c *cache) find(remote string) fs.Object { - if !cacheObjects { + if !c.cacheObjects { return nil } c.mu.RLock() @@ -33,7 +35,7 @@ func (c *cache) find(remote string) fs.Object { // add the object to the cache func (c *cache) add(remote string, o fs.Object) { - if !cacheObjects { + if !c.cacheObjects { return } c.mu.Lock() @@ -43,7 +45,7 @@ func (c *cache) add(remote string, o fs.Object) { // remove the object from the cache func (c *cache) remove(remote string) { - if !cacheObjects { + if !c.cacheObjects { return } c.mu.Lock() @@ -53,7 +55,7 @@ func (c *cache) remove(remote string) { // remove all the items with prefix from the cache func (c *cache) removePrefix(prefix string) { - if !cacheObjects { + if !c.cacheObjects { return } diff --git a/cmd/serve/restic/cache_test.go b/cmd/serve/restic/cache_test.go index 05687ad84..ffabfa773 100644 --- a/cmd/serve/restic/cache_test.go +++ b/cmd/serve/restic/cache_test.go @@ -21,7 +21,7 @@ func (c *cache) String() string { } func TestCacheCRUD(t *testing.T) { - c := newCache() + c := newCache(true) assert.Equal(t, "", c.String()) assert.Nil(t, c.find("potato")) o := mockobject.New("potato") @@ -35,7 +35,7 @@ func TestCacheCRUD(t *testing.T) { } func TestCacheRemovePrefix(t *testing.T) { - c := newCache() + c := newCache(true) for _, remote := range []string{ "a", "b", diff --git a/cmd/serve/restic/restic.go b/cmd/serve/restic/restic.go index 3662fc58c..0d4db85bd 100644 --- a/cmd/serve/restic/restic.go +++ b/cmd/serve/restic/restic.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "os" "path" @@ -12,34 +13,48 @@ import ( "strings" "time" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" "github.com/rclone/rclone/cmd" - "github.com/rclone/rclone/cmd/serve/httplib" - "github.com/rclone/rclone/cmd/serve/httplib/httpflags" "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/accounting" "github.com/rclone/rclone/fs/config/flags" "github.com/rclone/rclone/fs/operations" "github.com/rclone/rclone/fs/walk" + libhttp "github.com/rclone/rclone/lib/http" "github.com/rclone/rclone/lib/http/serve" "github.com/rclone/rclone/lib/terminal" "github.com/spf13/cobra" "golang.org/x/net/http2" ) -var ( - stdio bool - appendOnly bool - privateRepos bool - cacheObjects bool -) +// Options required for http server +type Options struct { + Auth libhttp.AuthConfig + HTTP libhttp.Config + Stdio bool + AppendOnly bool + PrivateRepos bool + CacheObjects bool +} + +// DefaultOpt is the default values used for Options +var DefaultOpt = Options{ + Auth: libhttp.DefaultAuthCfg(), + HTTP: libhttp.DefaultCfg(), +} + +// Opt is options set by command line flags +var Opt = DefaultOpt func init() { - httpflags.AddFlags(Command.Flags()) flagSet := Command.Flags() - flags.BoolVarP(flagSet, &stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout") - flags.BoolVarP(flagSet, &appendOnly, "append-only", "", false, "Disallow deletion of repository data") - flags.BoolVarP(flagSet, &privateRepos, "private-repos", "", false, "Users can only access their private repo") - flags.BoolVarP(flagSet, &cacheObjects, "cache-objects", "", true, "Cache listed objects") + libhttp.AddAuthFlagsPrefix(flagSet, "", &Opt.Auth) + libhttp.AddHTTPFlagsPrefix(flagSet, "", &Opt.HTTP) + flags.BoolVarP(flagSet, &Opt.Stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout") + flags.BoolVarP(flagSet, &Opt.AppendOnly, "append-only", "", false, "Disallow deletion of repository data") + flags.BoolVarP(flagSet, &Opt.PrivateRepos, "private-repos", "", false, "Users can only access their private repo") + flags.BoolVarP(flagSet, &Opt.CacheObjects, "cache-objects", "", true, "Cache listed objects") } // Command definition for cobra @@ -127,16 +142,21 @@ these **must** end with /. Eg The` + "`--private-repos`" + ` flag can be used to limit users to repositories starting with a path of ` + "`//`" + `. -` + httplib.Help, +` + libhttp.Help + libhttp.AuthHelp, Annotations: map[string]string{ "versionIntroduced": "v1.40", }, Run: func(command *cobra.Command, args []string) { + ctx := context.Background() cmd.CheckArgs(1, 1, command, args) f := cmd.NewFsSrc(args) cmd.Run(false, true, command, func() error { - s := NewServer(f, &httpflags.Opt) - if stdio { + s, err := newServer(ctx, f, &Opt) + if err != nil { + return err + } + fs.Logf(s.f, "Serving restic REST API on %s", s.URLs()) + if s.opt.Stdio { if terminal.IsTerminal(int(os.Stdout.Fd())) { return errors.New("refusing to run HTTP2 server directly on a terminal, please let restic start rclone") } @@ -148,16 +168,11 @@ with a path of ` + "`//`" + `. httpSrv := &http2.Server{} opts := &http2.ServeConnOpts{ - Handler: s, + Handler: s.Server.Router(), } httpSrv.ServeConn(conn, opts) return nil } - err := s.Serve() - if err != nil { - return err - } - s.Wait() return nil }) }, @@ -167,101 +182,130 @@ const ( resticAPIV2 = "application/vnd.x.restic.rest.v2" ) -// Server contains everything to run the Server -type Server struct { - *httplib.Server +type contextRemoteType struct{} + +// ContextRemoteKey is a simple context key for storing the username of the request +var ContextRemoteKey = &contextRemoteType{} + +// WithRemote makes a remote from a URL path. This implements the backend layout +// required by restic. +func WithRemote(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var urlpath string + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + urlpath = rctx.RoutePath + } else { + urlpath = r.URL.Path + } + urlpath = strings.Trim(urlpath, "/") + parts := matchData.FindStringSubmatch(urlpath) + // if no data directory, layout is flat + if parts != nil { + // otherwise map + // data/2159dd48 to + // data/21/2159dd48 + fileName := parts[1] + prefix := urlpath[:len(urlpath)-len(fileName)] + urlpath = prefix + fileName[:2] + "/" + fileName + } + ctx := context.WithValue(r.Context(), ContextRemoteKey, urlpath) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// Middleware to ensure authenticated user is accessing their own private folder +func checkPrivate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := chi.URLParam(r, "userID") + userID, ok := libhttp.CtxGetUser(r.Context()) + if ok && user != "" && user == userID { + next.ServeHTTP(w, r) + } else { + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + } + }) +} + +// server contains everything to run the server +type server struct { + *libhttp.Server f fs.Fs cache *cache + opt Options } -// NewServer returns an HTTP server that speaks the rest protocol -func NewServer(f fs.Fs, opt *httplib.Options) *Server { - mux := http.NewServeMux() - s := &Server{ - Server: httplib.NewServer(mux, opt), - f: f, - cache: newCache(), +func newServer(ctx context.Context, f fs.Fs, opt *Options) (s *server, err error) { + s = &server{ + f: f, + cache: newCache(opt.CacheObjects), + opt: *opt, } - mux.HandleFunc(s.Opt.BaseURL+"/", s.ServeHTTP) - return s -} - -// Serve runs the http server in the background. -// -// Use s.Close() and s.Wait() to shutdown server -func (s *Server) Serve() error { - err := s.Server.Serve() + s.Server, err = libhttp.NewServer(ctx, + libhttp.WithConfig(opt.HTTP), + libhttp.WithAuth(opt.Auth), + ) if err != nil { - return err + return nil, fmt.Errorf("failed to init server: %w", err) + } + router := s.Router() + s.Bind(router) + s.Server.Serve() + return s, nil +} + +// bind helper for main Bind method +func (s *server) bind(router chi.Router) { + router.MethodFunc("GET", "/*", func(w http.ResponseWriter, r *http.Request) { + urlpath := chi.URLParam(r, "*") + if urlpath == "" || strings.HasSuffix(urlpath, "/") { + s.listObjects(w, r) + } else { + s.serveObject(w, r) + } + }) + router.MethodFunc("POST", "/*", func(w http.ResponseWriter, r *http.Request) { + urlpath := chi.URLParam(r, "*") + if urlpath == "" || strings.HasSuffix(urlpath, "/") { + s.createRepo(w, r) + } else { + s.postObject(w, r) + } + }) + router.MethodFunc("HEAD", "/*", s.serveObject) + router.MethodFunc("DELETE", "/*", s.deleteObject) +} + +// Bind restic server routes to passed router +func (s *server) Bind(router chi.Router) { + // FIXME + // if m := authX.Auth(authX.Opt); m != nil { + // router.Use(m) + // } + router.Use( + middleware.SetHeader("Accept-Ranges", "bytes"), + middleware.SetHeader("Server", "rclone/"+fs.Version), + WithRemote, + ) + + if s.opt.PrivateRepos { + router.Route("/{userID}", func(r chi.Router) { + r.Use(checkPrivate) + s.bind(r) + }) + router.NotFound(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + }) + } else { + s.bind(router) } - fs.Logf(s.f, "Serving restic REST API on %s", s.URL()) - return nil } var matchData = regexp.MustCompile("(?:^|/)data/([^/]{2,})$") -// Makes a remote from a URL path. This implements the backend layout -// required by restic. -func makeRemote(path string) string { - path = strings.Trim(path, "/") - parts := matchData.FindStringSubmatch(path) - // if no data directory, layout is flat - if parts == nil { - return path - } - // otherwise map - // data/2159dd48 to - // data/21/2159dd48 - fileName := parts[1] - prefix := path[:len(path)-len(fileName)] - return prefix + fileName[:2] + "/" + fileName -} - -// ServeHTTP reads incoming requests and dispatches them -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Accept-Ranges", "bytes") - w.Header().Set("Server", "rclone/"+fs.Version) - - path, ok := s.Path(w, r) - if !ok { - return - } - remote := makeRemote(path) - fs.Debugf(s.f, "%s %s", r.Method, path) - - v := r.Context().Value(httplib.ContextUserKey) - if privateRepos && (v == nil || !strings.HasPrefix(path, "/"+v.(string)+"/")) { - http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) - return - } - - // Dispatch on path then method - if strings.HasSuffix(path, "/") { - switch r.Method { - case "GET": - s.listObjects(w, r, remote) - case "POST": - s.createRepo(w, r, remote) - default: - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - } - } else { - switch r.Method { - case "GET", "HEAD": - s.serveObject(w, r, remote) - case "POST": - s.postObject(w, r, remote) - case "DELETE": - s.deleteObject(w, r, remote) - default: - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - } - } -} - // newObject returns an object with the remote given either from the // cache or directly -func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error) { +func (s *server) newObject(ctx context.Context, remote string) (fs.Object, error) { o := s.cache.find(remote) if o != nil { return o, nil @@ -275,7 +319,12 @@ func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error } // get the remote -func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote string) { +func (s *server) serveObject(w http.ResponseWriter, r *http.Request) { + remote, ok := r.Context().Value(ContextRemoteKey).(string) + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } o, err := s.newObject(r.Context(), remote) if err != nil { fs.Debugf(remote, "%s request error: %v", r.Method, err) @@ -286,8 +335,13 @@ func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote stri } // postObject posts an object to the repository -func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote string) { - if appendOnly { +func (s *server) postObject(w http.ResponseWriter, r *http.Request) { + remote, ok := r.Context().Value(ContextRemoteKey).(string) + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + if s.opt.AppendOnly { // make sure the file does not exist yet _, err := s.newObject(r.Context(), remote) if err == nil { @@ -312,8 +366,13 @@ func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote strin } // delete the remote -func (s *Server) deleteObject(w http.ResponseWriter, r *http.Request, remote string) { - if appendOnly { +func (s *server) deleteObject(w http.ResponseWriter, r *http.Request) { + remote, ok := r.Context().Value(ContextRemoteKey).(string) + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + if s.opt.AppendOnly { parts := strings.Split(r.URL.Path, "/") // if path doesn't end in "/locks/:name", disallow the operation @@ -362,14 +421,18 @@ func (ls *listItems) add(o fs.Object) { } // listObjects lists all Objects of a given type in an arbitrary order. -func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote string) { - fs.Debugf(remote, "list request") - - if r.Header.Get("Accept") != resticAPIV2 { - fs.Errorf(remote, "Restic v2 API required") - http.Error(w, "Restic v2 API required", http.StatusBadRequest) +func (s *server) listObjects(w http.ResponseWriter, r *http.Request) { + remote, ok := r.Context().Value(ContextRemoteKey).(string) + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } + if r.Header.Get("Accept") != resticAPIV2 { + fs.Errorf(remote, "Restic v2 API required for List Objects") + http.Error(w, "Restic v2 API required for List Objects", http.StatusBadRequest) + return + } + fs.Debugf(remote, "list request") // make sure an empty list is returned, and not a 'nil' value ls := listItems{} @@ -408,7 +471,12 @@ func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote stri // createRepo creates repository directories. // // We don't bother creating the data dirs as rclone will create them on the fly -func (s *Server) createRepo(w http.ResponseWriter, r *http.Request, remote string) { +func (s *server) createRepo(w http.ResponseWriter, r *http.Request) { + remote, ok := r.Context().Value(ContextRemoteKey).(string) + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } fs.Infof(remote, "Creating repository") if r.URL.Query().Get("create") != "true" { diff --git a/cmd/serve/restic/restic_appendonly_test.go b/cmd/serve/restic/restic_appendonly_test.go index b3562db9e..2499481a1 100644 --- a/cmd/serve/restic/restic_appendonly_test.go +++ b/cmd/serve/restic/restic_appendonly_test.go @@ -1,6 +1,7 @@ package restic import ( + "context" "crypto/rand" "encoding/hex" "io" @@ -9,7 +10,6 @@ import ( "testing" "github.com/rclone/rclone/cmd" - "github.com/rclone/rclone/cmd/serve/httplib/httpflags" "github.com/rclone/rclone/fs/config/configfile" "github.com/stretchr/testify/require" ) @@ -62,6 +62,7 @@ func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest { // TestResticHandler runs tests on the restic handler code, especially in append-only mode. func TestResticHandler(t *testing.T) { + ctx := context.Background() configfile.Install() buf := make([]byte, 32) _, err := io.ReadFull(rand.Reader, buf) @@ -110,19 +111,18 @@ func TestResticHandler(t *testing.T) { // setup rclone with a local backend in a temporary directory tempdir := t.TempDir() - // globally set append-only mode - prev := appendOnly - appendOnly = true - defer func() { - appendOnly = prev // reset when done - }() + // set append-only mode + opt := newOpt() + opt.AppendOnly = true // make a new file system in the temp dir f := cmd.NewFsSrc([]string{tempdir}) - srv := NewServer(f, &httpflags.Opt) + s, err := newServer(ctx, f, &opt) + require.NoError(t, err) + router := s.Server.Router() // create the repo - checkRequest(t, srv.ServeHTTP, + checkRequest(t, router.ServeHTTP, newRequest(t, "POST", "/?create=true", nil), []wantFunc{wantCode(http.StatusOK)}) @@ -130,7 +130,7 @@ func TestResticHandler(t *testing.T) { t.Run("", func(t *testing.T) { for i, seq := range test.seq { t.Logf("request %v: %v %v", i, seq.req.Method, seq.req.URL.Path) - checkRequest(t, srv.ServeHTTP, seq.req, seq.want) + checkRequest(t, router.ServeHTTP, seq.req, seq.want) } }) } diff --git a/cmd/serve/restic/restic_privaterepos_test.go b/cmd/serve/restic/restic_privaterepos_test.go index 4faae363c..1b89e8681 100644 --- a/cmd/serve/restic/restic_privaterepos_test.go +++ b/cmd/serve/restic/restic_privaterepos_test.go @@ -8,23 +8,21 @@ import ( "strings" "testing" - "github.com/rclone/rclone/cmd/serve/httplib" - "github.com/rclone/rclone/cmd" - "github.com/rclone/rclone/cmd/serve/httplib/httpflags" "github.com/stretchr/testify/require" ) // newAuthenticatedRequest returns a new HTTP request with the given params. -func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader) *http.Request { +func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader, user, pass string) *http.Request { req := newRequest(t, method, path, body) - req = req.WithContext(context.WithValue(req.Context(), httplib.ContextUserKey, "test")) + req.SetBasicAuth(user, pass) req.Header.Add("Accept", resticAPIV2) return req } // TestResticPrivateRepositories runs tests on the restic handler code for private repositories func TestResticPrivateRepositories(t *testing.T) { + ctx := context.Background() buf := make([]byte, 32) _, err := io.ReadFull(rand.Reader, buf) require.NoError(t, err) @@ -32,42 +30,49 @@ func TestResticPrivateRepositories(t *testing.T) { // setup rclone with a local backend in a temporary directory tempdir := t.TempDir() - // globally set private-repos mode & test user - prev := privateRepos - prevUser := httpflags.Opt.BasicUser - prevPassword := httpflags.Opt.BasicPass - privateRepos = true - httpflags.Opt.BasicUser = "test" - httpflags.Opt.BasicPass = "password" - // reset when done - defer func() { - privateRepos = prev - httpflags.Opt.BasicUser = prevUser - httpflags.Opt.BasicPass = prevPassword - }() + opt := newOpt() + + // set private-repos mode & test user + opt.PrivateRepos = true + opt.Auth.BasicUser = "test" + opt.Auth.BasicPass = "password" // make a new file system in the temp dir f := cmd.NewFsSrc([]string{tempdir}) - srv := NewServer(f, &httpflags.Opt) + s, err := newServer(ctx, f, &opt) + require.NoError(t, err) + router := s.Server.Router() // Requesting /test/ should allow access reqs := []*http.Request{ - newAuthenticatedRequest(t, "POST", "/test/?create=true", nil), - newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config")), - newAuthenticatedRequest(t, "GET", "/test/config", nil), + newAuthenticatedRequest(t, "POST", "/test/?create=true", nil, opt.Auth.BasicUser, opt.Auth.BasicPass), + newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config"), opt.Auth.BasicUser, opt.Auth.BasicPass), + newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass), } for _, req := range reqs { - checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)}) + checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)}) + } + + // Requesting with bad credentials should raise unauthorised errors + reqs = []*http.Request{ + newRequest(t, "GET", "/test/config", nil), + newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, ""), + newAuthenticatedRequest(t, "GET", "/test/config", nil, "", opt.Auth.BasicPass), + newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser+"x", opt.Auth.BasicPass), + newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass+"x"), + } + for _, req := range reqs { + checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusUnauthorized)}) } // Requesting everything else should raise forbidden errors reqs = []*http.Request{ - newAuthenticatedRequest(t, "GET", "/", nil), - newAuthenticatedRequest(t, "POST", "/other_user", nil), - newAuthenticatedRequest(t, "GET", "/other_user/config", nil), + newAuthenticatedRequest(t, "GET", "/", nil, opt.Auth.BasicUser, opt.Auth.BasicPass), + newAuthenticatedRequest(t, "POST", "/other_user", nil, opt.Auth.BasicUser, opt.Auth.BasicPass), + newAuthenticatedRequest(t, "GET", "/other_user/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass), } for _, req := range reqs { - checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)}) + checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)}) } } diff --git a/cmd/serve/restic/restic_test.go b/cmd/serve/restic/restic_test.go index d9f343ca5..3db040bdd 100644 --- a/cmd/serve/restic/restic_test.go +++ b/cmd/serve/restic/restic_test.go @@ -5,14 +5,16 @@ package restic import ( "context" + "net/http" + "net/http/httptest" "os" "os/exec" "testing" _ "github.com/rclone/rclone/backend/all" - "github.com/rclone/rclone/cmd/serve/httplib" "github.com/rclone/rclone/fstest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -20,16 +22,24 @@ const ( resticSource = "../../../../../restic/restic" ) +func newOpt() Options { + opt := DefaultOpt + opt.HTTP.ListenAddr = []string{testBindAddress} + return opt +} + // TestRestic runs the restic server then runs the unit tests for the // restic remote against it. -func TestRestic(t *testing.T) { +// +// Requires the restic source code in the location indicated by resticSource. +func TestResticIntegration(t *testing.T) { + ctx := context.Background() _, err := os.Stat(resticSource) if err != nil { t.Skipf("Skipping test as restic source not found: %v", err) } - opt := httplib.DefaultOpt - opt.ListenAddr = testBindAddress + opt := newOpt() fstest.Initialise() @@ -41,16 +51,16 @@ func TestRestic(t *testing.T) { assert.NoError(t, err) // Start the server - w := NewServer(fremote, &opt) - assert.NoError(t, w.Serve()) + s, err := newServer(ctx, fremote, &opt) + require.NoError(t, err) + testURL := s.Server.URLs()[0] defer func() { - w.Close() - w.Wait() + _ = s.Shutdown() }() // Change directory to run the tests err = os.Chdir(resticSource) - assert.NoError(t, err, "failed to cd to restic source code") + require.NoError(t, err, "failed to cd to restic source code") // Run the restic tests runTests := func(path string) { @@ -60,7 +70,7 @@ func TestRestic(t *testing.T) { } cmd := exec.Command("go", args...) cmd.Env = append(os.Environ(), - "RESTIC_TEST_REST_REPOSITORY=rest:"+w.Server.URL()+path, + "RESTIC_TEST_REST_REPOSITORY=rest:"+testURL+path, "GO111MODULE=on", ) out, err := cmd.CombinedOutput() @@ -81,7 +91,6 @@ func TestMakeRemote(t *testing.T) { for _, test := range []struct { in, want string }{ - {"", ""}, {"/", ""}, {"/data", "data"}, {"/data/", "data"}, @@ -94,7 +103,14 @@ func TestMakeRemote(t *testing.T) { {"/keys/12", "keys/12"}, {"/keys/123", "keys/123"}, } { - got := makeRemote(test.in) - assert.Equal(t, test.want, got, test.in) + r := httptest.NewRequest("GET", test.in, nil) + w := httptest.NewRecorder() + next := http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) { + remote, ok := request.Context().Value(ContextRemoteKey).(string) + assert.True(t, ok, "Failed to get remote from context") + assert.Equal(t, test.want, remote, test.in) + }) + got := WithRemote(next) + got.ServeHTTP(w, r) } } diff --git a/cmd/serve/restic/restic_utils_test.go b/cmd/serve/restic/restic_utils_test.go index 0721f7c0f..587c7e1ab 100644 --- a/cmd/serve/restic/restic_utils_test.go +++ b/cmd/serve/restic/restic_utils_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // declare a few helper functions @@ -15,11 +14,10 @@ import ( // wantFunc tests the HTTP response in res and marks the test as errored if something is incorrect. type wantFunc func(t testing.TB, res *httptest.ResponseRecorder) -// newRequest returns a new HTTP request with the given params. On error, the -// test is marked as failed. +// newRequest returns a new HTTP request with the given params func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request { - req, err := http.NewRequest(method, path, body) - require.NoError(t, err) + req := httptest.NewRequest(method, path, body) + req.Header.Add("Accept", resticAPIV2) return req }