From e45c1a218836bf920ea3b2d228c2e2c50f890b64 Mon Sep 17 00:00:00 2001 From: Marina Biryukova Date: Mon, 31 Mar 2025 14:42:59 +0300 Subject: [PATCH] [#672] Support wildcard in allowed origins and headers Signed-off-by: Marina Biryukova --- api/errors/errors.go | 16 +- api/handler/cors.go | 28 +- api/handler/cors_test.go | 612 ++++++++++++++++++++++++++++++++++++++- api/layer/cors.go | 13 +- api/layer/cors_test.go | 69 ++++- api/router.go | 4 +- 6 files changed, 724 insertions(+), 18 deletions(-) diff --git a/api/errors/errors.go b/api/errors/errors.go index ca514ee1..aee4b7b4 100644 --- a/api/errors/errors.go +++ b/api/errors/errors.go @@ -290,6 +290,8 @@ const ( //CORS configuration errors. ErrCORSUnsupportedMethod ErrCORSWildcardExposeHeaders + ErrCORSWildcardsAllowedOrigins + ErrCORSWildcardsAllowedHeaders // Limits errors. ErrLimitExceeded @@ -1740,7 +1742,7 @@ var errorCodes = errorCodeMap{ ErrCORSWildcardExposeHeaders: { ErrCode: ErrCORSWildcardExposeHeaders, Code: "InvalidRequest", - Description: "ExposeHeader \"*\" contains wildcard. We currently do not support wildcard for ExposeHeader", + Description: "ExposeHeader contains wildcard. We currently do not support wildcard for ExposeHeader", HTTPStatusCode: http.StatusBadRequest, }, ErrInvalidPartNumber: { @@ -1781,6 +1783,18 @@ var errorCodes = errorCodeMap{ Description: "The TagSet does not exist", HTTPStatusCode: http.StatusNotFound, }, + ErrCORSWildcardsAllowedOrigins: { + ErrCode: ErrCORSWildcardsAllowedOrigins, + Code: "InvalidRequest", + Description: "AllowedOrigin can not have more than one wildcard.", + HTTPStatusCode: http.StatusBadRequest, + }, + ErrCORSWildcardsAllowedHeaders: { + ErrCode: ErrCORSWildcardsAllowedHeaders, + Code: "InvalidRequest", + Description: "AllowedHeader can not have more than one wildcard.", + HTTPStatusCode: http.StatusBadRequest, + }, // Add your error structure here. } diff --git a/api/handler/cors.go b/api/handler/cors.go index 0f7e08e1..0f9d3f71 100644 --- a/api/handler/cors.go +++ b/api/handler/cors.go @@ -2,6 +2,8 @@ package handler import ( "net/http" + "regexp" + "slices" "strconv" "strings" @@ -110,6 +112,10 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) { if origin == "" { return } + method := r.Header.Get(api.AccessControlRequestMethod) + if method == "" { + method = r.Method + } ctx = qostagging.ContextWithIOTag(ctx, util.InternalIOTag) reqInfo := middleware.GetReqInfo(ctx) @@ -132,9 +138,9 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) { for _, rule := range cors.CORSRules { for _, o := range rule.AllowedOrigins { - if o == origin { + if o == origin || (strings.Contains(o, "*") && len(o) > 1 && match(o, origin)) { for _, m := range rule.AllowedMethods { - if m == r.Method { + if m == method { w.Header().Set(api.AccessControlAllowOrigin, origin) w.Header().Set(api.AccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", ")) w.Header().Set(api.AccessControlAllowCredentials, "true") @@ -145,7 +151,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) { } if o == wildcard { for _, m := range rule.AllowedMethods { - if m == r.Method { + if m == method { if withCredentials { w.Header().Set(api.AccessControlAllowOrigin, origin) w.Header().Set(api.AccessControlAllowCredentials, "true") @@ -199,7 +205,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) { for _, rule := range cors.CORSRules { for _, o := range rule.AllowedOrigins { - if o == origin || o == wildcard { + if o == origin || o == wildcard || (strings.Contains(o, "*") && match(o, origin)) { for _, m := range rule.AllowedMethods { if m == method { if !checkSubslice(rule.AllowedHeaders, headers) { @@ -235,12 +241,9 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) { } func checkSubslice(slice []string, subSlice []string) bool { - if sliceContains(slice, wildcard) { + if slices.Contains(slice, wildcard) { return true } - if len(subSlice) > len(slice) { - return false - } for _, r := range subSlice { if !sliceContains(slice, r) { return false @@ -251,9 +254,16 @@ func checkSubslice(slice []string, subSlice []string) bool { func sliceContains(slice []string, str string) bool { for _, s := range slice { - if s == str { + if s == str || (strings.Contains(s, "*") && match(s, str)) { return true } } return false } + +func match(tmpl, str string) bool { + regexpStr := "^" + regexp.QuoteMeta(tmpl) + "$" + regexpStr = regexpStr[:strings.Index(regexpStr, "*")-1] + "." + regexpStr[strings.Index(regexpStr, "*"):] + reg := regexp.MustCompile(regexpStr) + return reg.Match([]byte(str)) +} diff --git a/api/handler/cors_test.go b/api/handler/cors_test.go index 44e84ab0..32547b85 100644 --- a/api/handler/cors_test.go +++ b/api/handler/cors_test.go @@ -63,8 +63,8 @@ func TestPreflight(t *testing.T) { GET http://www.example.com Authorization - x-amz-* - X-Amz-* + x-amz-request-id + X-Amz-Request-Id 600 @@ -138,7 +138,7 @@ func TestPreflight(t *testing.T) { require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin)) require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods)) require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders)) - require.Equal(t, "x-amz-*, X-Amz-*", w.Header().Get(api.AccessControlExposeHeaders)) + require.Equal(t, "x-amz-request-id, X-Amz-Request-Id", w.Header().Get(api.AccessControlExposeHeaders)) require.Equal(t, "true", w.Header().Get(api.AccessControlAllowCredentials)) require.Equal(t, "600", w.Header().Get(api.AccessControlMaxAge)) } @@ -230,6 +230,109 @@ func TestPreflightWildcardOrigin(t *testing.T) { } } +func TestAppendCORSHeadersWildcardOrigin(t *testing.T) { + body := ` + + + GET + PUT + * + + +` + hc := prepareHandlerContext(t) + + bktName := "bucket-append-cors-headers-wildcard-test" + box, _ := createAccessBox(t) + w, r := prepareTestRequest(hc, bktName, "", nil) + ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}) + r = r.WithContext(ctx) + hc.Handler().CreateBucketHandler(w, r) + assertStatus(t, w, http.StatusOK) + + putBucketCORS(hc, bktName, body) + + for _, tc := range []struct { + name string + requestHeaders map[string]string + expectedHeaders map[string]string + }{ + { + name: "Valid get", + requestHeaders: map[string]string{ + api.Origin: "http://www.example.com", + api.AccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "*", + api.AccessControlAllowCredentials: "", + api.Vary: "", + api.AccessControlAllowMethods: "GET, PUT", + }, + }, + { + name: "Valid get with Authorization", + requestHeaders: map[string]string{ + api.Origin: "http://www.example.com", + api.AccessControlRequestMethod: "GET", + api.Authorization: "value", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "http://www.example.com", + api.AccessControlAllowCredentials: "true", + api.Vary: api.Origin, + api.AccessControlAllowMethods: "GET, PUT", + }, + }, + { + name: "Empty origin", + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowCredentials: "", + api.Vary: "", + api.AccessControlAllowMethods: "", + }, + }, + { + name: "Empty request method", + requestHeaders: map[string]string{ + api.Origin: "http://www.example.com", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "*", + api.AccessControlAllowCredentials: "", + api.Vary: "", + api.AccessControlAllowMethods: "GET, PUT", + }, + }, + { + name: "Not allowed method", + requestHeaders: map[string]string{ + api.Origin: "http://www.example.com", + api.AccessControlRequestMethod: "DELETE", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowCredentials: "", + api.Vary: "", + api.AccessControlAllowMethods: "", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + w, r = prepareTestPayloadRequest(hc, bktName, "", nil) + for k, v := range tc.requestHeaders { + r.Header.Set(k, v) + } + hc.Handler().AppendCORSHeaders(w, r) + + for k, v := range tc.expectedHeaders { + require.Equal(t, v, w.Header().Get(k)) + } + }) + } +} + func TestGetLatestCORSVersion(t *testing.T) { bodyTree := ` @@ -346,6 +449,509 @@ func TestDeleteCORSInDeleteBucket(t *testing.T) { require.Len(t, hc.tp.Objects(), 1) // CORS object in bucket container is not deleted } +func TestAllowedOriginWildcards(t *testing.T) { + hc := prepareHandlerContext(t) + bktName := "bucket-allowed-origin-wildcards" + createBucket(hc, bktName) + + cfg := &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"*suffix.example"}, + AllowedMethods: []string{"PUT"}, + }, + { + AllowedOrigins: []string{"https://*example"}, + AllowedMethods: []string{"PUT"}, + }, + { + AllowedOrigins: []string{"prefix.example*"}, + AllowedMethods: []string{"PUT"}, + }, + }, + } + body, err := xml.Marshal(cfg) + require.NoError(t, err) + putBucketCORS(hc, bktName, string(body)) + + for _, tc := range []struct { + name string + handler func(w http.ResponseWriter, r *http.Request) + requestHeaders map[string]string + expectedHeaders map[string]string + expectedStatus int + }{ + { + name: "append cors headers, empty request cors headers", + handler: hc.Handler().AppendCORSHeaders, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + }, + { + name: "append cors headers, invalid origin", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "https://origin.com", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + }, + { + name: "append cors headers, first rule, no symbols in place of wildcard", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "suffix.example", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "suffix.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "append cors headers, first rule, valid origin", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "http://suffix.example", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "http://suffix.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "append cors headers, first rule, invalid origin", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "http://suffix-example", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + }, + { + name: "append cors headers, second rule, no symbols in place of wildcard", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "https://example", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "https://example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "append cors headers, second rule, valid origin", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "https://www.example", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "https://www.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "append cors headers, second rule, invalid origin", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + }, + { + name: "append cors headers, third rule, no symbols in place of wildcard", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "prefix.example", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "prefix.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "append cors headers, third rule, valid origin", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "prefix.example.com", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "prefix.example.com", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "append cors headers, third rule, invalid origin", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "www.prefix.example", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + }, + { + name: "append cors headers, third rule, invalid request method in header", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "prefix.example.com", + api.AccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + }, + { + name: "append cors headers, third rule, valid request method in header", + handler: hc.Handler().AppendCORSHeaders, + requestHeaders: map[string]string{ + api.Origin: "prefix.example.com", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "prefix.example.com", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "preflight, empty request cors headers", + handler: hc.Handler().Preflight, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "preflight, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "https://origin.com", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, first rule, no symbols in place of wildcard", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "suffix.example", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "suffix.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "prelight, first rule, valid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "http://suffix.example", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "http://suffix.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "preflight, first rule, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "http://suffix-example", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, second rule, no symbols in place of wildcard", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "https://example", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "https://example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "preflight, second rule, valid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "https://www.example", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "https://www.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "preflight, second rule, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, third rule, no symbols in place of wildcard", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "prefix.example", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "prefix.example", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "preflight, third rule, valid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "prefix.example.com", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "prefix.example.com", + api.AccessControlAllowMethods: "PUT", + }, + }, + { + name: "preflight, third rule, invalid origin", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "www.prefix.example", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, third rule, invalid request method in header", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "prefix.example.com", + api.AccessControlRequestMethod: "GET", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "preflight, third rule, valid request method in header", + handler: hc.Handler().Preflight, + requestHeaders: map[string]string{ + api.Origin: "prefix.example.com", + api.AccessControlRequestMethod: "PUT", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "prefix.example.com", + api.AccessControlAllowMethods: "PUT", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + w, r := prepareTestRequest(hc, bktName, "", nil) + for k, v := range tc.requestHeaders { + r.Header.Set(k, v) + } + + tc.handler(w, r) + + expectedStatus := http.StatusOK + if tc.expectedStatus != 0 { + expectedStatus = tc.expectedStatus + } + require.Equal(t, expectedStatus, w.Code) + for k, v := range tc.expectedHeaders { + require.Equal(t, v, w.Header().Get(k)) + } + }) + } +} + +func TestAllowedHeaderWildcards(t *testing.T) { + hc := prepareHandlerContext(t) + bktName := "bucket-allowed-header-wildcards" + createBucket(hc, bktName) + + cfg := &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"https://www.example.com"}, + AllowedMethods: []string{"HEAD"}, + AllowedHeaders: []string{"*-suffix"}, + }, + { + AllowedOrigins: []string{"https://www.example.com"}, + AllowedMethods: []string{"HEAD"}, + AllowedHeaders: []string{"start-*-end"}, + }, + { + AllowedOrigins: []string{"https://www.example.com"}, + AllowedMethods: []string{"HEAD"}, + AllowedHeaders: []string{"X-Amz-*"}, + }, + }, + } + body, err := xml.Marshal(cfg) + require.NoError(t, err) + putBucketCORS(hc, bktName, string(body)) + + for _, tc := range []struct { + name string + requestHeaders map[string]string + expectedHeaders map[string]string + expectedStatus int + }{ + { + name: "first rule, valid headers", + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "HEAD", + api.AccessControlRequestHeaders: "header-suffix, -suffix", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "https://www.example.com", + api.AccessControlAllowMethods: "HEAD", + api.AccessControlAllowHeaders: "header-suffix, -suffix", + }, + }, + { + name: "first rule, invalid headers", + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "HEAD", + api.AccessControlRequestHeaders: "header-suffix-*", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + api.AccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "second rule, valid headers", + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "HEAD", + api.AccessControlRequestHeaders: "start--end, start-header-end", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "https://www.example.com", + api.AccessControlAllowMethods: "HEAD", + api.AccessControlAllowHeaders: "start--end, start-header-end", + }, + }, + { + name: "second rule, invalid header ending", + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "HEAD", + api.AccessControlRequestHeaders: "start-header-end-*", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + api.AccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "second rule, invalid header beginning", + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "HEAD", + api.AccessControlRequestHeaders: "*-start-header-end", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + api.AccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "third rule, valid headers", + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "HEAD", + api.AccessControlRequestHeaders: "X-Amz-Date, X-Amz-Content-Sha256", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "https://www.example.com", + api.AccessControlAllowMethods: "HEAD", + api.AccessControlAllowHeaders: "X-Amz-Date, X-Amz-Content-Sha256", + }, + }, + { + name: "third rule, invalid headers", + requestHeaders: map[string]string{ + api.Origin: "https://www.example.com", + api.AccessControlRequestMethod: "HEAD", + api.AccessControlRequestHeaders: "Authorization", + }, + expectedHeaders: map[string]string{ + api.AccessControlAllowOrigin: "", + api.AccessControlAllowMethods: "", + api.AccessControlAllowHeaders: "", + }, + expectedStatus: http.StatusForbidden, + }, + } { + t.Run(tc.name, func(t *testing.T) { + w, r := prepareTestRequest(hc, bktName, "", nil) + for k, v := range tc.requestHeaders { + r.Header.Set(k, v) + } + + hc.Handler().Preflight(w, r) + + expectedStatus := http.StatusOK + if tc.expectedStatus != 0 { + expectedStatus = tc.expectedStatus + } + require.Equal(t, expectedStatus, w.Code) + for k, v := range tc.expectedHeaders { + require.Equal(t, v, w.Header().Get(k)) + } + }) + } +} + func addCORSToTree(hc *handlerContext, cors string, bkt *data.BucketInfo, corsCnrID cid.ID) { var addr oid.Address addr.SetContainer(corsCnrID) diff --git a/api/layer/cors.go b/api/layer/cors.go index b2a95c4e..e93c1af8 100644 --- a/api/layer/cors.go +++ b/api/layer/cors.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "strings" "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data" @@ -173,13 +174,23 @@ func (n *Layer) deleteCORSVersions(ctx context.Context, bktInfo *data.BucketInfo func checkCORS(cors *data.CORSConfiguration) error { for _, r := range cors.CORSRules { + for _, o := range r.AllowedOrigins { + if strings.Count(o, "*") > 1 { + return apierr.GetAPIError(apierr.ErrCORSWildcardsAllowedOrigins) + } + } + for _, h := range r.AllowedHeaders { + if strings.Count(h, "*") > 1 { + return apierr.GetAPIError(apierr.ErrCORSWildcardsAllowedHeaders) + } + } for _, m := range r.AllowedMethods { if _, ok := supportedMethods[m]; !ok { return apierr.GetAPIErrorWithError(apierr.ErrCORSUnsupportedMethod, fmt.Errorf("unsupported method is %s", m)) } } for _, h := range r.ExposeHeaders { - if h == wildcard { + if strings.Contains(h, wildcard) { return apierr.GetAPIError(apierr.ErrCORSWildcardExposeHeaders) } } diff --git a/api/layer/cors_test.go b/api/layer/cors_test.go index d2e96a95..5fbc8d67 100644 --- a/api/layer/cors_test.go +++ b/api/layer/cors_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" + "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data" + apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" "github.com/stretchr/testify/require" ) @@ -17,7 +19,7 @@ func TestCorsCopiesNumber(t *testing.T) { GET http://www.example.com Authorization - x-amz-* + x-amz-request-id ` @@ -39,6 +41,71 @@ func TestCorsCopiesNumber(t *testing.T) { require.EqualValues(t, copies, tc.testFrostFS.CopiesNumbers(addrFromObject(objs[0]).EncodeToString())) } +func TestCheckCORS(t *testing.T) { + for _, tc := range []struct { + name string + cfg *data.CORSConfiguration + expectedCode apierr.ErrorCode + }{ + { + name: "allowed origin wildcards", + cfg: &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"https://*.example.*"}, + AllowedMethods: []string{"GET"}, + }, + }, + }, + expectedCode: apierr.ErrCORSWildcardsAllowedOrigins, + }, + { + name: "allowed header wildcards", + cfg: &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"https://*.example.com"}, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{"x-amz-*-*"}, + }, + }, + }, + expectedCode: apierr.ErrCORSWildcardsAllowedHeaders, + }, + { + name: "invalid allowed method", + cfg: &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"https://*.example.com"}, + AllowedMethods: []string{"INVALID"}, + AllowedHeaders: []string{"x-amz-*"}, + }, + }, + }, + expectedCode: apierr.ErrCORSUnsupportedMethod, + }, + { + name: "expose header wildcard", + cfg: &data.CORSConfiguration{ + CORSRules: []data.CORSRule{ + { + AllowedOrigins: []string{"https://*.example.com"}, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{"x-amz-*"}, + ExposeHeaders: []string{"x-amz-*"}, + }, + }, + }, + expectedCode: apierr.ErrCORSWildcardExposeHeaders, + }, + } { + t.Run(tc.name, func(t *testing.T) { + require.True(t, apierr.IsS3Error(checkCORS(tc.cfg), tc.expectedCode)) + }) + } +} + func NewXMLDecoder(r io.Reader, _ string) *xml.Decoder { dec := xml.NewDecoder(r) diff --git a/api/router.go b/api/router.go index 832f9405..b65ff80f 100644 --- a/api/router.go +++ b/api/router.go @@ -144,6 +144,7 @@ func NewRouter(cfg Config) *chi.Mux { } api.Use(s3middleware.PrepareAddressStyle(cfg.MiddlewareSettings, cfg.Log)) + api.Use(s3middleware.WrapHandler(cfg.Handler.AppendCORSHeaders)) api.Use(s3middleware.PolicyCheck(s3middleware.PolicyConfig{ Storage: cfg.PolicyChecker, FrostfsID: cfg.FrostfsID, @@ -290,9 +291,6 @@ func attachErrorHandler(api *chi.Mux) { func bucketRouter(h Handler) chi.Router { bktRouter := chi.NewRouter() - bktRouter.Use( - s3middleware.WrapHandler(h.AppendCORSHeaders), - ) bktRouter.Mount("/", objectRouter(h))