diff --git a/api/errors/errors.go b/api/errors/errors.go
index ca514ee1..aee4b7b4 100644
--- a/api/errors/errors.go
+++ b/api/errors/errors.go
@@ -290,6 +290,8 @@ const (
//CORS configuration errors.
ErrCORSUnsupportedMethod
ErrCORSWildcardExposeHeaders
+ ErrCORSWildcardsAllowedOrigins
+ ErrCORSWildcardsAllowedHeaders
// Limits errors.
ErrLimitExceeded
@@ -1740,7 +1742,7 @@ var errorCodes = errorCodeMap{
ErrCORSWildcardExposeHeaders: {
ErrCode: ErrCORSWildcardExposeHeaders,
Code: "InvalidRequest",
- Description: "ExposeHeader \"*\" contains wildcard. We currently do not support wildcard for ExposeHeader",
+ Description: "ExposeHeader contains wildcard. We currently do not support wildcard for ExposeHeader",
HTTPStatusCode: http.StatusBadRequest,
},
ErrInvalidPartNumber: {
@@ -1781,6 +1783,18 @@ var errorCodes = errorCodeMap{
Description: "The TagSet does not exist",
HTTPStatusCode: http.StatusNotFound,
},
+ ErrCORSWildcardsAllowedOrigins: {
+ ErrCode: ErrCORSWildcardsAllowedOrigins,
+ Code: "InvalidRequest",
+ Description: "AllowedOrigin can not have more than one wildcard.",
+ HTTPStatusCode: http.StatusBadRequest,
+ },
+ ErrCORSWildcardsAllowedHeaders: {
+ ErrCode: ErrCORSWildcardsAllowedHeaders,
+ Code: "InvalidRequest",
+ Description: "AllowedHeader can not have more than one wildcard.",
+ HTTPStatusCode: http.StatusBadRequest,
+ },
// Add your error structure here.
}
diff --git a/api/handler/cors.go b/api/handler/cors.go
index 0f7e08e1..0f9d3f71 100644
--- a/api/handler/cors.go
+++ b/api/handler/cors.go
@@ -2,6 +2,8 @@ package handler
import (
"net/http"
+ "regexp"
+ "slices"
"strconv"
"strings"
@@ -110,6 +112,10 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
if origin == "" {
return
}
+ method := r.Header.Get(api.AccessControlRequestMethod)
+ if method == "" {
+ method = r.Method
+ }
ctx = qostagging.ContextWithIOTag(ctx, util.InternalIOTag)
reqInfo := middleware.GetReqInfo(ctx)
@@ -132,9 +138,9 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
for _, rule := range cors.CORSRules {
for _, o := range rule.AllowedOrigins {
- if o == origin {
+ if o == origin || (strings.Contains(o, "*") && len(o) > 1 && match(o, origin)) {
for _, m := range rule.AllowedMethods {
- if m == r.Method {
+ if m == method {
w.Header().Set(api.AccessControlAllowOrigin, origin)
w.Header().Set(api.AccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
w.Header().Set(api.AccessControlAllowCredentials, "true")
@@ -145,7 +151,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
}
if o == wildcard {
for _, m := range rule.AllowedMethods {
- if m == r.Method {
+ if m == method {
if withCredentials {
w.Header().Set(api.AccessControlAllowOrigin, origin)
w.Header().Set(api.AccessControlAllowCredentials, "true")
@@ -199,7 +205,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
for _, rule := range cors.CORSRules {
for _, o := range rule.AllowedOrigins {
- if o == origin || o == wildcard {
+ if o == origin || o == wildcard || (strings.Contains(o, "*") && match(o, origin)) {
for _, m := range rule.AllowedMethods {
if m == method {
if !checkSubslice(rule.AllowedHeaders, headers) {
@@ -235,12 +241,9 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
}
func checkSubslice(slice []string, subSlice []string) bool {
- if sliceContains(slice, wildcard) {
+ if slices.Contains(slice, wildcard) {
return true
}
- if len(subSlice) > len(slice) {
- return false
- }
for _, r := range subSlice {
if !sliceContains(slice, r) {
return false
@@ -251,9 +254,16 @@ func checkSubslice(slice []string, subSlice []string) bool {
func sliceContains(slice []string, str string) bool {
for _, s := range slice {
- if s == str {
+ if s == str || (strings.Contains(s, "*") && match(s, str)) {
return true
}
}
return false
}
+
+func match(tmpl, str string) bool {
+ regexpStr := "^" + regexp.QuoteMeta(tmpl) + "$"
+ regexpStr = regexpStr[:strings.Index(regexpStr, "*")-1] + "." + regexpStr[strings.Index(regexpStr, "*"):]
+ reg := regexp.MustCompile(regexpStr)
+ return reg.Match([]byte(str))
+}
diff --git a/api/handler/cors_test.go b/api/handler/cors_test.go
index 44e84ab0..32547b85 100644
--- a/api/handler/cors_test.go
+++ b/api/handler/cors_test.go
@@ -63,8 +63,8 @@ func TestPreflight(t *testing.T) {
GET
http://www.example.com
Authorization
- x-amz-*
- X-Amz-*
+ x-amz-request-id
+ X-Amz-Request-Id
600
@@ -138,7 +138,7 @@ func TestPreflight(t *testing.T) {
require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin))
require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods))
require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders))
- require.Equal(t, "x-amz-*, X-Amz-*", w.Header().Get(api.AccessControlExposeHeaders))
+ require.Equal(t, "x-amz-request-id, X-Amz-Request-Id", w.Header().Get(api.AccessControlExposeHeaders))
require.Equal(t, "true", w.Header().Get(api.AccessControlAllowCredentials))
require.Equal(t, "600", w.Header().Get(api.AccessControlMaxAge))
}
@@ -230,6 +230,109 @@ func TestPreflightWildcardOrigin(t *testing.T) {
}
}
+func TestAppendCORSHeadersWildcardOrigin(t *testing.T) {
+ body := `
+
+
+ GET
+ PUT
+ *
+
+
+`
+ hc := prepareHandlerContext(t)
+
+ bktName := "bucket-append-cors-headers-wildcard-test"
+ box, _ := createAccessBox(t)
+ w, r := prepareTestRequest(hc, bktName, "", nil)
+ ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
+ r = r.WithContext(ctx)
+ hc.Handler().CreateBucketHandler(w, r)
+ assertStatus(t, w, http.StatusOK)
+
+ putBucketCORS(hc, bktName, body)
+
+ for _, tc := range []struct {
+ name string
+ requestHeaders map[string]string
+ expectedHeaders map[string]string
+ }{
+ {
+ name: "Valid get",
+ requestHeaders: map[string]string{
+ api.Origin: "http://www.example.com",
+ api.AccessControlRequestMethod: "GET",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "*",
+ api.AccessControlAllowCredentials: "",
+ api.Vary: "",
+ api.AccessControlAllowMethods: "GET, PUT",
+ },
+ },
+ {
+ name: "Valid get with Authorization",
+ requestHeaders: map[string]string{
+ api.Origin: "http://www.example.com",
+ api.AccessControlRequestMethod: "GET",
+ api.Authorization: "value",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "http://www.example.com",
+ api.AccessControlAllowCredentials: "true",
+ api.Vary: api.Origin,
+ api.AccessControlAllowMethods: "GET, PUT",
+ },
+ },
+ {
+ name: "Empty origin",
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowCredentials: "",
+ api.Vary: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ {
+ name: "Empty request method",
+ requestHeaders: map[string]string{
+ api.Origin: "http://www.example.com",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "*",
+ api.AccessControlAllowCredentials: "",
+ api.Vary: "",
+ api.AccessControlAllowMethods: "GET, PUT",
+ },
+ },
+ {
+ name: "Not allowed method",
+ requestHeaders: map[string]string{
+ api.Origin: "http://www.example.com",
+ api.AccessControlRequestMethod: "DELETE",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowCredentials: "",
+ api.Vary: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
+ for k, v := range tc.requestHeaders {
+ r.Header.Set(k, v)
+ }
+ hc.Handler().AppendCORSHeaders(w, r)
+
+ for k, v := range tc.expectedHeaders {
+ require.Equal(t, v, w.Header().Get(k))
+ }
+ })
+ }
+}
+
func TestGetLatestCORSVersion(t *testing.T) {
bodyTree := `
@@ -346,6 +449,509 @@ func TestDeleteCORSInDeleteBucket(t *testing.T) {
require.Len(t, hc.tp.Objects(), 1) // CORS object in bucket container is not deleted
}
+func TestAllowedOriginWildcards(t *testing.T) {
+ hc := prepareHandlerContext(t)
+ bktName := "bucket-allowed-origin-wildcards"
+ createBucket(hc, bktName)
+
+ cfg := &data.CORSConfiguration{
+ CORSRules: []data.CORSRule{
+ {
+ AllowedOrigins: []string{"*suffix.example"},
+ AllowedMethods: []string{"PUT"},
+ },
+ {
+ AllowedOrigins: []string{"https://*example"},
+ AllowedMethods: []string{"PUT"},
+ },
+ {
+ AllowedOrigins: []string{"prefix.example*"},
+ AllowedMethods: []string{"PUT"},
+ },
+ },
+ }
+ body, err := xml.Marshal(cfg)
+ require.NoError(t, err)
+ putBucketCORS(hc, bktName, string(body))
+
+ for _, tc := range []struct {
+ name string
+ handler func(w http.ResponseWriter, r *http.Request)
+ requestHeaders map[string]string
+ expectedHeaders map[string]string
+ expectedStatus int
+ }{
+ {
+ name: "append cors headers, empty request cors headers",
+ handler: hc.Handler().AppendCORSHeaders,
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ {
+ name: "append cors headers, invalid origin",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "https://origin.com",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ {
+ name: "append cors headers, first rule, no symbols in place of wildcard",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "suffix.example",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "suffix.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "append cors headers, first rule, valid origin",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "http://suffix.example",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "http://suffix.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "append cors headers, first rule, invalid origin",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "http://suffix-example",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ {
+ name: "append cors headers, second rule, no symbols in place of wildcard",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "https://example",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "https://example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "append cors headers, second rule, valid origin",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "https://www.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "append cors headers, second rule, invalid origin",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ {
+ name: "append cors headers, third rule, no symbols in place of wildcard",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "prefix.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "append cors headers, third rule, valid origin",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example.com",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "prefix.example.com",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "append cors headers, third rule, invalid origin",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "www.prefix.example",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ {
+ name: "append cors headers, third rule, invalid request method in header",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example.com",
+ api.AccessControlRequestMethod: "GET",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ },
+ {
+ name: "append cors headers, third rule, valid request method in header",
+ handler: hc.Handler().AppendCORSHeaders,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example.com",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "prefix.example.com",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "preflight, empty request cors headers",
+ handler: hc.Handler().Preflight,
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "preflight, invalid origin",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "https://origin.com",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "preflight, first rule, no symbols in place of wildcard",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "suffix.example",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "suffix.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "prelight, first rule, valid origin",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "http://suffix.example",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "http://suffix.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "preflight, first rule, invalid origin",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "http://suffix-example",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "preflight, second rule, no symbols in place of wildcard",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "https://example",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "https://example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "preflight, second rule, valid origin",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "https://www.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "preflight, second rule, invalid origin",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "preflight, third rule, no symbols in place of wildcard",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "prefix.example",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "preflight, third rule, valid origin",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example.com",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "prefix.example.com",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ {
+ name: "preflight, third rule, invalid origin",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "www.prefix.example",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "preflight, third rule, invalid request method in header",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example.com",
+ api.AccessControlRequestMethod: "GET",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "preflight, third rule, valid request method in header",
+ handler: hc.Handler().Preflight,
+ requestHeaders: map[string]string{
+ api.Origin: "prefix.example.com",
+ api.AccessControlRequestMethod: "PUT",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "prefix.example.com",
+ api.AccessControlAllowMethods: "PUT",
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ w, r := prepareTestRequest(hc, bktName, "", nil)
+ for k, v := range tc.requestHeaders {
+ r.Header.Set(k, v)
+ }
+
+ tc.handler(w, r)
+
+ expectedStatus := http.StatusOK
+ if tc.expectedStatus != 0 {
+ expectedStatus = tc.expectedStatus
+ }
+ require.Equal(t, expectedStatus, w.Code)
+ for k, v := range tc.expectedHeaders {
+ require.Equal(t, v, w.Header().Get(k))
+ }
+ })
+ }
+}
+
+func TestAllowedHeaderWildcards(t *testing.T) {
+ hc := prepareHandlerContext(t)
+ bktName := "bucket-allowed-header-wildcards"
+ createBucket(hc, bktName)
+
+ cfg := &data.CORSConfiguration{
+ CORSRules: []data.CORSRule{
+ {
+ AllowedOrigins: []string{"https://www.example.com"},
+ AllowedMethods: []string{"HEAD"},
+ AllowedHeaders: []string{"*-suffix"},
+ },
+ {
+ AllowedOrigins: []string{"https://www.example.com"},
+ AllowedMethods: []string{"HEAD"},
+ AllowedHeaders: []string{"start-*-end"},
+ },
+ {
+ AllowedOrigins: []string{"https://www.example.com"},
+ AllowedMethods: []string{"HEAD"},
+ AllowedHeaders: []string{"X-Amz-*"},
+ },
+ },
+ }
+ body, err := xml.Marshal(cfg)
+ require.NoError(t, err)
+ putBucketCORS(hc, bktName, string(body))
+
+ for _, tc := range []struct {
+ name string
+ requestHeaders map[string]string
+ expectedHeaders map[string]string
+ expectedStatus int
+ }{
+ {
+ name: "first rule, valid headers",
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "HEAD",
+ api.AccessControlRequestHeaders: "header-suffix, -suffix",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "https://www.example.com",
+ api.AccessControlAllowMethods: "HEAD",
+ api.AccessControlAllowHeaders: "header-suffix, -suffix",
+ },
+ },
+ {
+ name: "first rule, invalid headers",
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "HEAD",
+ api.AccessControlRequestHeaders: "header-suffix-*",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ api.AccessControlAllowHeaders: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "second rule, valid headers",
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "HEAD",
+ api.AccessControlRequestHeaders: "start--end, start-header-end",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "https://www.example.com",
+ api.AccessControlAllowMethods: "HEAD",
+ api.AccessControlAllowHeaders: "start--end, start-header-end",
+ },
+ },
+ {
+ name: "second rule, invalid header ending",
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "HEAD",
+ api.AccessControlRequestHeaders: "start-header-end-*",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ api.AccessControlAllowHeaders: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "second rule, invalid header beginning",
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "HEAD",
+ api.AccessControlRequestHeaders: "*-start-header-end",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ api.AccessControlAllowHeaders: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "third rule, valid headers",
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "HEAD",
+ api.AccessControlRequestHeaders: "X-Amz-Date, X-Amz-Content-Sha256",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "https://www.example.com",
+ api.AccessControlAllowMethods: "HEAD",
+ api.AccessControlAllowHeaders: "X-Amz-Date, X-Amz-Content-Sha256",
+ },
+ },
+ {
+ name: "third rule, invalid headers",
+ requestHeaders: map[string]string{
+ api.Origin: "https://www.example.com",
+ api.AccessControlRequestMethod: "HEAD",
+ api.AccessControlRequestHeaders: "Authorization",
+ },
+ expectedHeaders: map[string]string{
+ api.AccessControlAllowOrigin: "",
+ api.AccessControlAllowMethods: "",
+ api.AccessControlAllowHeaders: "",
+ },
+ expectedStatus: http.StatusForbidden,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ w, r := prepareTestRequest(hc, bktName, "", nil)
+ for k, v := range tc.requestHeaders {
+ r.Header.Set(k, v)
+ }
+
+ hc.Handler().Preflight(w, r)
+
+ expectedStatus := http.StatusOK
+ if tc.expectedStatus != 0 {
+ expectedStatus = tc.expectedStatus
+ }
+ require.Equal(t, expectedStatus, w.Code)
+ for k, v := range tc.expectedHeaders {
+ require.Equal(t, v, w.Header().Get(k))
+ }
+ })
+ }
+}
+
func addCORSToTree(hc *handlerContext, cors string, bkt *data.BucketInfo, corsCnrID cid.ID) {
var addr oid.Address
addr.SetContainer(corsCnrID)
diff --git a/api/layer/cors.go b/api/layer/cors.go
index b2a95c4e..e93c1af8 100644
--- a/api/layer/cors.go
+++ b/api/layer/cors.go
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
+ "strings"
"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
@@ -173,13 +174,23 @@ func (n *Layer) deleteCORSVersions(ctx context.Context, bktInfo *data.BucketInfo
func checkCORS(cors *data.CORSConfiguration) error {
for _, r := range cors.CORSRules {
+ for _, o := range r.AllowedOrigins {
+ if strings.Count(o, "*") > 1 {
+ return apierr.GetAPIError(apierr.ErrCORSWildcardsAllowedOrigins)
+ }
+ }
+ for _, h := range r.AllowedHeaders {
+ if strings.Count(h, "*") > 1 {
+ return apierr.GetAPIError(apierr.ErrCORSWildcardsAllowedHeaders)
+ }
+ }
for _, m := range r.AllowedMethods {
if _, ok := supportedMethods[m]; !ok {
return apierr.GetAPIErrorWithError(apierr.ErrCORSUnsupportedMethod, fmt.Errorf("unsupported method is %s", m))
}
}
for _, h := range r.ExposeHeaders {
- if h == wildcard {
+ if strings.Contains(h, wildcard) {
return apierr.GetAPIError(apierr.ErrCORSWildcardExposeHeaders)
}
}
diff --git a/api/layer/cors_test.go b/api/layer/cors_test.go
index d2e96a95..5fbc8d67 100644
--- a/api/layer/cors_test.go
+++ b/api/layer/cors_test.go
@@ -6,6 +6,8 @@ import (
"strings"
"testing"
+ "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
+ apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"github.com/stretchr/testify/require"
)
@@ -17,7 +19,7 @@ func TestCorsCopiesNumber(t *testing.T) {
GET
http://www.example.com
Authorization
- x-amz-*
+ x-amz-request-id
`
@@ -39,6 +41,71 @@ func TestCorsCopiesNumber(t *testing.T) {
require.EqualValues(t, copies, tc.testFrostFS.CopiesNumbers(addrFromObject(objs[0]).EncodeToString()))
}
+func TestCheckCORS(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ cfg *data.CORSConfiguration
+ expectedCode apierr.ErrorCode
+ }{
+ {
+ name: "allowed origin wildcards",
+ cfg: &data.CORSConfiguration{
+ CORSRules: []data.CORSRule{
+ {
+ AllowedOrigins: []string{"https://*.example.*"},
+ AllowedMethods: []string{"GET"},
+ },
+ },
+ },
+ expectedCode: apierr.ErrCORSWildcardsAllowedOrigins,
+ },
+ {
+ name: "allowed header wildcards",
+ cfg: &data.CORSConfiguration{
+ CORSRules: []data.CORSRule{
+ {
+ AllowedOrigins: []string{"https://*.example.com"},
+ AllowedMethods: []string{"GET"},
+ AllowedHeaders: []string{"x-amz-*-*"},
+ },
+ },
+ },
+ expectedCode: apierr.ErrCORSWildcardsAllowedHeaders,
+ },
+ {
+ name: "invalid allowed method",
+ cfg: &data.CORSConfiguration{
+ CORSRules: []data.CORSRule{
+ {
+ AllowedOrigins: []string{"https://*.example.com"},
+ AllowedMethods: []string{"INVALID"},
+ AllowedHeaders: []string{"x-amz-*"},
+ },
+ },
+ },
+ expectedCode: apierr.ErrCORSUnsupportedMethod,
+ },
+ {
+ name: "expose header wildcard",
+ cfg: &data.CORSConfiguration{
+ CORSRules: []data.CORSRule{
+ {
+ AllowedOrigins: []string{"https://*.example.com"},
+ AllowedMethods: []string{"GET"},
+ AllowedHeaders: []string{"x-amz-*"},
+ ExposeHeaders: []string{"x-amz-*"},
+ },
+ },
+ },
+ expectedCode: apierr.ErrCORSWildcardExposeHeaders,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ require.True(t, apierr.IsS3Error(checkCORS(tc.cfg), tc.expectedCode))
+ })
+ }
+}
+
func NewXMLDecoder(r io.Reader, _ string) *xml.Decoder {
dec := xml.NewDecoder(r)
diff --git a/api/router.go b/api/router.go
index 832f9405..b65ff80f 100644
--- a/api/router.go
+++ b/api/router.go
@@ -144,6 +144,7 @@ func NewRouter(cfg Config) *chi.Mux {
}
api.Use(s3middleware.PrepareAddressStyle(cfg.MiddlewareSettings, cfg.Log))
+ api.Use(s3middleware.WrapHandler(cfg.Handler.AppendCORSHeaders))
api.Use(s3middleware.PolicyCheck(s3middleware.PolicyConfig{
Storage: cfg.PolicyChecker,
FrostfsID: cfg.FrostfsID,
@@ -290,9 +291,6 @@ func attachErrorHandler(api *chi.Mux) {
func bucketRouter(h Handler) chi.Router {
bktRouter := chi.NewRouter()
- bktRouter.Use(
- s3middleware.WrapHandler(h.AppendCORSHeaders),
- )
bktRouter.Mount("/", objectRouter(h))