diff --git a/api/handler/object_list.go b/api/handler/object_list.go index 5119707..502b8d2 100644 --- a/api/handler/object_list.go +++ b/api/handler/object_list.go @@ -4,6 +4,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "time" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api" @@ -153,6 +154,10 @@ func parseListObjectArgs(reqInfo *middleware.ReqInfo) (*layer.ListObjectsParamsC res.Delimiter = queryValues.Get("delimiter") res.Encode = queryValues.Get("encoding-type") + if res.Encode != "" && strings.ToLower(res.Encode) != urlEncodingType { + return nil, errors.GetAPIError(errors.ErrInvalidEncodingMethod) + } + if queryValues.Get("max-keys") == "" { res.MaxKeys = maxObjectList } else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys < 0 { @@ -257,6 +262,10 @@ func parseListObjectVersionsRequest(reqInfo *middleware.ReqInfo) (*layer.ListObj res.Encode = queryValues.Get("encoding-type") res.VersionIDMarker = queryValues.Get("version-id-marker") + if res.Encode != "" && strings.ToLower(res.Encode) != urlEncodingType { + return nil, errors.GetAPIError(errors.ErrInvalidEncodingMethod) + } + if res.VersionIDMarker != "" && res.KeyMarker == "" { return nil, errors.GetAPIError(errors.VersionIDMarkerWithoutKeyMarker) } diff --git a/api/handler/object_list_test.go b/api/handler/object_list_test.go index 6c9605e..5a13953 100644 --- a/api/handler/object_list_test.go +++ b/api/handler/object_list_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/http/httptest" "net/url" "sort" "strconv" @@ -14,6 +15,7 @@ import ( "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/cache" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data" + apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs" @@ -755,6 +757,16 @@ func TestListObjectVersionsEncoding(t *testing.T) { require.Equal(t, 3, listResponse.MaxKeys) } +func TestListingsWithInvalidEncodingType(t *testing.T) { + hc := prepareHandlerContext(t) + bktName := "bucket-for-listing-invalid-encoding" + createTestBucket(hc, bktName) + + listObjectsVersionsErr(hc, bktName, "invalid", apierr.GetAPIError(apierr.ErrInvalidEncodingMethod)) + listObjectsV2Err(hc, bktName, "invalid", apierr.GetAPIError(apierr.ErrInvalidEncodingMethod)) + listObjectsV1Err(hc, bktName, "invalid", apierr.GetAPIError(apierr.ErrInvalidEncodingMethod)) +} + func checkVersionsNames(t *testing.T, versions *ListObjectsVersionsResponse, names []string) { for i, v := range versions.Version { require.Equal(t, names[i], v.Key) @@ -762,10 +774,19 @@ func checkVersionsNames(t *testing.T, versions *ListObjectsVersionsResponse, nam } func listObjectsV2(hc *handlerContext, bktName, prefix, delimiter, startAfter, continuationToken string, maxKeys int) *ListObjectsV2Response { - return listObjectsV2Ext(hc, bktName, prefix, delimiter, startAfter, continuationToken, "", maxKeys) + w := listObjectsV2Base(hc, bktName, prefix, delimiter, startAfter, continuationToken, "", maxKeys) + assertStatus(hc.t, w, http.StatusOK) + res := &ListObjectsV2Response{} + parseTestResponse(hc.t, w, res) + return res } -func listObjectsV2Ext(hc *handlerContext, bktName, prefix, delimiter, startAfter, continuationToken, encodingType string, maxKeys int) *ListObjectsV2Response { +func listObjectsV2Err(hc *handlerContext, bktName, encoding string, err apierr.Error) { + w := listObjectsV2Base(hc, bktName, "", "", "", "", encoding, -1) + assertS3Error(hc.t, w, err) +} + +func listObjectsV2Base(hc *handlerContext, bktName, prefix, delimiter, startAfter, continuationToken, encodingType string, maxKeys int) *httptest.ResponseRecorder { query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys) query.Add("fetch-owner", "true") if len(startAfter) != 0 { @@ -780,10 +801,7 @@ func listObjectsV2Ext(hc *handlerContext, bktName, prefix, delimiter, startAfter w, r := prepareTestFullRequest(hc, bktName, "", query, nil) hc.Handler().ListObjectsV2Handler(w, r) - assertStatus(hc.t, w, http.StatusOK) - res := &ListObjectsV2Response{} - parseTestResponse(hc.t, w, res) - return res + return w } func validateListV1(t *testing.T, tc *handlerContext, bktName, prefix, delimiter, marker string, maxKeys int, @@ -843,28 +861,54 @@ func prepareCommonListObjectsQuery(prefix, delimiter string, maxKeys int) url.Va } func listObjectsV1(hc *handlerContext, bktName, prefix, delimiter, marker string, maxKeys int) *ListObjectsV1Response { - query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys) - if len(marker) != 0 { - query.Add("marker", marker) - } - - w, r := prepareTestFullRequest(hc, bktName, "", query, nil) - hc.Handler().ListObjectsV1Handler(w, r) + w := listObjectsV1Base(hc, bktName, prefix, delimiter, marker, "", maxKeys) assertStatus(hc.t, w, http.StatusOK) res := &ListObjectsV1Response{} parseTestResponse(hc.t, w, res) return res } +func listObjectsV1Err(hc *handlerContext, bktName, encoding string, err apierr.Error) { + w := listObjectsV1Base(hc, bktName, "", "", "", encoding, -1) + assertS3Error(hc.t, w, err) +} + +func listObjectsV1Base(hc *handlerContext, bktName, prefix, delimiter, marker, encoding string, maxKeys int) *httptest.ResponseRecorder { + query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys) + if len(marker) != 0 { + query.Add("marker", marker) + } + if len(encoding) != 0 { + query.Add("encoding-type", encoding) + } + + w, r := prepareTestFullRequest(hc, bktName, "", query, nil) + hc.Handler().ListObjectsV1Handler(w, r) + return w +} + func listObjectsVersions(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker string, maxKeys int) *ListObjectsVersionsResponse { - return listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, maxKeys, false) + w := listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, "", maxKeys) + assertStatus(hc.t, w, http.StatusOK) + res := &ListObjectsVersionsResponse{} + parseTestResponse(hc.t, w, res) + return res } func listObjectsVersionsURL(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker string, maxKeys int) *ListObjectsVersionsResponse { - return listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, maxKeys, true) + w := listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, urlEncodingType, maxKeys) + assertStatus(hc.t, w, http.StatusOK) + res := &ListObjectsVersionsResponse{} + parseTestResponse(hc.t, w, res) + return res } -func listObjectsVersionsBase(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker string, maxKeys int, encode bool) *ListObjectsVersionsResponse { +func listObjectsVersionsErr(hc *handlerContext, bktName, encoding string, err apierr.Error) { + w := listObjectsVersionsBase(hc, bktName, "", "", "", "", encoding, -1) + assertS3Error(hc.t, w, err) +} + +func listObjectsVersionsBase(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker, encoding string, maxKeys int) *httptest.ResponseRecorder { query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys) if len(keyMarker) != 0 { query.Add("key-marker", keyMarker) @@ -872,14 +916,11 @@ func listObjectsVersionsBase(hc *handlerContext, bktName, prefix, delimiter, key if len(versionIDMarker) != 0 { query.Add("version-id-marker", versionIDMarker) } - if encode { - query.Add("encoding-type", "url") + if len(encoding) != 0 { + query.Add("encoding-type", encoding) } w, r := prepareTestFullRequest(hc, bktName, "", query, nil) hc.Handler().ListBucketObjectVersionsHandler(w, r) - assertStatus(hc.t, w, http.StatusOK) - res := &ListObjectsVersionsResponse{} - parseTestResponse(hc.t, w, res) - return res + return w }