cmd: add support for private repositories in serve restic - fixes #3247

This commit is contained in:
Florian Apolloner 2019-06-07 19:47:46 +02:00 committed by Nick Craig-Wood
parent 64fb4effa7
commit 939b19c3b7
5 changed files with 164 additions and 50 deletions

View file

@ -2,6 +2,7 @@
package httplib package httplib
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@ -114,6 +115,11 @@ type Server struct {
HTMLTemplate *template.Template // HTML template for web interface HTMLTemplate *template.Template // HTML template for web interface
} }
type contextUserType struct{}
// ContextUserKey is a simple context key
var ContextUserKey = &contextUserType{}
// singleUserProvider provides the encrypted password for a single user // singleUserProvider provides the encrypted password for a single user
func (s *Server) singleUserProvider(user, realm string) string { func (s *Server) singleUserProvider(user, realm string) string {
if user == s.Opt.BasicUser { if user == s.Opt.BasicUser {
@ -172,6 +178,7 @@ func NewServer(handler http.Handler, opt *Options) *Server {
} }
authenticator.RequireAuth(w, r) authenticator.RequireAuth(w, r)
} else { } else {
r = r.WithContext(context.WithValue(r.Context(), ContextUserKey, username))
oldHandler.ServeHTTP(w, r) oldHandler.ServeHTTP(w, r)
} }
}) })

View file

