forked from TrueCloudLab/rclone
http servers: allow CORS to be set with --allow-origin flag - fixes #5078
Some changes about test cases: Because MiddlewareCORS will return early on OPTIONS request, this middleware should only be used once at NewServer function. Test cases should pass AllowOrigin config instead of adding this middleware again. A new test case was added to test CORS preflight request with an authenticator. Preflight request should always return 200 OK regardless of autentications. Co-authored-by: yuudi <yuudi@users.noreply.github.com>
This commit is contained in:
parent
3ed4a2e963
commit
6c8148ef39
8 changed files with 95 additions and 75 deletions
31
fs/rc/rc.go
31
fs/rc/rc.go
|
@ -18,22 +18,21 @@ import (
|
|||
|
||||
// Options contains options for the remote control server
|
||||
type Options struct {
|
||||
HTTP libhttp.Config
|
||||
Auth libhttp.AuthConfig
|
||||
Template libhttp.TemplateConfig
|
||||
Enabled bool // set to enable the server
|
||||
Serve bool // set to serve files from remotes
|
||||
Files string // set to enable serving files locally
|
||||
NoAuth bool // set to disable auth checks on AuthRequired methods
|
||||
WebUI bool // set to launch the web ui
|
||||
WebGUIUpdate bool // set to check new update
|
||||
WebGUIForceUpdate bool // set to force download new update
|
||||
WebGUINoOpenBrowser bool // set to disable auto opening browser
|
||||
WebGUIFetchURL string // set the default url for fetching webgui
|
||||
AccessControlAllowOrigin string // set the access control for CORS configuration
|
||||
EnableMetrics bool // set to disable prometheus metrics on /metrics
|
||||
JobExpireDuration time.Duration
|
||||
JobExpireInterval time.Duration
|
||||
HTTP libhttp.Config
|
||||
Auth libhttp.AuthConfig
|
||||
Template libhttp.TemplateConfig
|
||||
Enabled bool // set to enable the server
|
||||
Serve bool // set to serve files from remotes
|
||||
Files string // set to enable serving files locally
|
||||
NoAuth bool // set to disable auth checks on AuthRequired methods
|
||||
WebUI bool // set to launch the web ui
|
||||
WebGUIUpdate bool // set to check new update
|
||||
WebGUIForceUpdate bool // set to force download new update
|
||||
WebGUINoOpenBrowser bool // set to disable auto opening browser
|
||||
WebGUIFetchURL string // set the default url for fetching webgui
|
||||
EnableMetrics bool // set to disable prometheus metrics on /metrics
|
||||
JobExpireDuration time.Duration
|
||||
JobExpireInterval time.Duration
|
||||
}
|
||||
|
||||
// DefaultOpt is the default values used for Options
|
||||
|
|
|
@ -27,7 +27,6 @@ func AddFlags(flagSet *pflag.FlagSet) {
|
|||
flags.BoolVarP(flagSet, &Opt.WebGUIForceUpdate, "rc-web-gui-force-update", "", false, "Force update to latest version of web gui")
|
||||
flags.BoolVarP(flagSet, &Opt.WebGUINoOpenBrowser, "rc-web-gui-no-open-browser", "", false, "Don't open the browser automatically")
|
||||
flags.StringVarP(flagSet, &Opt.WebGUIFetchURL, "rc-web-fetch-url", "", "https://api.github.com/repos/rclone/rclone-webui-react/releases/latest", "URL to fetch the releases for webgui")
|
||||
flags.StringVarP(flagSet, &Opt.AccessControlAllowOrigin, "rc-allow-origin", "", "", "Set the allowed origin for CORS")
|
||||
flags.BoolVarP(flagSet, &Opt.EnableMetrics, "rc-enable-metrics", "", false, "Enable prometheus metrics on /metrics")
|
||||
flags.DurationVarP(flagSet, &Opt.JobExpireDuration, "rc-job-expire-duration", "", Opt.JobExpireDuration, "Expire finished async jobs older than this value")
|
||||
flags.DurationVarP(flagSet, &Opt.JobExpireInterval, "rc-job-expire-interval", "", Opt.JobExpireInterval, "Interval to check for expired async jobs")
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
|
@ -38,7 +37,6 @@ import (
|
|||
)
|
||||
|
||||
var promHandler http.Handler
|
||||
var onlyOnceWarningAllowOrigin sync.Once
|
||||
|
||||
func init() {
|
||||
rcloneCollector := accounting.NewRcloneCollector(context.Background())
|
||||
|
@ -214,23 +212,6 @@ func writeError(path string, in rc.Params, w http.ResponseWriter, err error, sta
|
|||
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
|
||||
path := strings.TrimLeft(r.URL.Path, "/")
|
||||
|
||||
allowOrigin := rcflags.Opt.AccessControlAllowOrigin
|
||||
if allowOrigin != "" {
|
||||
onlyOnceWarningAllowOrigin.Do(func() {
|
||||
if allowOrigin == "*" {
|
||||
fs.Logf(nil, "Warning: Allow origin set to *. This can cause serious security problems.")
|
||||
}
|
||||
})
|
||||
w.Header().Add("Access-Control-Allow-Origin", allowOrigin)
|
||||
} else {
|
||||
urls := s.server.URLs()
|
||||
if len(urls) == 1 {
|
||||
w.Header().Add("Access-Control-Allow-Origin", urls[0])
|
||||
} else {
|
||||
fs.Errorf(nil, "Warning, need exactly 1 URL for Access-Control-Allow-Origin, got %d %q", len(urls), urls)
|
||||
}
|
||||
}
|
||||
|
||||
// echo back access control headers client needs
|
||||
//reqAccessHeaders := r.Header.Get("Access-Control-Request-Headers")
|
||||
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")
|
||||
|
|
|
@ -552,32 +552,6 @@ Unknown command
|
|||
testServer(t, tests, &opt)
|
||||
}
|
||||
|
||||
func TestMethods(t *testing.T) {
|
||||
tests := []testRun{{
|
||||
Name: "options",
|
||||
URL: "",
|
||||
Method: "OPTIONS",
|
||||
Status: http.StatusOK,
|
||||
Expected: "",
|
||||
Headers: map[string]string{
|
||||
"Access-Control-Allow-Origin": "testURL",
|
||||
"Access-Control-Request-Method": "POST, OPTIONS, GET, HEAD",
|
||||
"Access-Control-Allow-Headers": "authorization, Content-Type",
|
||||
},
|
||||
}, {
|
||||
Name: "bad",
|
||||
URL: "",
|
||||
Method: "POTATO",
|
||||
Status: http.StatusMethodNotAllowed,
|
||||
Expected: `Method Not Allowed
|
||||
`,
|
||||
}}
|
||||
opt := newTestOpt()
|
||||
opt.Serve = true
|
||||
opt.Files = testFs
|
||||
testServer(t, tests, &opt)
|
||||
}
|
||||
|
||||
func TestMetrics(t *testing.T) {
|
||||
stats := accounting.GlobalStats()
|
||||
tests := makeMetricsTestCases(stats)
|
||||
|
|
|
@ -181,6 +181,13 @@ func MiddlewareCORS(allowOrigin string) Middleware {
|
|||
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")
|
||||
w.Header().Add("Access-Control-Allow-Headers", "authorization, Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
// Because CORS preflight OPTIONS requests are not authenticated,
|
||||
// and require a 200 OK response, we will return early here.
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -329,23 +329,22 @@ var _testCORSHeaderKeys = []string{
|
|||
|
||||
func TestMiddlewareCORS(t *testing.T) {
|
||||
servers := []struct {
|
||||
name string
|
||||
http Config
|
||||
origin string
|
||||
name string
|
||||
http Config
|
||||
}{
|
||||
{
|
||||
name: "EmptyOrigin",
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
AllowOrigin: "",
|
||||
},
|
||||
origin: "",
|
||||
},
|
||||
{
|
||||
name: "CustomOrigin",
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
AllowOrigin: "http://test.rclone.org",
|
||||
},
|
||||
origin: "http://test.rclone.org",
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -357,8 +356,6 @@ func TestMiddlewareCORS(t *testing.T) {
|
|||
require.NoError(t, s.Shutdown())
|
||||
}()
|
||||
|
||||
s.Router().Use(MiddlewareCORS(ss.origin))
|
||||
|
||||
expected := []byte("data")
|
||||
s.Router().Mount("/", testEchoHandler(expected))
|
||||
s.Serve()
|
||||
|
@ -384,8 +381,69 @@ func TestMiddlewareCORS(t *testing.T) {
|
|||
}
|
||||
|
||||
expectedOrigin := url
|
||||
if ss.origin != "" {
|
||||
expectedOrigin = ss.origin
|
||||
if ss.http.AllowOrigin != "" {
|
||||
expectedOrigin = ss.http.AllowOrigin
|
||||
}
|
||||
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareCORSWithAuth(t *testing.T) {
|
||||
authServers := []struct {
|
||||
name string
|
||||
http Config
|
||||
auth AuthConfig
|
||||
}{
|
||||
{
|
||||
name: "ServerWithAuth",
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
AllowOrigin: "http://test.rclone.org",
|
||||
},
|
||||
auth: AuthConfig{
|
||||
Realm: "test",
|
||||
BasicUser: "test_user",
|
||||
BasicPass: "test_pass",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, ss := range authServers {
|
||||
t.Run(ss.name, func(t *testing.T) {
|
||||
s, err := NewServer(context.Background(), WithConfig(ss.http))
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, s.Shutdown())
|
||||
}()
|
||||
|
||||
expected := []byte("data")
|
||||
s.Router().Mount("/", testEchoHandler(expected))
|
||||
s.Serve()
|
||||
|
||||
url := testGetServerURL(t, s)
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("OPTIONS", url, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode, "OPTIONS should return ok even if not authenticated")
|
||||
|
||||
testExpectRespBody(t, resp, []byte{})
|
||||
|
||||
for _, key := range _testCORSHeaderKeys {
|
||||
require.Contains(t, resp.Header, key, "CORS headers should be sent even if not authenticated")
|
||||
}
|
||||
|
||||
expectedOrigin := url
|
||||
if ss.http.AllowOrigin != "" {
|
||||
expectedOrigin = ss.http.AllowOrigin
|
||||
}
|
||||
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
|
||||
})
|
||||
|
|
|
@ -109,6 +109,7 @@ type Config struct {
|
|||
TLSKeyBody []byte // TLS PEM Private key body, ignores TLSKey
|
||||
ClientCA string // Client certificate authority to verify clients with
|
||||
MinTLSVersion string // MinTLSVersion contains the minimum TLS version that is acceptable.
|
||||
AllowOrigin string // AllowOrigin sets the Access-Control-Allow-Origin header
|
||||
}
|
||||
|
||||
// AddFlagsPrefix adds flags for the httplib
|
||||
|
@ -122,6 +123,7 @@ func (cfg *Config) AddFlagsPrefix(flagSet *pflag.FlagSet, prefix string) {
|
|||
flags.StringVarP(flagSet, &cfg.ClientCA, prefix+"client-ca", "", cfg.ClientCA, "Client certificate authority to verify clients with")
|
||||
flags.StringVarP(flagSet, &cfg.BaseURL, prefix+"baseurl", "", cfg.BaseURL, "Prefix for URLs - leave blank for root")
|
||||
flags.StringVarP(flagSet, &cfg.MinTLSVersion, prefix+"min-tls-version", "", cfg.MinTLSVersion, "Minimum TLS version that is acceptable")
|
||||
flags.StringVarP(flagSet, &cfg.AllowOrigin, prefix+"allow-origin", "", cfg.AllowOrigin, "Origin which cross-domain request (CORS) can be executed from")
|
||||
}
|
||||
|
||||
// AddHTTPFlagsPrefix adds flags for the httplib
|
||||
|
@ -236,6 +238,8 @@ func NewServer(ctx context.Context, options ...Option) (*Server, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
s.mux.Use(MiddlewareCORS(s.cfg.AllowOrigin))
|
||||
|
||||
s.initAuth()
|
||||
|
||||
for _, addr := range s.cfg.ListenAddr {
|
||||
|
|
|
@ -82,8 +82,6 @@ func TestNewServerUnix(t *testing.T) {
|
|||
|
||||
require.Empty(t, s.URLs(), "unix socket should not appear in URLs")
|
||||
|
||||
s.Router().Use(MiddlewareCORS(""))
|
||||
|
||||
expected := []byte("hello world")
|
||||
s.Router().Mount("/", testEchoHandler(expected))
|
||||
s.Serve()
|
||||
|
|
Loading…
Reference in a new issue