diff --git a/api/handler/cors.go b/api/handler/cors.go
index 69684fa..0ddec96 100644
--- a/api/handler/cors.go
+++ b/api/handler/cors.go
@@ -187,8 +187,8 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
if !checkSubslice(rule.AllowedHeaders, headers) {
continue
}
- w.Header().Set(api.AccessControlAllowOrigin, o)
- w.Header().Set(api.AccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
+ w.Header().Set(api.AccessControlAllowOrigin, origin)
+ w.Header().Set(api.AccessControlAllowMethods, method)
if headers != nil {
w.Header().Set(api.AccessControlAllowHeaders, requestHeaders)
}
diff --git a/api/handler/cors_test.go b/api/handler/cors_test.go
index 1c4bd9e..42008d7 100644
--- a/api/handler/cors_test.go
+++ b/api/handler/cors_test.go
@@ -7,6 +7,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
+ "github.com/stretchr/testify/require"
)
func TestCORSOriginWildcard(t *testing.T) {
@@ -39,3 +40,181 @@ func TestCORSOriginWildcard(t *testing.T) {
hc.Handler().GetBucketCorsHandler(w, r)
assertStatus(t, w, http.StatusOK)
}
+
+func TestPreflight(t *testing.T) {
+ body := `
+
+
+ GET
+ http://www.example.com
+ Authorization
+ x-amz-*
+ X-Amz-*
+ 600
+
+
+`
+ hc := prepareHandlerContext(t)
+
+ bktName := "bucket-preflight-test"
+ box, _ := createAccessBox(t)
+ w, r := prepareTestRequest(hc, bktName, "", nil)
+ ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
+ r = r.WithContext(ctx)
+ hc.Handler().CreateBucketHandler(w, r)
+ assertStatus(t, w, http.StatusOK)
+
+ w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
+ ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
+ r = r.WithContext(ctx)
+ hc.Handler().PutBucketCorsHandler(w, r)
+ assertStatus(t, w, http.StatusOK)
+
+ for _, tc := range []struct {
+ name string
+ origin string
+ method string
+ headers string
+ expectedStatus int
+ }{
+ {
+ name: "Valid",
+ origin: "http://www.example.com",
+ method: "GET",
+ headers: "Authorization",
+ expectedStatus: http.StatusOK,
+ },
+ {
+ name: "Empty origin",
+ method: "GET",
+ headers: "Authorization",
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "Empty request method",
+ origin: "http://www.example.com",
+ headers: "Authorization",
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "Not allowed method",
+ origin: "http://www.example.com",
+ method: "PUT",
+ headers: "Authorization",
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "Not allowed headers",
+ origin: "http://www.example.com",
+ method: "GET",
+ headers: "Authorization, Last-Modified",
+ expectedStatus: http.StatusForbidden,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
+ r.Header.Set(api.Origin, tc.origin)
+ r.Header.Set(api.AccessControlRequestMethod, tc.method)
+ r.Header.Set(api.AccessControlRequestHeaders, tc.headers)
+ hc.Handler().Preflight(w, r)
+ assertStatus(t, w, tc.expectedStatus)
+
+ if tc.expectedStatus == http.StatusOK {
+ require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin))
+ require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods))
+ require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders))
+ require.Equal(t, "x-amz-*, X-Amz-*", w.Header().Get(api.AccessControlExposeHeaders))
+ require.Equal(t, "true", w.Header().Get(api.AccessControlAllowCredentials))
+ require.Equal(t, "600", w.Header().Get(api.AccessControlMaxAge))
+ }
+ })
+ }
+}
+
+func TestPreflightWildcardOrigin(t *testing.T) {
+ body := `
+
+
+ GET
+ PUT
+ *
+ *
+
+
+`
+ hc := prepareHandlerContext(t)
+
+ bktName := "bucket-preflight-wildcard-test"
+ box, _ := createAccessBox(t)
+ w, r := prepareTestRequest(hc, bktName, "", nil)
+ ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
+ r = r.WithContext(ctx)
+ hc.Handler().CreateBucketHandler(w, r)
+ assertStatus(t, w, http.StatusOK)
+
+ w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
+ ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
+ r = r.WithContext(ctx)
+ hc.Handler().PutBucketCorsHandler(w, r)
+ assertStatus(t, w, http.StatusOK)
+
+ for _, tc := range []struct {
+ name string
+ origin string
+ method string
+ headers string
+ expectedStatus int
+ }{
+ {
+ name: "Valid get",
+ origin: "http://www.example.com",
+ method: "GET",
+ headers: "Authorization, Last-Modified",
+ expectedStatus: http.StatusOK,
+ },
+ {
+ name: "Valid put",
+ origin: "http://example.com",
+ method: "PUT",
+ headers: "Authorization, Content-Type",
+ expectedStatus: http.StatusOK,
+ },
+ {
+ name: "Empty origin",
+ method: "GET",
+ headers: "Authorization, Last-Modified",
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "Empty request method",
+ origin: "http://www.example.com",
+ headers: "Authorization, Last-Modified",
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "Not allowed method",
+ origin: "http://www.example.com",
+ method: "DELETE",
+ headers: "Authorization, Last-Modified",
+ expectedStatus: http.StatusForbidden,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
+ r.Header.Set(api.Origin, tc.origin)
+ r.Header.Set(api.AccessControlRequestMethod, tc.method)
+ r.Header.Set(api.AccessControlRequestHeaders, tc.headers)
+ hc.Handler().Preflight(w, r)
+ assertStatus(t, w, tc.expectedStatus)
+
+ if tc.expectedStatus == http.StatusOK {
+ require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin))
+ require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods))
+ require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders))
+ require.Empty(t, w.Header().Get(api.AccessControlExposeHeaders))
+ require.Empty(t, w.Header().Get(api.AccessControlAllowCredentials))
+ require.Equal(t, "0", w.Header().Get(api.AccessControlMaxAge))
+ }
+ })
+ }
+}
diff --git a/api/middleware/constants.go b/api/middleware/constants.go
index 47f6532..a52b93a 100644
--- a/api/middleware/constants.go
+++ b/api/middleware/constants.go
@@ -5,7 +5,7 @@ const (
// bucket operations.
- OptionsOperation = "Options"
+ OptionsBucketOperation = "OptionsBucket"
HeadBucketOperation = "HeadBucket"
ListMultipartUploadsOperation = "ListMultipartUploads"
GetBucketLocationOperation = "GetBucketLocation"
@@ -51,6 +51,7 @@ const (
// object operations.
+ OptionsObjectOperation = "OptionsObject"
HeadObjectOperation = "HeadObject"
ListPartsOperation = "ListParts"
GetObjectACLOperation = "GetObjectACL"
diff --git a/api/middleware/metrics.go b/api/middleware/metrics.go
index c72c59d..fca113a 100644
--- a/api/middleware/metrics.go
+++ b/api/middleware/metrics.go
@@ -103,7 +103,7 @@ func stats(f http.HandlerFunc, resolveCID cidResolveFunc, appMetrics *metrics.Ap
func requestTypeFromAPI(api string) metrics.RequestType {
switch api {
- case OptionsOperation, HeadObjectOperation, HeadBucketOperation:
+ case OptionsBucketOperation, OptionsObjectOperation, HeadObjectOperation, HeadBucketOperation:
return metrics.HEADRequest
case CreateMultipartUploadOperation, UploadPartCopyOperation, UploadPartOperation, CompleteMultipartUploadOperation,
PutObjectACLOperation, PutObjectTaggingOperation, CopyObjectOperation, PutObjectRetentionOperation, PutObjectLegalHoldOperation,
diff --git a/api/middleware/policy.go b/api/middleware/policy.go
index f1c1f32..5a7142a 100644
--- a/api/middleware/policy.go
+++ b/api/middleware/policy.go
@@ -253,7 +253,7 @@ func determineBucketOperation(r *http.Request) string {
query := r.URL.Query()
switch r.Method {
case http.MethodOptions:
- return OptionsOperation
+ return OptionsBucketOperation
case http.MethodHead:
return HeadBucketOperation
case http.MethodGet:
@@ -356,6 +356,8 @@ func determineBucketOperation(r *http.Request) string {
func determineObjectOperation(r *http.Request) string {
query := r.URL.Query()
switch r.Method {
+ case http.MethodOptions:
+ return OptionsObjectOperation
case http.MethodHead:
return HeadObjectOperation
case http.MethodGet:
diff --git a/api/middleware/policy_test.go b/api/middleware/policy_test.go
index 34d3a9c..0c6f128 100644
--- a/api/middleware/policy_test.go
+++ b/api/middleware/policy_test.go
@@ -91,9 +91,9 @@ func TestDetermineBucketOperation(t *testing.T) {
expected string
}{
{
- name: "OptionsOperation",
+ name: "OptionsBucketOperation",
method: http.MethodOptions,
- expected: OptionsOperation,
+ expected: OptionsBucketOperation,
},
{
name: "HeadBucketOperation",
@@ -367,6 +367,11 @@ func TestDetermineObjectOperation(t *testing.T) {
headerKeys []string
expected string
}{
+ {
+ name: "OptionsObjectOperation",
+ method: http.MethodOptions,
+ expected: OptionsObjectOperation,
+ },
{
name: "HeadObjectOperation",
method: http.MethodHead,
diff --git a/api/router.go b/api/router.go
index a6dfa0a..0f86e2e 100644
--- a/api/router.go
+++ b/api/router.go
@@ -223,7 +223,7 @@ func bucketRouter(h Handler, log *zap.Logger) chi.Router {
bktRouter.Mount("/", objectRouter(h, log))
- bktRouter.Options("/", h.Preflight)
+ bktRouter.Options("/", named(s3middleware.OptionsBucketOperation, h.Preflight))
bktRouter.Head("/", named(s3middleware.HeadBucketOperation, h.HeadBucketHandler))
@@ -372,6 +372,8 @@ func objectRouter(h Handler, l *zap.Logger) chi.Router {
objRouter := chi.NewRouter()
objRouter.Use(s3middleware.AddObjectName(l))
+ objRouter.Options("/*", named(s3middleware.OptionsObjectOperation, h.Preflight))
+
objRouter.Head("/*", named(s3middleware.HeadObjectOperation, h.HeadObjectHandler))
// GET method handlers