@ -29,14 +29,16 @@ import (
) )
var ( var (
stdio bool stdio bool
appendOnly bool appendOnly bool
privateRepos bool
) )
func init() { func init() {
httpflags.AddFlags(Command.Flags()) httpflags.AddFlags(Command.Flags())
Command.Flags().BoolVar(&stdio, "stdio", false, "run an HTTP2 server on stdin/stdout") Command.Flags().BoolVar(&stdio, "stdio", false, "run an HTTP2 server on stdin/stdout")
Command.Flags().BoolVar(&appendOnly, "append-only", false, "disallow deletion of repository data") Command.Flags().BoolVar(&appendOnly, "append-only", false, "disallow deletion of repository data")
Command.Flags().BoolVar(&privateRepos, "private-repos", false, "users can only access their private repo")
} }
// Command definition for cobra // Command definition for cobra
@ -94,14 +96,14 @@ For example:
$ export RESTIC_PASSWORD=yourpassword $ export RESTIC_PASSWORD=yourpassword
$ restic init $ restic init
created restic backend 8b1a4b56ae at rest:http://localhost:8080/ created restic backend 8b1a4b56ae at rest:http://localhost:8080/
Please note that knowledge of your password is required to access Please note that knowledge of your password is required to access
the repository. Losing your password means that your data is the repository. Losing your password means that your data is
irrecoverably lost. irrecoverably lost.
$ restic backup /path/to/files/to/backup $ restic backup /path/to/files/to/backup
scan [/path/to/files/to/backup] scan [/path/to/files/to/backup]
scanned 189 directories, 312 files in 0:00 scanned 189 directories, 312 files in 0:00
[0:00] 100.00% 38.128 MiB / 38.128 MiB 501 / 501 items 0 errors ETA 0:00 [0:00] 100.00% 38.128 MiB / 38.128 MiB 501 / 501 items 0 errors ETA 0:00
duration: 0:00 duration: 0:00
snapshot 45c8fdd8 saved snapshot 45c8fdd8 saved
@ -116,6 +118,10 @@ these **must** end with /. Eg
$ export RESTIC_REPOSITORY=rest:http://localhost:8080/user2repo/ $ export RESTIC_REPOSITORY=rest:http://localhost:8080/user2repo/
# backup user2 stuff # backup user2 stuff
#### Private repositories ####
The "--private-repos" flag can be used to limit users to repositories starting
with a path of "/<username>/".
` + httplib.Help, ` + httplib.Help,
Run: func(command *cobra.Command, args []string) { Run: func(command *cobra.Command, args []string) {
cmd.CheckArgs(1, 1, command, args) cmd.CheckArgs(1, 1, command, args)
@ -209,6 +215,12 @@ func (s *server) handler(w http.ResponseWriter, r *http.Request) {
remote := makeRemote(path) remote := makeRemote(path)
fs.Debugf(s.f, "%s %s", r.Method, path) fs.Debugf(s.f, "%s %s", r.Method, path)
v := r.Context().Value(httplib.ContextUserKey)
if privateRepos && (v == nil || !strings.HasPrefix(path, "/"+v.(string)+"/")) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
// Dispatch on path then method // Dispatch on path then method
if strings.HasSuffix(path, "/") { if strings.HasSuffix(path, "/") {
switch r.Method { switch r.Method {

View file

@ -8,61 +8,15 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest"
"os" "os"
"strings" "strings"
"testing" "testing"
"github.com/ncw/rclone/cmd" "github.com/ncw/rclone/cmd"
"github.com/ncw/rclone/cmd/serve/httplib/httpflags" "github.com/ncw/rclone/cmd/serve/httplib/httpflags"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// declare a few helper functions
// wantFunc tests the HTTP response in res and marks the test as errored if something is incorrect.
type wantFunc func(t testing.TB, res *httptest.ResponseRecorder)
// newRequest returns a new HTTP request with the given params. On error, the
// test is marked as failed.
func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, path, body)
require.NoError(t, err)
return req
}
// wantCode returns a function which checks that the response has the correct HTTP status code.
func wantCode(code int) wantFunc {
return func(t testing.TB, res *httptest.ResponseRecorder) {
assert.Equal(t, code, res.Code)
}
}
// wantBody returns a function which checks that the response has the data in the body.
func wantBody(body string) wantFunc {
return func(t testing.TB, res *httptest.ResponseRecorder) {
assert.NotNil(t, res.Body)
assert.Equal(t, res.Body.Bytes(), []byte(body))
}
}
// checkRequest uses f to process the request and runs the checker functions on the result.
func checkRequest(t testing.TB, f http.HandlerFunc, req *http.Request, want []wantFunc) {
rr := httptest.NewRecorder()
f(rr, req)
for _, fn := range want {
fn(t, rr)
}
}
// TestRequest is a sequence of HTTP requests with (optional) tests for the response.
type TestRequest struct {
req *http.Request
want []wantFunc
}
// createOverwriteDeleteSeq returns a sequence which will create a new file at // createOverwriteDeleteSeq returns a sequence which will create a new file at
// path, and then try to overwrite and delete it. // path, and then try to overwrite and delete it.
func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest { func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest {

View file

@ -0,0 +1,84 @@
// +build go1.9
package restic
import (
"context"
"crypto/rand"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
"testing"
"github.com/ncw/rclone/cmd/serve/httplib"
"github.com/ncw/rclone/cmd"
"github.com/ncw/rclone/cmd/serve/httplib/httpflags"
"github.com/stretchr/testify/require"
)
// newAuthenticatedRequest returns a new HTTP request with the given params.
func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
req := newRequest(t, method, path, body)
req = req.WithContext(context.WithValue(req.Context(), httplib.ContextUserKey, "test"))
req.Header.Add("Accept", resticAPIV2)
return req
}
// TestResticPrivateRepositories runs tests on the restic handler code for private repositories
func TestResticPrivateRepositories(t *testing.T) {
buf := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, buf)
require.NoError(t, err)
// setup rclone with a local backend in a temporary directory
tempdir, err := ioutil.TempDir("", "rclone-restic-test-")
require.NoError(t, err)
// make sure the tempdir is properly removed
defer func() {
err := os.RemoveAll(tempdir)
require.NoError(t, err)
}()
// globally set private-repos mode & test user
prev := privateRepos
prevUser := httpflags.Opt.BasicUser
prevPassword := httpflags.Opt.BasicPass
privateRepos = true
httpflags.Opt.BasicUser = "test"
httpflags.Opt.BasicPass = "password"
// reset when done
defer func() {
privateRepos = prev
httpflags.Opt.BasicUser = prevUser
httpflags.Opt.BasicPass = prevPassword
}()
// make a new file system in the temp dir
f := cmd.NewFsSrc([]string{tempdir})
srv := newServer(f, &httpflags.Opt)
// Requesting /test/ should allow access
reqs := []*http.Request{
newAuthenticatedRequest(t, "POST", "/test/?create=true", nil),
newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config")),
newAuthenticatedRequest(t, "GET", "/test/config", nil),
}
for _, req := range reqs {
checkRequest(t, srv.handler, req, []wantFunc{wantCode(http.StatusOK)})
}
// Requesting everything else should raise forbidden errors
reqs = []*http.Request{
newAuthenticatedRequest(t, "GET", "/", nil),
newAuthenticatedRequest(t, "POST", "/other_user", nil),
newAuthenticatedRequest(t, "GET", "/other_user/config", nil),
}
for _, req := range reqs {
checkRequest(t, srv.handler, req, []wantFunc{wantCode(http.StatusForbidden)})
}
}

View file

@ -0,0 +1,57 @@
// +build go1.9
package restic
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// declare a few helper functions
// wantFunc tests the HTTP response in res and marks the test as errored if something is incorrect.
type wantFunc func(t testing.TB, res *httptest.ResponseRecorder)
// newRequest returns a new HTTP request with the given params. On error, the
// test is marked as failed.
func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, path, body)
require.NoError(t, err)
return req
}
// wantCode returns a function which checks that the response has the correct HTTP status code.
func wantCode(code int) wantFunc {
return func(t testing.TB, res *httptest.ResponseRecorder) {
assert.Equal(t, code, res.Code)
}
}
// wantBody returns a function which checks that the response has the data in the body.
func wantBody(body string) wantFunc {
return func(t testing.TB, res *httptest.ResponseRecorder) {
assert.NotNil(t, res.Body)
assert.Equal(t, res.Body.Bytes(), []byte(body))
}
}
// checkRequest uses f to process the request and runs the checker functions on the result.
func checkRequest(t testing.TB, f http.HandlerFunc, req *http.Request, want []wantFunc) {
rr := httptest.NewRecorder()
f(rr, req)
for _, fn := range want {
fn(t, rr)
}
}
// TestRequest is a sequence of HTTP requests with (optional) tests for the response.
type TestRequest struct {
req *http.Request
want []wantFunc
}