From 273459e0904b97929143a0237d59222419bc59ee Mon Sep 17 00:00:00 2001 From: Marina Biryukova Date: Mon, 7 Apr 2025 16:50:48 +0300 Subject: [PATCH] [#225] Support wildcard in allowed origins and headers Signed-off-by: Marina Biryukova --- internal/handler/cors.go | 29 +- internal/handler/cors_test.go | 490 ++++++++++++++++++++++++++++++++++ 2 files changed, 510 insertions(+), 9 deletions(-) diff --git a/internal/handler/cors.go b/internal/handler/cors.go index d77ae02..bbfce1e 100644 --- a/internal/handler/cors.go +++ b/internal/handler/cors.go @@ -5,6 +5,8 @@ import ( "encoding/xml" "errors" "fmt" + "regexp" + "slices" "sort" "strconv" "strings" @@ -78,7 +80,7 @@ func (h *Handler) Preflight(req *fasthttp.RequestCtx) { for _, rule := range corsConfig.CORSRules { for _, o := range rule.AllowedOrigins { - if o == string(origin) || o == wildcard { + if o == string(origin) || o == wildcard || (strings.Contains(o, "*") && match(o, string(origin))) { for _, m := range rule.AllowedMethods { if m == string(method) { if !checkSubslice(rule.AllowedHeaders, headers) { @@ -117,6 +119,11 @@ func (h *Handler) SetCORSHeaders(req *fasthttp.RequestCtx) { return } + method := req.Request.Header.Peek(fasthttp.HeaderAccessControlRequestMethod) + if len(method) == 0 { + method = req.Method() + } + ctx = qostagging.ContextWithIOTag(ctx, internalIOTag) cidParam, _ := req.UserValue("cid").(string) reqLog := h.reqLogger(ctx) @@ -141,9 +148,9 @@ func (h *Handler) SetCORSHeaders(req *fasthttp.RequestCtx) { for _, rule := range corsConfig.CORSRules { for _, o := range rule.AllowedOrigins { - if o == string(origin) { + if o == string(origin) || (strings.Contains(o, "*") && len(o) > 1 && match(o, string(origin))) { for _, m := range rule.AllowedMethods { - if m == string(req.Method()) { + if m == string(method) { req.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin)) req.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", ")) req.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true") @@ -154,7 +161,7 @@ func (h *Handler) SetCORSHeaders(req *fasthttp.RequestCtx) { } if o == wildcard { for _, m := range rule.AllowedMethods { - if m == string(req.Method()) { + if m == string(method) { if withCredentials { req.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin)) req.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true") @@ -318,12 +325,9 @@ func setCORSHeadersFromRule(c *fasthttp.RequestCtx, cors *data.CORSRule) { } func checkSubslice(slice []string, subSlice []string) bool { - if sliceContains(slice, wildcard) { + if slices.Contains(slice, wildcard) { return true } - if len(subSlice) > len(slice) { - return false - } for _, r := range subSlice { if !sliceContains(slice, r) { return false @@ -334,9 +338,16 @@ func checkSubslice(slice []string, subSlice []string) bool { func sliceContains(slice []string, str string) bool { for _, s := range slice { - if s == str { + if s == str || (strings.Contains(s, "*") && match(s, str)) { return true } } return false } + +func match(tmpl, str string) bool { + regexpStr := "^" + regexp.QuoteMeta(tmpl) + "$" + regexpStr = regexpStr[:strings.Index(regexpStr, "*")-1] + "." + regexpStr[strings.Index(regexpStr, "*"):] + reg := regexp.MustCompile(regexpStr) + return reg.Match([]byte(str)) +} diff --git a/internal/handler/cors_test.go b/internal/handler/cors_test.go index 7cd7b0d..1ac07d7 100644 --- a/internal/handler/cors_test.go +++ b/internal/handler/cors_test.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/xml" "fmt" + "net/http" "testing" "git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/data" @@ -407,6 +408,12 @@ func TestCheckSubslice(t *testing.T) { actual: []string{"str1", "str5"}, expected: false, }, + { + name: "wildcard in allowed", + allowed: []string{"str*"}, + actual: []string{"str", "str5"}, + expected: true, + }, } { t.Run(tc.name, func(t *testing.T) { require.Equal(t, tc.expected, checkSubslice(tc.allowed, tc.actual)) @@ -414,6 +421,489 @@ func TestCheckSubslice(t *testing.T) { } } +func TestAllowedOriginWildcards(t *testing.T) { + hc := prepareHandlerContext(t) + bktName := "bucket-allowed-origin-wildcards" + cnrID, cnr, err := hc.prepareContainer(bktName, acl.Private) + require.NoError(t, err) + hc.frostfs.SetContainer(cnrID, cnr) + + cfg := &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"*suffix.example"}, + AllowedMethods: []string{"GET"}, + }, + { + AllowedOrigins: []string{"https://*example"}, + AllowedMethods: []string{"GET"}, + }, + { + AllowedOrigins: []string{"prefix.example*"}, + AllowedMethods: []string{"GET"}, + }, + }, + } + setCORSObject(t, hc, cnrID, cfg, 1) + + for _, tc := range []struct { + name string + handler func(*fasthttp.RequestCtx) + requestHeaders map[string]string + expectedHeaders map[string]string + expectedStatus int + }{ + { + name: "set cors headers, empty request cors headers", + handler: hc.Handler().SetCORSHeaders, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + }, + { + name: "set cors headers, invalid origin", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://origin.com", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + }, + { + name: "set cors headers, first rule, no symbols in place of wildcard", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "suffix.example", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "suffix.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "set cors headers, first rule, valid origin", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "http://suffix.example", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "http://suffix.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "set cors headers, first rule, invalid origin", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "http://suffix-example", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + }, + { + name: "set cors headers, second rule, no symbols in place of wildcard", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://example", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "https://example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "set cors headers, second rule, valid origin", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "https://www.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "set cors headers, second rule, invalid origin", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + }, + { + name: "set cors headers, third rule, no symbols in place of wildcard", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "prefix.example", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "prefix.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "set cors headers, third rule, valid origin", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "prefix.example.com", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "prefix.example.com", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "set cors headers, third rule, invalid origin", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "www.prefix.example", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + }, + { + name: "set cors headers, third rule, invalid request method in header", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "prefix.example.com", + fasthttp.HeaderAccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + }, + { + name: "set cors headers, third rule, valid request method in header", + handler: hc.Handler().SetCORSHeaders, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "prefix.example.com", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "prefix.example.com", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "preflight, empty request cors headers", + handler: hc.Handler().Preflight, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "preflight, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://origin.com", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, first rule, no symbols in place of wildcard", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "suffix.example", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "suffix.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "prelight, first rule, valid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "http://suffix.example", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "http://suffix.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "preflight, first rule, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "http://suffix-example", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, second rule, no symbols in place of wildcard", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://example", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "https://example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "preflight, second rule, valid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "https://www.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "preflight, second rule, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, third rule, no symbols in place of wildcard", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "prefix.example", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "prefix.example", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "preflight, third rule, valid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "prefix.example.com", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "prefix.example.com", + fasthttp.HeaderAccessControlAllowMethods: "GET", + }, + }, + { + name: "preflight, third rule, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "www.prefix.example", + fasthttp.HeaderAccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, third rule, invalid request method in header", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "prefix.example.com", + fasthttp.HeaderAccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + } { + t.Run(tc.name, func(t *testing.T) { + r := prepareCORSRequest(t, bktName, tc.requestHeaders) + tc.handler(r) + + expectedStatus := fasthttp.StatusOK + if tc.expectedStatus != 0 { + expectedStatus = tc.expectedStatus + } + require.Equal(t, expectedStatus, r.Response.StatusCode()) + for k, v := range tc.expectedHeaders { + require.Equal(t, v, string(r.Response.Header.Peek(k))) + } + }) + } +} + +func TestAllowedHeaderWildcards(t *testing.T) { + hc := prepareHandlerContext(t) + bktName := "bucket-allowed-header-wildcards" + cnrID, cnr, err := hc.prepareContainer(bktName, acl.Private) + require.NoError(t, err) + hc.frostfs.SetContainer(cnrID, cnr) + + cfg := &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"https://www.example.com"}, + AllowedMethods: []string{"HEAD"}, + AllowedHeaders: []string{"*-suffix"}, + }, + { + AllowedOrigins: []string{"https://www.example.com"}, + AllowedMethods: []string{"HEAD"}, + AllowedHeaders: []string{"start-*-end"}, + }, + { + AllowedOrigins: []string{"https://www.example.com"}, + AllowedMethods: []string{"HEAD"}, + AllowedHeaders: []string{"X-Amz-*"}, + }, + }, + } + setCORSObject(t, hc, cnrID, cfg, 1) + + for _, tc := range []struct { + name string + requestHeaders map[string]string + expectedHeaders map[string]string + expectedStatus int + }{ + { + name: "first rule, valid headers", + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "HEAD", + fasthttp.HeaderAccessControlRequestHeaders: "header-suffix, -suffix", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlAllowMethods: "HEAD", + fasthttp.HeaderAccessControlAllowHeaders: "header-suffix, -suffix", + }, + }, + { + name: "first rule, invalid headers", + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "HEAD", + fasthttp.HeaderAccessControlRequestHeaders: "header-suffix-*", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + fasthttp.HeaderAccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "second rule, valid headers", + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "HEAD", + fasthttp.HeaderAccessControlRequestHeaders: "start--end, start-header-end", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlAllowMethods: "HEAD", + fasthttp.HeaderAccessControlAllowHeaders: "start--end, start-header-end", + }, + }, + { + name: "second rule, invalid header ending", + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "HEAD", + fasthttp.HeaderAccessControlRequestHeaders: "start-header-end-*", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + fasthttp.HeaderAccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "second rule, invalid header beginning", + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "HEAD", + fasthttp.HeaderAccessControlRequestHeaders: "*-start-header-end", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + fasthttp.HeaderAccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "third rule, valid headers", + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "HEAD", + fasthttp.HeaderAccessControlRequestHeaders: "X-Amz-Date, X-Amz-Content-Sha256", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlAllowMethods: "HEAD", + fasthttp.HeaderAccessControlAllowHeaders: "X-Amz-Date, X-Amz-Content-Sha256", + }, + }, + { + name: "third rule, invalid headers", + requestHeaders: map[string]string{ + fasthttp.HeaderOrigin: "https://www.example.com", + fasthttp.HeaderAccessControlRequestMethod: "HEAD", + fasthttp.HeaderAccessControlRequestHeaders: "Authorization", + }, + expectedHeaders: map[string]string{ + fasthttp.HeaderAccessControlAllowOrigin: "", + fasthttp.HeaderAccessControlAllowMethods: "", + fasthttp.HeaderAccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + } { + t.Run(tc.name, func(t *testing.T) { + r := prepareCORSRequest(t, bktName, tc.requestHeaders) + hc.Handler().Preflight(r) + + expectedStatus := http.StatusOK + if tc.expectedStatus != 0 { + expectedStatus = tc.expectedStatus + } + require.Equal(t, expectedStatus, r.Response.StatusCode()) + for k, v := range tc.expectedHeaders { + require.Equal(t, v, string(r.Response.Header.Peek(k))) + } + }) + } +} + func setCORSObject(t *testing.T, hc *handlerContext, cnrID cid.ID, corsConfig *data.CORSConfiguration, epoch uint64) { payload, err := xml.Marshal(corsConfig) require.NoError(t, err)