forked from TrueCloudLab/rclone
restic: refactor to use lib/http
Co-authored-by: Nick Craig-Wood <nick@craig-wood.com>
This commit is contained in:
parent
4444d2d102
commit
52443c2444
7 changed files with 270 additions and 181 deletions
|
@ -9,20 +9,22 @@ import (
|
|||
|
||||
// cache implements a simple object cache
|
||||
type cache struct {
|
||||
mu sync.RWMutex // protects the cache
|
||||
items map[string]fs.Object // cache of objects
|
||||
mu sync.RWMutex // protects the cache
|
||||
items map[string]fs.Object // cache of objects
|
||||
cacheObjects bool // whether we are actually caching
|
||||
}
|
||||
|
||||
// create a new cache
|
||||
func newCache() *cache {
|
||||
func newCache(cacheObjects bool) *cache {
|
||||
return &cache{
|
||||
items: map[string]fs.Object{},
|
||||
items: map[string]fs.Object{},
|
||||
cacheObjects: cacheObjects,
|
||||
}
|
||||
}
|
||||
|
||||
// find the object at remote or return nil
|
||||
func (c *cache) find(remote string) fs.Object {
|
||||
if !cacheObjects {
|
||||
if !c.cacheObjects {
|
||||
return nil
|
||||
}
|
||||
c.mu.RLock()
|
||||
|
@ -33,7 +35,7 @@ func (c *cache) find(remote string) fs.Object {
|
|||
|
||||
// add the object to the cache
|
||||
func (c *cache) add(remote string, o fs.Object) {
|
||||
if !cacheObjects {
|
||||
if !c.cacheObjects {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
|
@ -43,7 +45,7 @@ func (c *cache) add(remote string, o fs.Object) {
|
|||
|
||||
// remove the object from the cache
|
||||
func (c *cache) remove(remote string) {
|
||||
if !cacheObjects {
|
||||
if !c.cacheObjects {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
|
@ -53,7 +55,7 @@ func (c *cache) remove(remote string) {
|
|||
|
||||
// remove all the items with prefix from the cache
|
||||
func (c *cache) removePrefix(prefix string) {
|
||||
if !cacheObjects {
|
||||
if !c.cacheObjects {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ func (c *cache) String() string {
|
|||
}
|
||||
|
||||
func TestCacheCRUD(t *testing.T) {
|
||||
c := newCache()
|
||||
c := newCache(true)
|
||||
assert.Equal(t, "", c.String())
|
||||
assert.Nil(t, c.find("potato"))
|
||||
o := mockobject.New("potato")
|
||||
|
@ -35,7 +35,7 @@ func TestCacheCRUD(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCacheRemovePrefix(t *testing.T) {
|
||||
c := newCache()
|
||||
c := newCache(true)
|
||||
for _, remote := range []string{
|
||||
"a",
|
||||
"b",
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
@ -12,34 +13,48 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/rclone/rclone/cmd"
|
||||
"github.com/rclone/rclone/cmd/serve/httplib"
|
||||
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
|
||||
"github.com/rclone/rclone/fs"
|
||||
"github.com/rclone/rclone/fs/accounting"
|
||||
"github.com/rclone/rclone/fs/config/flags"
|
||||
"github.com/rclone/rclone/fs/operations"
|
||||
"github.com/rclone/rclone/fs/walk"
|
||||
libhttp "github.com/rclone/rclone/lib/http"
|
||||
"github.com/rclone/rclone/lib/http/serve"
|
||||
"github.com/rclone/rclone/lib/terminal"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
stdio bool
|
||||
appendOnly bool
|
||||
privateRepos bool
|
||||
cacheObjects bool
|
||||
)
|
||||
// Options required for http server
|
||||
type Options struct {
|
||||
Auth libhttp.AuthConfig
|
||||
HTTP libhttp.Config
|
||||
Stdio bool
|
||||
AppendOnly bool
|
||||
PrivateRepos bool
|
||||
CacheObjects bool
|
||||
}
|
||||
|
||||
// DefaultOpt is the default values used for Options
|
||||
var DefaultOpt = Options{
|
||||
Auth: libhttp.DefaultAuthCfg(),
|
||||
HTTP: libhttp.DefaultCfg(),
|
||||
}
|
||||
|
||||
// Opt is options set by command line flags
|
||||
var Opt = DefaultOpt
|
||||
|
||||
func init() {
|
||||
httpflags.AddFlags(Command.Flags())
|
||||
flagSet := Command.Flags()
|
||||
flags.BoolVarP(flagSet, &stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout")
|
||||
flags.BoolVarP(flagSet, &appendOnly, "append-only", "", false, "Disallow deletion of repository data")
|
||||
flags.BoolVarP(flagSet, &privateRepos, "private-repos", "", false, "Users can only access their private repo")
|
||||
flags.BoolVarP(flagSet, &cacheObjects, "cache-objects", "", true, "Cache listed objects")
|
||||
libhttp.AddAuthFlagsPrefix(flagSet, "", &Opt.Auth)
|
||||
libhttp.AddHTTPFlagsPrefix(flagSet, "", &Opt.HTTP)
|
||||
flags.BoolVarP(flagSet, &Opt.Stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout")
|
||||
flags.BoolVarP(flagSet, &Opt.AppendOnly, "append-only", "", false, "Disallow deletion of repository data")
|
||||
flags.BoolVarP(flagSet, &Opt.PrivateRepos, "private-repos", "", false, "Users can only access their private repo")
|
||||
flags.BoolVarP(flagSet, &Opt.CacheObjects, "cache-objects", "", true, "Cache listed objects")
|
||||
}
|
||||
|
||||
// Command definition for cobra
|
||||
|
@ -127,16 +142,21 @@ these **must** end with /. Eg
|
|||
|
||||
The` + "`--private-repos`" + ` flag can be used to limit users to repositories starting
|
||||
with a path of ` + "`/<username>/`" + `.
|
||||
` + httplib.Help,
|
||||
` + libhttp.Help + libhttp.AuthHelp,
|
||||
Annotations: map[string]string{
|
||||
"versionIntroduced": "v1.40",
|
||||
},
|
||||
Run: func(command *cobra.Command, args []string) {
|
||||
ctx := context.Background()
|
||||
cmd.CheckArgs(1, 1, command, args)
|
||||
f := cmd.NewFsSrc(args)
|
||||
cmd.Run(false, true, command, func() error {
|
||||
s := NewServer(f, &httpflags.Opt)
|
||||
if stdio {
|
||||
s, err := newServer(ctx, f, &Opt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fs.Logf(s.f, "Serving restic REST API on %s", s.URLs())
|
||||
if s.opt.Stdio {
|
||||
if terminal.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return errors.New("refusing to run HTTP2 server directly on a terminal, please let restic start rclone")
|
||||
}
|
||||
|
@ -148,16 +168,11 @@ with a path of ` + "`/<username>/`" + `.
|
|||
|
||||
httpSrv := &http2.Server{}
|
||||
opts := &http2.ServeConnOpts{
|
||||
Handler: s,
|
||||
Handler: s.Server.Router(),
|
||||
}
|
||||
httpSrv.ServeConn(conn, opts)
|
||||
return nil
|
||||
}
|
||||
err := s.Serve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Wait()
|
||||
return nil
|
||||
})
|
||||
},
|
||||
|
@ -167,101 +182,130 @@ const (
|
|||
resticAPIV2 = "application/vnd.x.restic.rest.v2"
|
||||
)
|
||||
|
||||
// Server contains everything to run the Server
|
||||
type Server struct {
|
||||
*httplib.Server
|
||||
type contextRemoteType struct{}
|
||||
|
||||
// ContextRemoteKey is a simple context key for storing the username of the request
|
||||
var ContextRemoteKey = &contextRemoteType{}
|
||||
|
||||
// WithRemote makes a remote from a URL path. This implements the backend layout
|
||||
// required by restic.
|
||||
func WithRemote(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var urlpath string
|
||||
rctx := chi.RouteContext(r.Context())
|
||||
if rctx != nil && rctx.RoutePath != "" {
|
||||
urlpath = rctx.RoutePath
|
||||
} else {
|
||||
urlpath = r.URL.Path
|
||||
}
|
||||
urlpath = strings.Trim(urlpath, "/")
|
||||
parts := matchData.FindStringSubmatch(urlpath)
|
||||
// if no data directory, layout is flat
|
||||
if parts != nil {
|
||||
// otherwise map
|
||||
// data/2159dd48 to
|
||||
// data/21/2159dd48
|
||||
fileName := parts[1]
|
||||
prefix := urlpath[:len(urlpath)-len(fileName)]
|
||||
urlpath = prefix + fileName[:2] + "/" + fileName
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ContextRemoteKey, urlpath)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// Middleware to ensure authenticated user is accessing their own private folder
|
||||
func checkPrivate(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := chi.URLParam(r, "userID")
|
||||
userID, ok := libhttp.CtxGetUser(r.Context())
|
||||
if ok && user != "" && user == userID {
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// server contains everything to run the server
|
||||
type server struct {
|
||||
*libhttp.Server
|
||||
f fs.Fs
|
||||
cache *cache
|
||||
opt Options
|
||||
}
|
||||
|
||||
// NewServer returns an HTTP server that speaks the rest protocol
|
||||
func NewServer(f fs.Fs, opt *httplib.Options) *Server {
|
||||
mux := http.NewServeMux()
|
||||
s := &Server{
|
||||
Server: httplib.NewServer(mux, opt),
|
||||
f: f,
|
||||
cache: newCache(),
|
||||
func newServer(ctx context.Context, f fs.Fs, opt *Options) (s *server, err error) {
|
||||
s = &server{
|
||||
f: f,
|
||||
cache: newCache(opt.CacheObjects),
|
||||
opt: *opt,
|
||||
}
|
||||
mux.HandleFunc(s.Opt.BaseURL+"/", s.ServeHTTP)
|
||||
return s
|
||||
}
|
||||
|
||||
// Serve runs the http server in the background.
|
||||
//
|
||||
// Use s.Close() and s.Wait() to shutdown server
|
||||
func (s *Server) Serve() error {
|
||||
err := s.Server.Serve()
|
||||
s.Server, err = libhttp.NewServer(ctx,
|
||||
libhttp.WithConfig(opt.HTTP),
|
||||
libhttp.WithAuth(opt.Auth),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to init server: %w", err)
|
||||
}
|
||||
router := s.Router()
|
||||
s.Bind(router)
|
||||
s.Server.Serve()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// bind helper for main Bind method
|
||||
func (s *server) bind(router chi.Router) {
|
||||
router.MethodFunc("GET", "/*", func(w http.ResponseWriter, r *http.Request) {
|
||||
urlpath := chi.URLParam(r, "*")
|
||||
if urlpath == "" || strings.HasSuffix(urlpath, "/") {
|
||||
s.listObjects(w, r)
|
||||
} else {
|
||||
s.serveObject(w, r)
|
||||
}
|
||||
})
|
||||
router.MethodFunc("POST", "/*", func(w http.ResponseWriter, r *http.Request) {
|
||||
urlpath := chi.URLParam(r, "*")
|
||||
if urlpath == "" || strings.HasSuffix(urlpath, "/") {
|
||||
s.createRepo(w, r)
|
||||
} else {
|
||||
s.postObject(w, r)
|
||||
}
|
||||
})
|
||||
router.MethodFunc("HEAD", "/*", s.serveObject)
|
||||
router.MethodFunc("DELETE", "/*", s.deleteObject)
|
||||
}
|
||||
|
||||
// Bind restic server routes to passed router
|
||||
func (s *server) Bind(router chi.Router) {
|
||||
// FIXME
|
||||
// if m := authX.Auth(authX.Opt); m != nil {
|
||||
// router.Use(m)
|
||||
// }
|
||||
router.Use(
|
||||
middleware.SetHeader("Accept-Ranges", "bytes"),
|
||||
middleware.SetHeader("Server", "rclone/"+fs.Version),
|
||||
WithRemote,
|
||||
)
|
||||
|
||||
if s.opt.PrivateRepos {
|
||||
router.Route("/{userID}", func(r chi.Router) {
|
||||
r.Use(checkPrivate)
|
||||
s.bind(r)
|
||||
})
|
||||
router.NotFound(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
})
|
||||
} else {
|
||||
s.bind(router)
|
||||
}
|
||||
fs.Logf(s.f, "Serving restic REST API on %s", s.URL())
|
||||
return nil
|
||||
}
|
||||
|
||||
var matchData = regexp.MustCompile("(?:^|/)data/([^/]{2,})$")
|
||||
|
||||
// Makes a remote from a URL path. This implements the backend layout
|
||||
// required by restic.
|
||||
func makeRemote(path string) string {
|
||||
path = strings.Trim(path, "/")
|
||||
parts := matchData.FindStringSubmatch(path)
|
||||
// if no data directory, layout is flat
|
||||
if parts == nil {
|
||||
return path
|
||||
}
|
||||
// otherwise map
|
||||
// data/2159dd48 to
|
||||
// data/21/2159dd48
|
||||
fileName := parts[1]
|
||||
prefix := path[:len(path)-len(fileName)]
|
||||
return prefix + fileName[:2] + "/" + fileName
|
||||
}
|
||||
|
||||
// ServeHTTP reads incoming requests and dispatches them
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.Header().Set("Server", "rclone/"+fs.Version)
|
||||
|
||||
path, ok := s.Path(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
remote := makeRemote(path)
|
||||
fs.Debugf(s.f, "%s %s", r.Method, path)
|
||||
|
||||
v := r.Context().Value(httplib.ContextUserKey)
|
||||
if privateRepos && (v == nil || !strings.HasPrefix(path, "/"+v.(string)+"/")) {
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Dispatch on path then method
|
||||
if strings.HasSuffix(path, "/") {
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
s.listObjects(w, r, remote)
|
||||
case "POST":
|
||||
s.createRepo(w, r, remote)
|
||||
default:
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
}
|
||||
} else {
|
||||
switch r.Method {
|
||||
case "GET", "HEAD":
|
||||
s.serveObject(w, r, remote)
|
||||
case "POST":
|
||||
s.postObject(w, r, remote)
|
||||
case "DELETE":
|
||||
s.deleteObject(w, r, remote)
|
||||
default:
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newObject returns an object with the remote given either from the
|
||||
// cache or directly
|
||||
func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error) {
|
||||
func (s *server) newObject(ctx context.Context, remote string) (fs.Object, error) {
|
||||
o := s.cache.find(remote)
|
||||
if o != nil {
|
||||
return o, nil
|
||||
|
@ -275,7 +319,12 @@ func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error
|
|||
}
|
||||
|
||||
// get the remote
|
||||
func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote string) {
|
||||
func (s *server) serveObject(w http.ResponseWriter, r *http.Request) {
|
||||
remote, ok := r.Context().Value(ContextRemoteKey).(string)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
o, err := s.newObject(r.Context(), remote)
|
||||
if err != nil {
|
||||
fs.Debugf(remote, "%s request error: %v", r.Method, err)
|
||||
|
@ -286,8 +335,13 @@ func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote stri
|
|||
}
|
||||
|
||||
// postObject posts an object to the repository
|
||||
func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote string) {
|
||||
if appendOnly {
|
||||
func (s *server) postObject(w http.ResponseWriter, r *http.Request) {
|
||||
remote, ok := r.Context().Value(ContextRemoteKey).(string)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if s.opt.AppendOnly {
|
||||
// make sure the file does not exist yet
|
||||
_, err := s.newObject(r.Context(), remote)
|
||||
if err == nil {
|
||||
|
@ -312,8 +366,13 @@ func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote strin
|
|||
}
|
||||
|
||||
// delete the remote
|
||||
func (s *Server) deleteObject(w http.ResponseWriter, r *http.Request, remote string) {
|
||||
if appendOnly {
|
||||
func (s *server) deleteObject(w http.ResponseWriter, r *http.Request) {
|
||||
remote, ok := r.Context().Value(ContextRemoteKey).(string)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if s.opt.AppendOnly {
|
||||
parts := strings.Split(r.URL.Path, "/")
|
||||
|
||||
// if path doesn't end in "/locks/:name", disallow the operation
|
||||
|
@ -362,14 +421,18 @@ func (ls *listItems) add(o fs.Object) {
|
|||
}
|
||||
|
||||
// listObjects lists all Objects of a given type in an arbitrary order.
|
||||
func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote string) {
|
||||
fs.Debugf(remote, "list request")
|
||||
|
||||
if r.Header.Get("Accept") != resticAPIV2 {
|
||||
fs.Errorf(remote, "Restic v2 API required")
|
||||
http.Error(w, "Restic v2 API required", http.StatusBadRequest)
|
||||
func (s *server) listObjects(w http.ResponseWriter, r *http.Request) {
|
||||
remote, ok := r.Context().Value(ContextRemoteKey).(string)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Accept") != resticAPIV2 {
|
||||
fs.Errorf(remote, "Restic v2 API required for List Objects")
|
||||
http.Error(w, "Restic v2 API required for List Objects", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fs.Debugf(remote, "list request")
|
||||
|
||||
// make sure an empty list is returned, and not a 'nil' value
|
||||
ls := listItems{}
|
||||
|
@ -408,7 +471,12 @@ func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote stri
|
|||
// createRepo creates repository directories.
|
||||
//
|
||||
// We don't bother creating the data dirs as rclone will create them on the fly
|
||||
func (s *Server) createRepo(w http.ResponseWriter, r *http.Request, remote string) {
|
||||
func (s *server) createRepo(w http.ResponseWriter, r *http.Request) {
|
||||
remote, ok := r.Context().Value(ContextRemoteKey).(string)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
fs.Infof(remote, "Creating repository")
|
||||
|
||||
if r.URL.Query().Get("create") != "true" {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package restic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
|
@ -9,7 +10,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/rclone/rclone/cmd"
|
||||
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
|
||||
"github.com/rclone/rclone/fs/config/configfile"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -62,6 +62,7 @@ func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest {
|
|||
|
||||
// TestResticHandler runs tests on the restic handler code, especially in append-only mode.
|
||||
func TestResticHandler(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
configfile.Install()
|
||||
buf := make([]byte, 32)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
|
@ -110,19 +111,18 @@ func TestResticHandler(t *testing.T) {
|
|||
// setup rclone with a local backend in a temporary directory
|
||||
tempdir := t.TempDir()
|
||||
|
||||
// globally set append-only mode
|
||||
prev := appendOnly
|
||||
appendOnly = true
|
||||
defer func() {
|
||||
appendOnly = prev // reset when done
|
||||
}()
|
||||
// set append-only mode
|
||||
opt := newOpt()
|
||||
opt.AppendOnly = true
|
||||
|
||||
// make a new file system in the temp dir
|
||||
f := cmd.NewFsSrc([]string{tempdir})
|
||||
srv := NewServer(f, &httpflags.Opt)
|
||||
s, err := newServer(ctx, f, &opt)
|
||||
require.NoError(t, err)
|
||||
router := s.Server.Router()
|
||||
|
||||
// create the repo
|
||||
checkRequest(t, srv.ServeHTTP,
|
||||
checkRequest(t, router.ServeHTTP,
|
||||
newRequest(t, "POST", "/?create=true", nil),
|
||||
[]wantFunc{wantCode(http.StatusOK)})
|
||||
|
||||
|
@ -130,7 +130,7 @@ func TestResticHandler(t *testing.T) {
|
|||
t.Run("", func(t *testing.T) {
|
||||
for i, seq := range test.seq {
|
||||
t.Logf("request %v: %v %v", i, seq.req.Method, seq.req.URL.Path)
|
||||
checkRequest(t, srv.ServeHTTP, seq.req, seq.want)
|
||||
checkRequest(t, router.ServeHTTP, seq.req, seq.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -8,23 +8,21 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rclone/rclone/cmd/serve/httplib"
|
||||
|
||||
"github.com/rclone/rclone/cmd"
|
||||
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newAuthenticatedRequest returns a new HTTP request with the given params.
|
||||
func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
|
||||
func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader, user, pass string) *http.Request {
|
||||
req := newRequest(t, method, path, body)
|
||||
req = req.WithContext(context.WithValue(req.Context(), httplib.ContextUserKey, "test"))
|
||||
req.SetBasicAuth(user, pass)
|
||||
req.Header.Add("Accept", resticAPIV2)
|
||||
return req
|
||||
}
|
||||
|
||||
// TestResticPrivateRepositories runs tests on the restic handler code for private repositories
|
||||
func TestResticPrivateRepositories(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
buf := make([]byte, 32)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
require.NoError(t, err)
|
||||
|
@ -32,42 +30,49 @@ func TestResticPrivateRepositories(t *testing.T) {
|
|||
// setup rclone with a local backend in a temporary directory
|
||||
tempdir := t.TempDir()
|
||||
|
||||
// globally set private-repos mode & test user
|
||||
prev := privateRepos
|
||||
prevUser := httpflags.Opt.BasicUser
|
||||
prevPassword := httpflags.Opt.BasicPass
|
||||
privateRepos = true
|
||||
httpflags.Opt.BasicUser = "test"
|
||||
httpflags.Opt.BasicPass = "password"
|
||||
// reset when done
|
||||
defer func() {
|
||||
privateRepos = prev
|
||||
httpflags.Opt.BasicUser = prevUser
|
||||
httpflags.Opt.BasicPass = prevPassword
|
||||
}()
|
||||
opt := newOpt()
|
||||
|
||||
// set private-repos mode & test user
|
||||
opt.PrivateRepos = true
|
||||
opt.Auth.BasicUser = "test"
|
||||
opt.Auth.BasicPass = "password"
|
||||
|
||||
// make a new file system in the temp dir
|
||||
f := cmd.NewFsSrc([]string{tempdir})
|
||||
srv := NewServer(f, &httpflags.Opt)
|
||||
s, err := newServer(ctx, f, &opt)
|
||||
require.NoError(t, err)
|
||||
router := s.Server.Router()
|
||||
|
||||
// Requesting /test/ should allow access
|
||||
reqs := []*http.Request{
|
||||
newAuthenticatedRequest(t, "POST", "/test/?create=true", nil),
|
||||
newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config")),
|
||||
newAuthenticatedRequest(t, "GET", "/test/config", nil),
|
||||
newAuthenticatedRequest(t, "POST", "/test/?create=true", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
|
||||
newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config"), opt.Auth.BasicUser, opt.Auth.BasicPass),
|
||||
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
|
||||
}
|
||||
for _, req := range reqs {
|
||||
checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)})
|
||||
checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)})
|
||||
}
|
||||
|
||||
// Requesting with bad credentials should raise unauthorised errors
|
||||
reqs = []*http.Request{
|
||||
newRequest(t, "GET", "/test/config", nil),
|
||||
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, ""),
|
||||
newAuthenticatedRequest(t, "GET", "/test/config", nil, "", opt.Auth.BasicPass),
|
||||
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser+"x", opt.Auth.BasicPass),
|
||||
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass+"x"),
|
||||
}
|
||||
for _, req := range reqs {
|
||||
checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusUnauthorized)})
|
||||
}
|
||||
|
||||
// Requesting everything else should raise forbidden errors
|
||||
reqs = []*http.Request{
|
||||
newAuthenticatedRequest(t, "GET", "/", nil),
|
||||
newAuthenticatedRequest(t, "POST", "/other_user", nil),
|
||||
newAuthenticatedRequest(t, "GET", "/other_user/config", nil),
|
||||
newAuthenticatedRequest(t, "GET", "/", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
|
||||
newAuthenticatedRequest(t, "POST", "/other_user", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
|
||||
newAuthenticatedRequest(t, "GET", "/other_user/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
|
||||
}
|
||||
for _, req := range reqs {
|
||||
checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)})
|
||||
checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)})
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -5,14 +5,16 @@ package restic
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
_ "github.com/rclone/rclone/backend/all"
|
||||
"github.com/rclone/rclone/cmd/serve/httplib"
|
||||
"github.com/rclone/rclone/fstest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -20,16 +22,24 @@ const (
|
|||
resticSource = "../../../../../restic/restic"
|
||||
)
|
||||
|
||||
func newOpt() Options {
|
||||
opt := DefaultOpt
|
||||
opt.HTTP.ListenAddr = []string{testBindAddress}
|
||||
return opt
|
||||
}
|
||||
|
||||
// TestRestic runs the restic server then runs the unit tests for the
|
||||
// restic remote against it.
|
||||
func TestRestic(t *testing.T) {
|
||||
//
|
||||
// Requires the restic source code in the location indicated by resticSource.
|
||||
func TestResticIntegration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := os.Stat(resticSource)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test as restic source not found: %v", err)
|
||||
}
|
||||
|
||||
opt := httplib.DefaultOpt
|
||||
opt.ListenAddr = testBindAddress
|
||||
opt := newOpt()
|
||||
|
||||
fstest.Initialise()
|
||||
|
||||
|
@ -41,16 +51,16 @@ func TestRestic(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// Start the server
|
||||
w := NewServer(fremote, &opt)
|
||||
assert.NoError(t, w.Serve())
|
||||
s, err := newServer(ctx, fremote, &opt)
|
||||
require.NoError(t, err)
|
||||
testURL := s.Server.URLs()[0]
|
||||
defer func() {
|
||||
w.Close()
|
||||
w.Wait()
|
||||
_ = s.Shutdown()
|
||||
}()
|
||||
|
||||
// Change directory to run the tests
|
||||
err = os.Chdir(resticSource)
|
||||
assert.NoError(t, err, "failed to cd to restic source code")
|
||||
require.NoError(t, err, "failed to cd to restic source code")
|
||||
|
||||
// Run the restic tests
|
||||
runTests := func(path string) {
|
||||
|
@ -60,7 +70,7 @@ func TestRestic(t *testing.T) {
|
|||
}
|
||||
cmd := exec.Command("go", args...)
|
||||
cmd.Env = append(os.Environ(),
|
||||
"RESTIC_TEST_REST_REPOSITORY=rest:"+w.Server.URL()+path,
|
||||
"RESTIC_TEST_REST_REPOSITORY=rest:"+testURL+path,
|
||||
"GO111MODULE=on",
|
||||
)
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
@ -81,7 +91,6 @@ func TestMakeRemote(t *testing.T) {
|
|||
for _, test := range []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"", ""},
|
||||
{"/", ""},
|
||||
{"/data", "data"},
|
||||
{"/data/", "data"},
|
||||
|
@ -94,7 +103,14 @@ func TestMakeRemote(t *testing.T) {
|
|||
{"/keys/12", "keys/12"},
|
||||
{"/keys/123", "keys/123"},
|
||||
} {
|
||||
got := makeRemote(test.in)
|
||||
assert.Equal(t, test.want, got, test.in)
|
||||
r := httptest.NewRequest("GET", test.in, nil)
|
||||
w := httptest.NewRecorder()
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) {
|
||||
remote, ok := request.Context().Value(ContextRemoteKey).(string)
|
||||
assert.True(t, ok, "Failed to get remote from context")
|
||||
assert.Equal(t, test.want, remote, test.in)
|
||||
})
|
||||
got := WithRemote(next)
|
||||
got.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// declare a few helper functions
|
||||
|
@ -15,11 +14,10 @@ import (
|
|||
// wantFunc tests the HTTP response in res and marks the test as errored if something is incorrect.
|
||||
type wantFunc func(t testing.TB, res *httptest.ResponseRecorder)
|
||||
|
||||
// newRequest returns a new HTTP request with the given params. On error, the
|
||||
// test is marked as failed.
|
||||
// newRequest returns a new HTTP request with the given params
|
||||
func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
|
||||
req, err := http.NewRequest(method, path, body)
|
||||
require.NoError(t, err)
|
||||
req := httptest.NewRequest(method, path, body)
|
||||
req.Header.Add("Accept", resticAPIV2)
|
||||
return req
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue