diff --git a/api/middleware/policy.go b/api/middleware/policy.go index 9faf069..0734083 100644 --- a/api/middleware/policy.go +++ b/api/middleware/policy.go @@ -24,13 +24,16 @@ import ( ) const ( - QueryVersionID = "versionId" - QueryPrefix = "prefix" - QueryDelimiter = "delimiter" - QueryMaxKeys = "max-keys" - QueryMarker = "marker" - QueryEncodingType = "encoding-type" - amzTagging = "x-amz-tagging" + QueryVersionID = "versionId" + QueryPrefix = "prefix" + QueryDelimiter = "delimiter" + QueryMaxKeys = "max-keys" + QueryMarker = "marker" + QueryEncodingType = "encoding-type" + QueryMaxBuckets = "max-buckets" + QueryContinuationToken = "continuation-token" + QueryBucketRegion = "bucket-region" + amzTagging = "x-amz-tagging" unmatchedBucketOperation = "UnmatchedBucketOperation" ) diff --git a/api/router.go b/api/router.go index 765bec3..49f348f 100644 --- a/api/router.go +++ b/api/router.go @@ -159,7 +159,7 @@ func NewRouter(cfg Config) *chi.Mux { defaultRouter.Get("/", named(s3middleware.ListBucketsOperation, cfg.Handler.ListBucketsHandler)) attachErrorHandler(defaultRouter) - vhsRouter := bucketRouter(cfg.Handler) + vhsRouter := newDomainRouter(cfg.Handler) router := newGlobalRouter(defaultRouter, vhsRouter) api.Mount("/", router) @@ -169,12 +169,43 @@ func NewRouter(cfg Config) *chi.Mux { return api } -type globalRouter struct { - pathStyleRouter chi.Router - vhsRouter chi.Router +type domainRouter struct { + bucketRouter chi.Router + defaultRouter chi.Router } -func newGlobalRouter(pathStyleRouter, vhsRouter chi.Router) *globalRouter { +func newDomainRouter(handler Handler) *domainRouter { + defaultRouter := chi.NewRouter() + defaultRouter.Group(func(r chi.Router) { + r.Method(http.MethodGet, "/", NewHandlerFilter(). + Add(NewFilter(). + AllowedQueries(s3middleware.QueryMaxBuckets, s3middleware.QueryPrefix, + s3middleware.QueryContinuationToken, s3middleware.QueryBucketRegion). + Handler(named(s3middleware.ListBucketsOperation, handler.ListBucketsHandler))). + DefaultHandler(notSupportedHandler())) + }) + attachErrorHandler(defaultRouter) + + return &domainRouter{ + bucketRouter: bucketRouter(handler), + defaultRouter: defaultRouter, + } +} + +func (g *domainRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if reqInfo := s3middleware.GetReqInfo(r.Context()); reqInfo.BucketName != "" { + g.bucketRouter.ServeHTTP(w, r) + } else { + g.defaultRouter.ServeHTTP(w, r) + } +} + +type globalRouter struct { + pathStyleRouter chi.Router + vhsRouter *domainRouter +} + +func newGlobalRouter(pathStyleRouter chi.Router, vhsRouter *domainRouter) *globalRouter { return &globalRouter{ pathStyleRouter: pathStyleRouter, vhsRouter: vhsRouter, @@ -182,12 +213,11 @@ func newGlobalRouter(pathStyleRouter, vhsRouter chi.Router) *globalRouter { } func (g *globalRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - router := g.pathStyleRouter if reqInfo := s3middleware.GetReqInfo(r.Context()); reqInfo.RequestVHSEnabled { - router = g.vhsRouter + g.vhsRouter.ServeHTTP(w, r) + } else { + g.pathStyleRouter.ServeHTTP(w, r) } - - router.ServeHTTP(w, r) } func named(name string, handlerFunc http.HandlerFunc) http.HandlerFunc { @@ -338,9 +368,6 @@ func bucketRouter(h Handler) chi.Router { AllowedQueries(s3middleware.QueryDelimiter, s3middleware.QueryMaxKeys, s3middleware.QueryPrefix, s3middleware.QueryMarker, s3middleware.QueryEncodingType). Handler(named(s3middleware.ListObjectsV1Operation, h.ListObjectsV1Handler))). - Add(NewFilter(). - NoQueries(). - Handler(listWrapper(h))). DefaultHandler(notSupportedHandler())) }) @@ -422,18 +449,6 @@ func bucketRouter(h Handler) chi.Router { return bktRouter } -func listWrapper(h Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if reqInfo := s3middleware.GetReqInfo(r.Context()); reqInfo.BucketName == "" { - reqInfo.API = s3middleware.ListBucketsOperation - h.ListBucketsHandler(w, r) - } else { - reqInfo.API = s3middleware.ListObjectsV1Operation - h.ListObjectsV1Handler(w, r) - } - } -} - func objectRouter(h Handler) chi.Router { objRouter := chi.NewRouter() diff --git a/api/router_filter.go b/api/router_filter.go index 7742ce5..06282be 100644 --- a/api/router_filter.go +++ b/api/router_filter.go @@ -138,13 +138,14 @@ func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (hf *HandlerFilters) match(r *http.Request) http.Handler { + queries := r.URL.Query() + LOOP: for _, filter := range hf.filters { - if filter.noQueries && len(r.URL.Query()) > 0 { + if filter.noQueries && len(queries) > 0 { continue } if len(filter.allowedQueries) > 0 { - queries := r.URL.Query() for key := range queries { if _, ok := filter.allowedQueries[key]; !ok { continue LOOP @@ -158,8 +159,8 @@ LOOP: } } for _, query := range filter.queries { - queryVal := r.URL.Query().Get(query.Key) - if !r.URL.Query().Has(query.Key) || query.Value != "" && query.Value != queryVal { + queryVal := queries.Get(query.Key) + if !queries.Has(query.Key) || query.Value != "" && query.Value != queryVal { continue LOOP } } diff --git a/api/router_test.go b/api/router_test.go index 18fe336..981b024 100644 --- a/api/router_test.go +++ b/api/router_test.go @@ -903,6 +903,78 @@ func TestRouterListObjectsV2Domains(t *testing.T) { require.Equal(t, s3middleware.ListObjectsV2Operation, resp.Method) } +func TestRouterListingVHS(t *testing.T) { + baseDomain := "domain.com" + baseDomainWithBkt := "bucket.domain.com" + chiRouter := prepareRouter(t, enableVHSDomains(baseDomain)) + chiRouter.handler.buckets["bucket"] = &data.BucketInfo{} + + for _, tc := range []struct { + name string + host string + queries string + expectedOperation string + notSupported bool + }{ + { + name: "list-object-v1 without query params", + host: baseDomainWithBkt, + expectedOperation: s3middleware.ListObjectsV1Operation, + }, + { + name: "list-buckets without query params", + host: baseDomain, + expectedOperation: s3middleware.ListBucketsOperation, + }, + { + name: "list-objects-v1 with prefix param", + host: baseDomainWithBkt, + queries: func() string { + query := make(url.Values) + query.Set(s3middleware.QueryPrefix, "prefix") + return query.Encode() + }(), + expectedOperation: s3middleware.ListObjectsV1Operation, + }, + { + name: "list-buckets with prefix param", + host: baseDomain, + queries: func() string { + query := make(url.Values) + query.Set(s3middleware.QueryPrefix, "prefix") + return query.Encode() + }(), + expectedOperation: s3middleware.ListBucketsOperation, + }, + { + name: "not supported operation", + host: baseDomain, + queries: func() string { + query := make(url.Values) + query.Set("invalid", "invalid") + return query.Encode() + }(), + notSupported: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.URL.RawQuery = tc.queries + r.Host = tc.host + chiRouter.ServeHTTP(w, r) + + if tc.notSupported { + assertAPIError(t, w, apierr.ErrNotSupported) + return + } + + resp := readResponse(t, w) + require.Equal(t, tc.expectedOperation, resp.Method) + }) + } +} + func enableVHSDomains(domains ...string) option { return func(cfg *Config) { setting := cfg.MiddlewareSettings.(*middlewareSettingsMock)