forked from TrueCloudLab/rclone
http servers: allow CORS to be set with --allow-origin flag - fixes #5078
Some changes about test cases: Because MiddlewareCORS will return early on OPTIONS request, this middleware should only be used once at NewServer function. Test cases should pass AllowOrigin config instead of adding this middleware again. A new test case was added to test CORS preflight request with an authenticator. Preflight request should always return 200 OK regardless of autentications. Co-authored-by: yuudi <yuudi@users.noreply.github.com>
This commit is contained in:
parent
3ed4a2e963
commit
6c8148ef39
8 changed files with 95 additions and 75 deletions
31
fs/rc/rc.go
31
fs/rc/rc.go
|
@ -18,22 +18,21 @@ import (
|
||||||
|
|
||||||
// Options contains options for the remote control server
|
// Options contains options for the remote control server
|
||||||
type Options struct {
|
type Options struct {
|
||||||
HTTP libhttp.Config
|
HTTP libhttp.Config
|
||||||
Auth libhttp.AuthConfig
|
Auth libhttp.AuthConfig
|
||||||
Template libhttp.TemplateConfig
|
Template libhttp.TemplateConfig
|
||||||
Enabled bool // set to enable the server
|
Enabled bool // set to enable the server
|
||||||
Serve bool // set to serve files from remotes
|
Serve bool // set to serve files from remotes
|
||||||
Files string // set to enable serving files locally
|
Files string // set to enable serving files locally
|
||||||
NoAuth bool // set to disable auth checks on AuthRequired methods
|
NoAuth bool // set to disable auth checks on AuthRequired methods
|
||||||
WebUI bool // set to launch the web ui
|
WebUI bool // set to launch the web ui
|
||||||
WebGUIUpdate bool // set to check new update
|
WebGUIUpdate bool // set to check new update
|
||||||
WebGUIForceUpdate bool // set to force download new update
|
WebGUIForceUpdate bool // set to force download new update
|
||||||
WebGUINoOpenBrowser bool // set to disable auto opening browser
|
WebGUINoOpenBrowser bool // set to disable auto opening browser
|
||||||
WebGUIFetchURL string // set the default url for fetching webgui
|
WebGUIFetchURL string // set the default url for fetching webgui
|
||||||
AccessControlAllowOrigin string // set the access control for CORS configuration
|
EnableMetrics bool // set to disable prometheus metrics on /metrics
|
||||||
EnableMetrics bool // set to disable prometheus metrics on /metrics
|
JobExpireDuration time.Duration
|
||||||
JobExpireDuration time.Duration
|
JobExpireInterval time.Duration
|
||||||
JobExpireInterval time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultOpt is the default values used for Options
|
// DefaultOpt is the default values used for Options
|
||||||
|
|
|
@ -27,7 +27,6 @@ func AddFlags(flagSet *pflag.FlagSet) {
|
||||||
flags.BoolVarP(flagSet, &Opt.WebGUIForceUpdate, "rc-web-gui-force-update", "", false, "Force update to latest version of web gui")
|
flags.BoolVarP(flagSet, &Opt.WebGUIForceUpdate, "rc-web-gui-force-update", "", false, "Force update to latest version of web gui")
|
||||||
flags.BoolVarP(flagSet, &Opt.WebGUINoOpenBrowser, "rc-web-gui-no-open-browser", "", false, "Don't open the browser automatically")
|
flags.BoolVarP(flagSet, &Opt.WebGUINoOpenBrowser, "rc-web-gui-no-open-browser", "", false, "Don't open the browser automatically")
|
||||||
flags.StringVarP(flagSet, &Opt.WebGUIFetchURL, "rc-web-fetch-url", "", "https://api.github.com/repos/rclone/rclone-webui-react/releases/latest", "URL to fetch the releases for webgui")
|
flags.StringVarP(flagSet, &Opt.WebGUIFetchURL, "rc-web-fetch-url", "", "https://api.github.com/repos/rclone/rclone-webui-react/releases/latest", "URL to fetch the releases for webgui")
|
||||||
flags.StringVarP(flagSet, &Opt.AccessControlAllowOrigin, "rc-allow-origin", "", "", "Set the allowed origin for CORS")
|
|
||||||
flags.BoolVarP(flagSet, &Opt.EnableMetrics, "rc-enable-metrics", "", false, "Enable prometheus metrics on /metrics")
|
flags.BoolVarP(flagSet, &Opt.EnableMetrics, "rc-enable-metrics", "", false, "Enable prometheus metrics on /metrics")
|
||||||
flags.DurationVarP(flagSet, &Opt.JobExpireDuration, "rc-job-expire-duration", "", Opt.JobExpireDuration, "Expire finished async jobs older than this value")
|
flags.DurationVarP(flagSet, &Opt.JobExpireDuration, "rc-job-expire-duration", "", Opt.JobExpireDuration, "Expire finished async jobs older than this value")
|
||||||
flags.DurationVarP(flagSet, &Opt.JobExpireInterval, "rc-job-expire-interval", "", Opt.JobExpireInterval, "Interval to check for expired async jobs")
|
flags.DurationVarP(flagSet, &Opt.JobExpireInterval, "rc-job-expire-interval", "", Opt.JobExpireInterval, "Interval to check for expired async jobs")
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
|
@ -38,7 +37,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var promHandler http.Handler
|
var promHandler http.Handler
|
||||||
var onlyOnceWarningAllowOrigin sync.Once
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rcloneCollector := accounting.NewRcloneCollector(context.Background())
|
rcloneCollector := accounting.NewRcloneCollector(context.Background())
|
||||||
|
@ -214,23 +212,6 @@ func writeError(path string, in rc.Params, w http.ResponseWriter, err error, sta
|
||||||
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
|
||||||
path := strings.TrimLeft(r.URL.Path, "/")
|
path := strings.TrimLeft(r.URL.Path, "/")
|
||||||
|
|
||||||
allowOrigin := rcflags.Opt.AccessControlAllowOrigin
|
|
||||||
if allowOrigin != "" {
|
|
||||||
onlyOnceWarningAllowOrigin.Do(func() {
|
|
||||||
if allowOrigin == "*" {
|
|
||||||
fs.Logf(nil, "Warning: Allow origin set to *. This can cause serious security problems.")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
w.Header().Add("Access-Control-Allow-Origin", allowOrigin)
|
|
||||||
} else {
|
|
||||||
urls := s.server.URLs()
|
|
||||||
if len(urls) == 1 {
|
|
||||||
w.Header().Add("Access-Control-Allow-Origin", urls[0])
|
|
||||||
} else {
|
|
||||||
fs.Errorf(nil, "Warning, need exactly 1 URL for Access-Control-Allow-Origin, got %d %q", len(urls), urls)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// echo back access control headers client needs
|
// echo back access control headers client needs
|
||||||
//reqAccessHeaders := r.Header.Get("Access-Control-Request-Headers")
|
//reqAccessHeaders := r.Header.Get("Access-Control-Request-Headers")
|
||||||
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")
|
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")
|
||||||
|
|
|
@ -552,32 +552,6 @@ Unknown command
|
||||||
testServer(t, tests, &opt)
|
testServer(t, tests, &opt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMethods(t *testing.T) {
|
|
||||||
tests := []testRun{{
|
|
||||||
Name: "options",
|
|
||||||
URL: "",
|
|
||||||
Method: "OPTIONS",
|
|
||||||
Status: http.StatusOK,
|
|
||||||
Expected: "",
|
|
||||||
Headers: map[string]string{
|
|
||||||
"Access-Control-Allow-Origin": "testURL",
|
|
||||||
"Access-Control-Request-Method": "POST, OPTIONS, GET, HEAD",
|
|
||||||
"Access-Control-Allow-Headers": "authorization, Content-Type",
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "bad",
|
|
||||||
URL: "",
|
|
||||||
Method: "POTATO",
|
|
||||||
Status: http.StatusMethodNotAllowed,
|
|
||||||
Expected: `Method Not Allowed
|
|
||||||
`,
|
|
||||||
}}
|
|
||||||
opt := newTestOpt()
|
|
||||||
opt.Serve = true
|
|
||||||
opt.Files = testFs
|
|
||||||
testServer(t, tests, &opt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMetrics(t *testing.T) {
|
func TestMetrics(t *testing.T) {
|
||||||
stats := accounting.GlobalStats()
|
stats := accounting.GlobalStats()
|
||||||
tests := makeMetricsTestCases(stats)
|
tests := makeMetricsTestCases(stats)
|
||||||
|
|
|
@ -181,6 +181,13 @@ func MiddlewareCORS(allowOrigin string) Middleware {
|
||||||
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")
|
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")
|
||||||
w.Header().Add("Access-Control-Allow-Headers", "authorization, Content-Type")
|
w.Header().Add("Access-Control-Allow-Headers", "authorization, Content-Type")
|
||||||
|
|
||||||
|
if r.Method == "OPTIONS" {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
// Because CORS preflight OPTIONS requests are not authenticated,
|
||||||
|
// and require a 200 OK response, we will return early here.
|
||||||
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -329,23 +329,22 @@ var _testCORSHeaderKeys = []string{
|
||||||
|
|
||||||
func TestMiddlewareCORS(t *testing.T) {
|
func TestMiddlewareCORS(t *testing.T) {
|
||||||
servers := []struct {
|
servers := []struct {
|
||||||
name string
|
name string
|
||||||
http Config
|
http Config
|
||||||
origin string
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "EmptyOrigin",
|
name: "EmptyOrigin",
|
||||||
http: Config{
|
http: Config{
|
||||||
ListenAddr: []string{"127.0.0.1:0"},
|
ListenAddr: []string{"127.0.0.1:0"},
|
||||||
|
AllowOrigin: "",
|
||||||
},
|
},
|
||||||
origin: "",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "CustomOrigin",
|
name: "CustomOrigin",
|
||||||
http: Config{
|
http: Config{
|
||||||
ListenAddr: []string{"127.0.0.1:0"},
|
ListenAddr: []string{"127.0.0.1:0"},
|
||||||
|
AllowOrigin: "http://test.rclone.org",
|
||||||
},
|
},
|
||||||
origin: "http://test.rclone.org",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,8 +356,6 @@ func TestMiddlewareCORS(t *testing.T) {
|
||||||
require.NoError(t, s.Shutdown())
|
require.NoError(t, s.Shutdown())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s.Router().Use(MiddlewareCORS(ss.origin))
|
|
||||||
|
|
||||||
expected := []byte("data")
|
expected := []byte("data")
|
||||||
s.Router().Mount("/", testEchoHandler(expected))
|
s.Router().Mount("/", testEchoHandler(expected))
|
||||||
s.Serve()
|
s.Serve()
|
||||||
|
@ -384,8 +381,69 @@ func TestMiddlewareCORS(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedOrigin := url
|
expectedOrigin := url
|
||||||
if ss.origin != "" {
|
if ss.http.AllowOrigin != "" {
|
||||||
expectedOrigin = ss.origin
|
expectedOrigin = ss.http.AllowOrigin
|
||||||
|
}
|
||||||
|
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareCORSWithAuth(t *testing.T) {
|
||||||
|
authServers := []struct {
|
||||||
|
name string
|
||||||
|
http Config
|
||||||
|
auth AuthConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ServerWithAuth",
|
||||||
|
http: Config{
|
||||||
|
ListenAddr: []string{"127.0.0.1:0"},
|
||||||
|
AllowOrigin: "http://test.rclone.org",
|
||||||
|
},
|
||||||
|
auth: AuthConfig{
|
||||||
|
Realm: "test",
|
||||||
|
BasicUser: "test_user",
|
||||||
|
BasicPass: "test_pass",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ss := range authServers {
|
||||||
|
t.Run(ss.name, func(t *testing.T) {
|
||||||
|
s, err := NewServer(context.Background(), WithConfig(ss.http))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, s.Shutdown())
|
||||||
|
}()
|
||||||
|
|
||||||
|
expected := []byte("data")
|
||||||
|
s.Router().Mount("/", testEchoHandler(expected))
|
||||||
|
s.Serve()
|
||||||
|
|
||||||
|
url := testGetServerURL(t, s)
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
req, err := http.NewRequest("OPTIONS", url, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode, "OPTIONS should return ok even if not authenticated")
|
||||||
|
|
||||||
|
testExpectRespBody(t, resp, []byte{})
|
||||||
|
|
||||||
|
for _, key := range _testCORSHeaderKeys {
|
||||||
|
require.Contains(t, resp.Header, key, "CORS headers should be sent even if not authenticated")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedOrigin := url
|
||||||
|
if ss.http.AllowOrigin != "" {
|
||||||
|
expectedOrigin = ss.http.AllowOrigin
|
||||||
}
|
}
|
||||||
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
|
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
|
||||||
})
|
})
|
||||||
|
|
|
@ -109,6 +109,7 @@ type Config struct {
|
||||||
TLSKeyBody []byte // TLS PEM Private key body, ignores TLSKey
|
TLSKeyBody []byte // TLS PEM Private key body, ignores TLSKey
|
||||||
ClientCA string // Client certificate authority to verify clients with
|
ClientCA string // Client certificate authority to verify clients with
|
||||||
MinTLSVersion string // MinTLSVersion contains the minimum TLS version that is acceptable.
|
MinTLSVersion string // MinTLSVersion contains the minimum TLS version that is acceptable.
|
||||||
|
AllowOrigin string // AllowOrigin sets the Access-Control-Allow-Origin header
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFlagsPrefix adds flags for the httplib
|
// AddFlagsPrefix adds flags for the httplib
|
||||||
|
@ -122,6 +123,7 @@ func (cfg *Config) AddFlagsPrefix(flagSet *pflag.FlagSet, prefix string) {
|
||||||
flags.StringVarP(flagSet, &cfg.ClientCA, prefix+"client-ca", "", cfg.ClientCA, "Client certificate authority to verify clients with")
|
flags.StringVarP(flagSet, &cfg.ClientCA, prefix+"client-ca", "", cfg.ClientCA, "Client certificate authority to verify clients with")
|
||||||
flags.StringVarP(flagSet, &cfg.BaseURL, prefix+"baseurl", "", cfg.BaseURL, "Prefix for URLs - leave blank for root")
|
flags.StringVarP(flagSet, &cfg.BaseURL, prefix+"baseurl", "", cfg.BaseURL, "Prefix for URLs - leave blank for root")
|
||||||
flags.StringVarP(flagSet, &cfg.MinTLSVersion, prefix+"min-tls-version", "", cfg.MinTLSVersion, "Minimum TLS version that is acceptable")
|
flags.StringVarP(flagSet, &cfg.MinTLSVersion, prefix+"min-tls-version", "", cfg.MinTLSVersion, "Minimum TLS version that is acceptable")
|
||||||
|
flags.StringVarP(flagSet, &cfg.AllowOrigin, prefix+"allow-origin", "", cfg.AllowOrigin, "Origin which cross-domain request (CORS) can be executed from")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHTTPFlagsPrefix adds flags for the httplib
|
// AddHTTPFlagsPrefix adds flags for the httplib
|
||||||
|
@ -236,6 +238,8 @@ func NewServer(ctx context.Context, options ...Option) (*Server, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.mux.Use(MiddlewareCORS(s.cfg.AllowOrigin))
|
||||||
|
|
||||||
s.initAuth()
|
s.initAuth()
|
||||||
|
|
||||||
for _, addr := range s.cfg.ListenAddr {
|
for _, addr := range s.cfg.ListenAddr {
|
||||||
|
|
|
@ -82,8 +82,6 @@ func TestNewServerUnix(t *testing.T) {
|
||||||
|
|
||||||
require.Empty(t, s.URLs(), "unix socket should not appear in URLs")
|
require.Empty(t, s.URLs(), "unix socket should not appear in URLs")
|
||||||
|
|
||||||
s.Router().Use(MiddlewareCORS(""))
|
|
||||||
|
|
||||||
expected := []byte("hello world")
|
expected := []byte("hello world")
|
||||||
s.Router().Mount("/", testEchoHandler(expected))
|
s.Router().Mount("/", testEchoHandler(expected))
|
||||||
s.Serve()
|
s.Serve()
|
||||||
|
|
Loading…
Add table
Reference in a new issue