package handler import ( "net/http" "strings" "testing" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware" "github.com/stretchr/testify/require" ) func TestCORSOriginWildcard(t *testing.T) { body := ` GET * ` hc := prepareHandlerContext(t) bktName := "bucket-for-cors" box, _ := createAccessBox(t) w, r := prepareTestRequest(hc, bktName, "", nil) ctx := middleware.SetBoxData(r.Context(), box) r = r.WithContext(ctx) r.Header.Add(api.AmzACL, "public-read") hc.Handler().CreateBucketHandler(w, r) assertStatus(t, w, http.StatusOK) w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body)) ctx = middleware.SetBoxData(r.Context(), box) r = r.WithContext(ctx) hc.Handler().PutBucketCorsHandler(w, r) assertStatus(t, w, http.StatusOK) w, r = prepareTestPayloadRequest(hc, bktName, "", nil) hc.Handler().GetBucketCorsHandler(w, r) assertStatus(t, w, http.StatusOK) } func TestPreflight(t *testing.T) { body := ` GET http://www.example.com Authorization x-amz-* X-Amz-* 600 ` hc := prepareHandlerContext(t) bktName := "bucket-preflight-test" box, _ := createAccessBox(t) w, r := prepareTestRequest(hc, bktName, "", nil) ctx := middleware.SetBoxData(r.Context(), box) r = r.WithContext(ctx) hc.Handler().CreateBucketHandler(w, r) assertStatus(t, w, http.StatusOK) w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body)) ctx = middleware.SetBoxData(r.Context(), box) r = r.WithContext(ctx) hc.Handler().PutBucketCorsHandler(w, r) assertStatus(t, w, http.StatusOK) for _, tc := range []struct { name string origin string method string headers string expectedStatus int }{ { name: "Valid", origin: "http://www.example.com", method: "GET", headers: "Authorization", expectedStatus: http.StatusOK, }, { name: "Empty origin", method: "GET", headers: "Authorization", expectedStatus: http.StatusBadRequest, }, { name: "Empty request method", origin: "http://www.example.com", headers: "Authorization", expectedStatus: http.StatusBadRequest, }, { name: "Not allowed method", origin: "http://www.example.com", method: "PUT", headers: "Authorization", expectedStatus: http.StatusForbidden, }, { name: "Not allowed headers", origin: "http://www.example.com", method: "GET", headers: "Authorization, Last-Modified", expectedStatus: http.StatusForbidden, }, } { t.Run(tc.name, func(t *testing.T) { w, r = prepareTestPayloadRequest(hc, bktName, "", nil) r.Header.Set(api.Origin, tc.origin) r.Header.Set(api.AccessControlRequestMethod, tc.method) r.Header.Set(api.AccessControlRequestHeaders, tc.headers) hc.Handler().Preflight(w, r) assertStatus(t, w, tc.expectedStatus) if tc.expectedStatus == http.StatusOK { require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin)) require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods)) require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders)) require.Equal(t, "x-amz-*, X-Amz-*", w.Header().Get(api.AccessControlExposeHeaders)) require.Equal(t, "true", w.Header().Get(api.AccessControlAllowCredentials)) require.Equal(t, "600", w.Header().Get(api.AccessControlMaxAge)) } }) } } func TestPreflightWildcardOrigin(t *testing.T) { body := ` GET PUT * * ` hc := prepareHandlerContext(t) bktName := "bucket-preflight-wildcard-test" box, _ := createAccessBox(t) w, r := prepareTestRequest(hc, bktName, "", nil) ctx := middleware.SetBoxData(r.Context(), box) r = r.WithContext(ctx) hc.Handler().CreateBucketHandler(w, r) assertStatus(t, w, http.StatusOK) w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body)) ctx = middleware.SetBoxData(r.Context(), box) r = r.WithContext(ctx) hc.Handler().PutBucketCorsHandler(w, r) assertStatus(t, w, http.StatusOK) for _, tc := range []struct { name string origin string method string headers string expectedStatus int }{ { name: "Valid get", origin: "http://www.example.com", method: "GET", headers: "Authorization, Last-Modified", expectedStatus: http.StatusOK, }, { name: "Valid put", origin: "http://example.com", method: "PUT", headers: "Authorization, Content-Type", expectedStatus: http.StatusOK, }, { name: "Empty origin", method: "GET", headers: "Authorization, Last-Modified", expectedStatus: http.StatusBadRequest, }, { name: "Empty request method", origin: "http://www.example.com", headers: "Authorization, Last-Modified", expectedStatus: http.StatusBadRequest, }, { name: "Not allowed method", origin: "http://www.example.com", method: "DELETE", headers: "Authorization, Last-Modified", expectedStatus: http.StatusForbidden, }, } { t.Run(tc.name, func(t *testing.T) { w, r = prepareTestPayloadRequest(hc, bktName, "", nil) r.Header.Set(api.Origin, tc.origin) r.Header.Set(api.AccessControlRequestMethod, tc.method) r.Header.Set(api.AccessControlRequestHeaders, tc.headers) hc.Handler().Preflight(w, r) assertStatus(t, w, tc.expectedStatus) if tc.expectedStatus == http.StatusOK { require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin)) require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods)) require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders)) require.Empty(t, w.Header().Get(api.AccessControlExposeHeaders)) require.Empty(t, w.Header().Get(api.AccessControlAllowCredentials)) require.Equal(t, "0", w.Header().Get(api.AccessControlMaxAge)) } }) } }