From 9e5fb4be95094bd44f6e0f8ca98cb23fad1eb236 Mon Sep 17 00:00:00 2001 From: Marina Biryukova Date: Tue, 1 Oct 2024 17:43:55 +0300 Subject: [PATCH] [#507] Return not implemented by default in bucket router Signed-off-by: Marina Biryukova --- api/middleware/policy.go | 35 ++++++++++++++++++++++-------- api/middleware/policy_test.go | 20 ++++++++++++++++- api/router.go | 41 ++++++++++++++++++++++++++++++++--- api/router_filter.go | 35 +++++++++++++++++++++++++++--- internal/logs/logs.go | 1 + 5 files changed, 116 insertions(+), 16 deletions(-) diff --git a/api/middleware/policy.go b/api/middleware/policy.go index 2c07c3f..9e36f0b 100644 --- a/api/middleware/policy.go +++ b/api/middleware/policy.go @@ -24,11 +24,15 @@ import ( ) const ( - QueryVersionID = "versionId" - QueryPrefix = "prefix" - QueryDelimiter = "delimiter" - QueryMaxKeys = "max-keys" - amzTagging = "x-amz-tagging" + QueryVersionID = "versionId" + QueryPrefix = "prefix" + QueryDelimiter = "delimiter" + QueryMaxKeys = "max-keys" + QueryMarker = "marker" + QueryEncodingType = "encoding-type" + amzTagging = "x-amz-tagging" + + unmatchedBucketOperation = "UnmatchedBucketOperation" ) // In these operations we don't check resource tags because @@ -268,8 +272,17 @@ func determineBucketOperation(r *http.Request) string { return ListObjectsV2MOperation case query.Get(ListTypeQuery) == "2": return ListObjectsV2Operation - default: + case len(query) == 0 || func() bool { + for key := range query { + if key != QueryDelimiter && key != QueryMaxKeys && key != QueryPrefix && key != QueryMarker && key != QueryEncodingType { + return false + } + } + return true + }(): return ListObjectsV1Operation + default: + return unmatchedBucketOperation } case http.MethodPut: switch { @@ -291,8 +304,10 @@ func determineBucketOperation(r *http.Request) string { return PutBucketVersioningOperation case query.Has(NotificationQuery): return PutBucketNotificationOperation - default: + case len(query) == 0: return CreateBucketOperation + default: + return unmatchedBucketOperation } case http.MethodPost: switch { @@ -315,12 +330,14 @@ func determineBucketOperation(r *http.Request) string { return DeleteBucketLifecycleOperation case query.Has(EncryptionQuery): return DeleteBucketEncryptionOperation - default: + case len(query) == 0: return DeleteBucketOperation + default: + return unmatchedBucketOperation } } - return "UnmatchedBucketOperation" + return unmatchedBucketOperation } func determineObjectOperation(r *http.Request) string { diff --git a/api/middleware/policy_test.go b/api/middleware/policy_test.go index 7147ae4..6a308eb 100644 --- a/api/middleware/policy_test.go +++ b/api/middleware/policy_test.go @@ -152,6 +152,12 @@ func TestDetermineBucketOperation(t *testing.T) { method: http.MethodGet, expected: ListObjectsV1Operation, }, + { + name: "UnmatchedBucketOperation GET", + method: http.MethodGet, + queryParam: map[string]string{"query": ""}, + expected: unmatchedBucketOperation, + }, { name: "PutBucketCorsOperation", method: http.MethodPut, @@ -211,6 +217,12 @@ func TestDetermineBucketOperation(t *testing.T) { method: http.MethodPut, expected: CreateBucketOperation, }, + { + name: "UnmatchedBucketOperation PUT", + method: http.MethodPut, + queryParam: map[string]string{"query": ""}, + expected: unmatchedBucketOperation, + }, { name: "DeleteMultipleObjectsOperation", method: http.MethodPost, @@ -263,10 +275,16 @@ func TestDetermineBucketOperation(t *testing.T) { method: http.MethodDelete, expected: DeleteBucketOperation, }, + { + name: "UnmatchedBucketOperation DELETE", + method: http.MethodDelete, + queryParam: map[string]string{"query": ""}, + expected: unmatchedBucketOperation, + }, { name: "UnmatchedBucketOperation", method: "invalid-method", - expected: "UnmatchedBucketOperation", + expected: unmatchedBucketOperation, }, } { t.Run(tc.name, func(t *testing.T) { diff --git a/api/router.go b/api/router.go index 11c16fe..ec8c27b 100644 --- a/api/router.go +++ b/api/router.go @@ -226,6 +226,28 @@ func errorResponseHandler(w http.ResponseWriter, r *http.Request) { } } +func notSupportedHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqInfo := s3middleware.GetReqInfo(ctx) + + _, wrErr := s3middleware.WriteErrorResponse(w, reqInfo, errors.GetAPIError(errors.ErrNotSupported)) + + if log := s3middleware.GetReqLog(ctx); log != nil { + fields := []zap.Field{ + zap.String("http method", r.Method), + zap.String("url", r.RequestURI), + } + + if wrErr != nil { + fields = append(fields, zap.NamedError("write_response_error", wrErr)) + } + + log.Error(logs.NotSupported, fields...) + } + } +} + // attachErrorHandler set NotFoundHandler and MethodNotAllowedHandler for chi.Router. func attachErrorHandler(api *chi.Mux) { errorHandler := http.HandlerFunc(errorResponseHandler) @@ -313,7 +335,14 @@ func bucketRouter(h Handler) chi.Router { Add(NewFilter(). Queries(s3middleware.VersionsQuery). Handler(named(s3middleware.ListBucketObjectVersionsOperation, h.ListBucketObjectVersionsHandler))). - DefaultHandler(listWrapper(h))) + Add(NewFilter(). + AllowedQueries(s3middleware.QueryDelimiter, s3middleware.QueryMaxKeys, s3middleware.QueryPrefix, + s3middleware.QueryMarker, s3middleware.QueryEncodingType). + Handler(named(s3middleware.ListObjectsV1Operation, h.ListObjectsV1Handler))). + Add(NewFilter(). + NoQueries(). + Handler(listWrapper(h))). + DefaultHandler(notSupportedHandler())) }) // PUT method handlers @@ -346,7 +375,10 @@ func bucketRouter(h Handler) chi.Router { Add(NewFilter(). Queries(s3middleware.NotificationQuery). Handler(named(s3middleware.PutBucketNotificationOperation, h.PutBucketNotificationHandler))). - DefaultHandler(named(s3middleware.CreateBucketOperation, h.CreateBucketHandler))) + Add(NewFilter(). + NoQueries(). + Handler(named(s3middleware.CreateBucketOperation, h.CreateBucketHandler))). + DefaultHandler(notSupportedHandler())) }) // POST method handlers @@ -380,7 +412,10 @@ func bucketRouter(h Handler) chi.Router { Add(NewFilter(). Queries(s3middleware.EncryptionQuery). Handler(named(s3middleware.DeleteBucketEncryptionOperation, h.DeleteBucketEncryptionHandler))). - DefaultHandler(named(s3middleware.DeleteBucketOperation, h.DeleteBucketHandler))) + Add(NewFilter(). + NoQueries(). + Handler(named(s3middleware.DeleteBucketOperation, h.DeleteBucketHandler))). + DefaultHandler(notSupportedHandler())) }) attachErrorHandler(bktRouter) diff --git a/api/router_filter.go b/api/router_filter.go index cdd87b3..7742ce5 100644 --- a/api/router_filter.go +++ b/api/router_filter.go @@ -11,9 +11,11 @@ type HandlerFilters struct { } type Filter struct { - queries []Pair - headers []Pair - h http.Handler + queries []Pair + headers []Pair + allowedQueries map[string]struct{} + noQueries bool + h http.Handler } type Pair struct { @@ -105,6 +107,22 @@ func (f *Filter) Queries(queries ...string) *Filter { return f } +// NoQueries sets flag indicating that request shouldn't have query parameters. +func (f *Filter) NoQueries() *Filter { + f.noQueries = true + return f +} + +// AllowedQueries adds query parameter keys that may be present in request. +func (f *Filter) AllowedQueries(queries ...string) *Filter { + f.allowedQueries = make(map[string]struct{}, len(queries)) + for _, query := range queries { + f.allowedQueries[query] = struct{}{} + } + + return f +} + func (hf *HandlerFilters) DefaultHandler(handler http.HandlerFunc) *HandlerFilters { hf.defaultHandler = handler return hf @@ -122,6 +140,17 @@ func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (hf *HandlerFilters) match(r *http.Request) http.Handler { LOOP: for _, filter := range hf.filters { + if filter.noQueries && len(r.URL.Query()) > 0 { + continue + } + if len(filter.allowedQueries) > 0 { + queries := r.URL.Query() + for key := range queries { + if _, ok := filter.allowedQueries[key]; !ok { + continue LOOP + } + } + } for _, header := range filter.headers { hdrVals := r.Header.Values(header.Key) if len(hdrVals) == 0 || header.Value != "" && header.Value != hdrVals[0] { diff --git a/internal/logs/logs.go b/internal/logs/logs.go index cbcbfbb..dd7a659 100644 --- a/internal/logs/logs.go +++ b/internal/logs/logs.go @@ -170,4 +170,5 @@ const ( WarnDomainContainsInvalidPlaceholder = "the domain contains an invalid placeholder, domain skipped" FailedToRemoveOldPartNode = "failed to remove old part node" CouldntCacheNetworkInfo = "couldn't cache network info" + NotSupported = "not supported" )