diff --git a/api/handler/cors.go b/api/handler/cors.go index 69684fa6..0ddec96a 100644 --- a/api/handler/cors.go +++ b/api/handler/cors.go @@ -187,8 +187,8 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) { if !checkSubslice(rule.AllowedHeaders, headers) { continue } - w.Header().Set(api.AccessControlAllowOrigin, o) - w.Header().Set(api.AccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", ")) + w.Header().Set(api.AccessControlAllowOrigin, origin) + w.Header().Set(api.AccessControlAllowMethods, method) if headers != nil { w.Header().Set(api.AccessControlAllowHeaders, requestHeaders) } diff --git a/api/handler/cors_test.go b/api/handler/cors_test.go index 1c4bd9ed..42008d76 100644 --- a/api/handler/cors_test.go +++ b/api/handler/cors_test.go @@ -7,6 +7,7 @@ import ( "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware" + "github.com/stretchr/testify/require" ) func TestCORSOriginWildcard(t *testing.T) { @@ -39,3 +40,181 @@ func TestCORSOriginWildcard(t *testing.T) { hc.Handler().GetBucketCorsHandler(w, r) assertStatus(t, w, http.StatusOK) } + +func TestPreflight(t *testing.T) { + body := ` + + + GET + http://www.example.com + Authorization + x-amz-* + X-Amz-* + 600 + + +` + hc := prepareHandlerContext(t) + + bktName := "bucket-preflight-test" + box, _ := createAccessBox(t) + w, r := prepareTestRequest(hc, bktName, "", nil) + ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}) + r = r.WithContext(ctx) + hc.Handler().CreateBucketHandler(w, r) + assertStatus(t, w, http.StatusOK) + + w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body)) + ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}) + r = r.WithContext(ctx) + hc.Handler().PutBucketCorsHandler(w, r) + assertStatus(t, w, http.StatusOK) + + for _, tc := range []struct { + name string + origin string + method string + headers string + expectedStatus int + }{ + { + name: "Valid", + origin: "http://www.example.com", + method: "GET", + headers: "Authorization", + expectedStatus: http.StatusOK, + }, + { + name: "Empty origin", + method: "GET", + headers: "Authorization", + expectedStatus: http.StatusBadRequest, + }, + { + name: "Empty request method", + origin: "http://www.example.com", + headers: "Authorization", + expectedStatus: http.StatusBadRequest, + }, + { + name: "Not allowed method", + origin: "http://www.example.com", + method: "PUT", + headers: "Authorization", + expectedStatus: http.StatusForbidden, + }, + { + name: "Not allowed headers", + origin: "http://www.example.com", + method: "GET", + headers: "Authorization, Last-Modified", + expectedStatus: http.StatusForbidden, + }, + } { + t.Run(tc.name, func(t *testing.T) { + w, r = prepareTestPayloadRequest(hc, bktName, "", nil) + r.Header.Set(api.Origin, tc.origin) + r.Header.Set(api.AccessControlRequestMethod, tc.method) + r.Header.Set(api.AccessControlRequestHeaders, tc.headers) + hc.Handler().Preflight(w, r) + assertStatus(t, w, tc.expectedStatus) + + if tc.expectedStatus == http.StatusOK { + require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin)) + require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods)) + require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders)) + require.Equal(t, "x-amz-*, X-Amz-*", w.Header().Get(api.AccessControlExposeHeaders)) + require.Equal(t, "true", w.Header().Get(api.AccessControlAllowCredentials)) + require.Equal(t, "600", w.Header().Get(api.AccessControlMaxAge)) + } + }) + } +} + +func TestPreflightWildcardOrigin(t *testing.T) { + body := ` + + + GET + PUT + * + * + + +` + hc := prepareHandlerContext(t) + + bktName := "bucket-preflight-wildcard-test" + box, _ := createAccessBox(t) + w, r := prepareTestRequest(hc, bktName, "", nil) + ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}) + r = r.WithContext(ctx) + hc.Handler().CreateBucketHandler(w, r) + assertStatus(t, w, http.StatusOK) + + w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body)) + ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}) + r = r.WithContext(ctx) + hc.Handler().PutBucketCorsHandler(w, r) + assertStatus(t, w, http.StatusOK) + + for _, tc := range []struct { + name string + origin string + method string + headers string + expectedStatus int + }{ + { + name: "Valid get", + origin: "http://www.example.com", + method: "GET", + headers: "Authorization, Last-Modified", + expectedStatus: http.StatusOK, + }, + { + name: "Valid put", + origin: "http://example.com", + method: "PUT", + headers: "Authorization, Content-Type", + expectedStatus: http.StatusOK, + }, + { + name: "Empty origin", + method: "GET", + headers: "Authorization, Last-Modified", + expectedStatus: http.StatusBadRequest, + }, + { + name: "Empty request method", + origin: "http://www.example.com", + headers: "Authorization, Last-Modified", + expectedStatus: http.StatusBadRequest, + }, + { + name: "Not allowed method", + origin: "http://www.example.com", + method: "DELETE", + headers: "Authorization, Last-Modified", + expectedStatus: http.StatusForbidden, + }, + } { + t.Run(tc.name, func(t *testing.T) { + w, r = prepareTestPayloadRequest(hc, bktName, "", nil) + r.Header.Set(api.Origin, tc.origin) + r.Header.Set(api.AccessControlRequestMethod, tc.method) + r.Header.Set(api.AccessControlRequestHeaders, tc.headers) + hc.Handler().Preflight(w, r) + assertStatus(t, w, tc.expectedStatus) + + if tc.expectedStatus == http.StatusOK { + require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin)) + require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods)) + require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders)) + require.Empty(t, w.Header().Get(api.AccessControlExposeHeaders)) + require.Empty(t, w.Header().Get(api.AccessControlAllowCredentials)) + require.Equal(t, "0", w.Header().Get(api.AccessControlMaxAge)) + } + }) + } +} diff --git a/api/middleware/constants.go b/api/middleware/constants.go index 47f65324..a52b93a8 100644 --- a/api/middleware/constants.go +++ b/api/middleware/constants.go @@ -5,7 +5,7 @@ const ( // bucket operations. - OptionsOperation = "Options" + OptionsBucketOperation = "OptionsBucket" HeadBucketOperation = "HeadBucket" ListMultipartUploadsOperation = "ListMultipartUploads" GetBucketLocationOperation = "GetBucketLocation" @@ -51,6 +51,7 @@ const ( // object operations. + OptionsObjectOperation = "OptionsObject" HeadObjectOperation = "HeadObject" ListPartsOperation = "ListParts" GetObjectACLOperation = "GetObjectACL" diff --git a/api/middleware/metrics.go b/api/middleware/metrics.go index c72c59d0..fca113a3 100644 --- a/api/middleware/metrics.go +++ b/api/middleware/metrics.go @@ -103,7 +103,7 @@ func stats(f http.HandlerFunc, resolveCID cidResolveFunc, appMetrics *metrics.Ap func requestTypeFromAPI(api string) metrics.RequestType { switch api { - case OptionsOperation, HeadObjectOperation, HeadBucketOperation: + case OptionsBucketOperation, OptionsObjectOperation, HeadObjectOperation, HeadBucketOperation: return metrics.HEADRequest case CreateMultipartUploadOperation, UploadPartCopyOperation, UploadPartOperation, CompleteMultipartUploadOperation, PutObjectACLOperation, PutObjectTaggingOperation, CopyObjectOperation, PutObjectRetentionOperation, PutObjectLegalHoldOperation, diff --git a/api/middleware/policy.go b/api/middleware/policy.go index f1c1f320..5a7142a2 100644 --- a/api/middleware/policy.go +++ b/api/middleware/policy.go @@ -253,7 +253,7 @@ func determineBucketOperation(r *http.Request) string { query := r.URL.Query() switch r.Method { case http.MethodOptions: - return OptionsOperation + return OptionsBucketOperation case http.MethodHead: return HeadBucketOperation case http.MethodGet: @@ -356,6 +356,8 @@ func determineBucketOperation(r *http.Request) string { func determineObjectOperation(r *http.Request) string { query := r.URL.Query() switch r.Method { + case http.MethodOptions: + return OptionsObjectOperation case http.MethodHead: return HeadObjectOperation case http.MethodGet: diff --git a/api/middleware/policy_test.go b/api/middleware/policy_test.go index 34d3a9c5..0c6f1282 100644 --- a/api/middleware/policy_test.go +++ b/api/middleware/policy_test.go @@ -91,9 +91,9 @@ func TestDetermineBucketOperation(t *testing.T) { expected string }{ { - name: "OptionsOperation", + name: "OptionsBucketOperation", method: http.MethodOptions, - expected: OptionsOperation, + expected: OptionsBucketOperation, }, { name: "HeadBucketOperation", @@ -367,6 +367,11 @@ func TestDetermineObjectOperation(t *testing.T) { headerKeys []string expected string }{ + { + name: "OptionsObjectOperation", + method: http.MethodOptions, + expected: OptionsObjectOperation, + }, { name: "HeadObjectOperation", method: http.MethodHead, diff --git a/api/router.go b/api/router.go index a6dfa0a5..0f86e2e5 100644 --- a/api/router.go +++ b/api/router.go @@ -223,7 +223,7 @@ func bucketRouter(h Handler, log *zap.Logger) chi.Router { bktRouter.Mount("/", objectRouter(h, log)) - bktRouter.Options("/", h.Preflight) + bktRouter.Options("/", named(s3middleware.OptionsBucketOperation, h.Preflight)) bktRouter.Head("/", named(s3middleware.HeadBucketOperation, h.HeadBucketHandler)) @@ -372,6 +372,8 @@ func objectRouter(h Handler, l *zap.Logger) chi.Router { objRouter := chi.NewRouter() objRouter.Use(s3middleware.AddObjectName(l)) + objRouter.Options("/*", named(s3middleware.OptionsObjectOperation, h.Preflight)) + objRouter.Head("/*", named(s3middleware.HeadObjectOperation, h.HeadObjectHandler)) // GET method handlers