package restic

import (
	"context"
	"crypto/rand"
	"io"
	"io/ioutil"
	"net/http"
	"os"
	"strings"
	"testing"

	"github.com/rclone/rclone/cmd/serve/httplib"

	"github.com/rclone/rclone/cmd"
	"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
	"github.com/stretchr/testify/require"
)

// newAuthenticatedRequest returns a new HTTP request with the given params.
func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
	req := newRequest(t, method, path, body)
	req = req.WithContext(context.WithValue(req.Context(), httplib.ContextUserKey, "test"))
	req.Header.Add("Accept", resticAPIV2)
	return req
}

// TestResticPrivateRepositories runs tests on the restic handler code for private repositories
func TestResticPrivateRepositories(t *testing.T) {
	buf := make([]byte, 32)
	_, err := io.ReadFull(rand.Reader, buf)
	require.NoError(t, err)

	// setup rclone with a local backend in a temporary directory
	tempdir, err := ioutil.TempDir("", "rclone-restic-test-")
	require.NoError(t, err)

	// make sure the tempdir is properly removed
	defer func() {
		err := os.RemoveAll(tempdir)
		require.NoError(t, err)
	}()

	// globally set private-repos mode & test user
	prev := privateRepos
	prevUser := httpflags.Opt.BasicUser
	prevPassword := httpflags.Opt.BasicPass
	privateRepos = true
	httpflags.Opt.BasicUser = "test"
	httpflags.Opt.BasicPass = "password"
	// reset when done
	defer func() {
		privateRepos = prev
		httpflags.Opt.BasicUser = prevUser
		httpflags.Opt.BasicPass = prevPassword
	}()

	// make a new file system in the temp dir
	f := cmd.NewFsSrc([]string{tempdir})
	srv := newServer(f, &httpflags.Opt)

	// Requesting /test/ should allow access
	reqs := []*http.Request{
		newAuthenticatedRequest(t, "POST", "/test/?create=true", nil),
		newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config")),
		newAuthenticatedRequest(t, "GET", "/test/config", nil),
	}
	for _, req := range reqs {
		checkRequest(t, srv.handler, req, []wantFunc{wantCode(http.StatusOK)})
	}

	// Requesting everything else should raise forbidden errors
	reqs = []*http.Request{
		newAuthenticatedRequest(t, "GET", "/", nil),
		newAuthenticatedRequest(t, "POST", "/other_user", nil),
		newAuthenticatedRequest(t, "GET", "/other_user/config", nil),
	}
	for _, req := range reqs {
		checkRequest(t, srv.handler, req, []wantFunc{wantCode(http.StatusForbidden)})
	}

}