[#507] Return not implemented by default in bucket router

Signed-off-by: Marina Biryukova <m.biryukova@yadro.com>
This commit is contained in:
Marina Biryukova 2024-10-01 17:43:55 +03:00 committed by Alexey Vanin
parent 346243b159
commit 9e5fb4be95
5 changed files with 116 additions and 16 deletions

View file

@ -28,7 +28,11 @@ const (
QueryPrefix = "prefix" QueryPrefix = "prefix"
QueryDelimiter = "delimiter" QueryDelimiter = "delimiter"
QueryMaxKeys = "max-keys" QueryMaxKeys = "max-keys"
QueryMarker = "marker"
QueryEncodingType = "encoding-type"
amzTagging = "x-amz-tagging" amzTagging = "x-amz-tagging"
unmatchedBucketOperation = "UnmatchedBucketOperation"
) )
// In these operations we don't check resource tags because // In these operations we don't check resource tags because
@ -268,8 +272,17 @@ func determineBucketOperation(r *http.Request) string {
return ListObjectsV2MOperation return ListObjectsV2MOperation
case query.Get(ListTypeQuery) == "2": case query.Get(ListTypeQuery) == "2":
return ListObjectsV2Operation return ListObjectsV2Operation
default: case len(query) == 0 || func() bool {
for key := range query {
if key != QueryDelimiter && key != QueryMaxKeys && key != QueryPrefix && key != QueryMarker && key != QueryEncodingType {
return false
}
}
return true
}():
return ListObjectsV1Operation return ListObjectsV1Operation
default:
return unmatchedBucketOperation
} }
case http.MethodPut: case http.MethodPut:
switch { switch {
@ -291,8 +304,10 @@ func determineBucketOperation(r *http.Request) string {
return PutBucketVersioningOperation return PutBucketVersioningOperation
case query.Has(NotificationQuery): case query.Has(NotificationQuery):
return PutBucketNotificationOperation return PutBucketNotificationOperation
default: case len(query) == 0:
return CreateBucketOperation return CreateBucketOperation
default:
return unmatchedBucketOperation
} }
case http.MethodPost: case http.MethodPost:
switch { switch {
@ -315,12 +330,14 @@ func determineBucketOperation(r *http.Request) string {
return DeleteBucketLifecycleOperation return DeleteBucketLifecycleOperation
case query.Has(EncryptionQuery): case query.Has(EncryptionQuery):
return DeleteBucketEncryptionOperation return DeleteBucketEncryptionOperation
default: case len(query) == 0:
return DeleteBucketOperation return DeleteBucketOperation
default:
return unmatchedBucketOperation
} }
} }
return "UnmatchedBucketOperation" return unmatchedBucketOperation
} }
func determineObjectOperation(r *http.Request) string { func determineObjectOperation(r *http.Request) string {

View file

@ -152,6 +152,12 @@ func TestDetermineBucketOperation(t *testing.T) {
method: http.MethodGet, method: http.MethodGet,
expected: ListObjectsV1Operation, expected: ListObjectsV1Operation,
}, },
{
name: "UnmatchedBucketOperation GET",
method: http.MethodGet,
queryParam: map[string]string{"query": ""},
expected: unmatchedBucketOperation,
},
{ {
name: "PutBucketCorsOperation", name: "PutBucketCorsOperation",
method: http.MethodPut, method: http.MethodPut,
@ -211,6 +217,12 @@ func TestDetermineBucketOperation(t *testing.T) {
method: http.MethodPut, method: http.MethodPut,
expected: CreateBucketOperation, expected: CreateBucketOperation,
}, },
{
name: "UnmatchedBucketOperation PUT",
method: http.MethodPut,
queryParam: map[string]string{"query": ""},
expected: unmatchedBucketOperation,
},
{ {
name: "DeleteMultipleObjectsOperation", name: "DeleteMultipleObjectsOperation",
method: http.MethodPost, method: http.MethodPost,
@ -263,10 +275,16 @@ func TestDetermineBucketOperation(t *testing.T) {
method: http.MethodDelete, method: http.MethodDelete,
expected: DeleteBucketOperation, expected: DeleteBucketOperation,
}, },
{
name: "UnmatchedBucketOperation DELETE",
method: http.MethodDelete,
queryParam: map[string]string{"query": ""},
expected: unmatchedBucketOperation,
},
{ {
name: "UnmatchedBucketOperation", name: "UnmatchedBucketOperation",
method: "invalid-method", method: "invalid-method",
expected: "UnmatchedBucketOperation", expected: unmatchedBucketOperation,
}, },
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View file

@ -226,6 +226,28 @@ func errorResponseHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
func notSupportedHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := s3middleware.GetReqInfo(ctx)
_, wrErr := s3middleware.WriteErrorResponse(w, reqInfo, errors.GetAPIError(errors.ErrNotSupported))
if log := s3middleware.GetReqLog(ctx); log != nil {
fields := []zap.Field{
zap.String("http method", r.Method),
zap.String("url", r.RequestURI),
}
if wrErr != nil {
fields = append(fields, zap.NamedError("write_response_error", wrErr))
}
log.Error(logs.NotSupported, fields...)
}
}
}
// attachErrorHandler set NotFoundHandler and MethodNotAllowedHandler for chi.Router. // attachErrorHandler set NotFoundHandler and MethodNotAllowedHandler for chi.Router.
func attachErrorHandler(api *chi.Mux) { func attachErrorHandler(api *chi.Mux) {
errorHandler := http.HandlerFunc(errorResponseHandler) errorHandler := http.HandlerFunc(errorResponseHandler)
@ -313,7 +335,14 @@ func bucketRouter(h Handler) chi.Router {
Add(NewFilter(). Add(NewFilter().
Queries(s3middleware.VersionsQuery). Queries(s3middleware.VersionsQuery).
Handler(named(s3middleware.ListBucketObjectVersionsOperation, h.ListBucketObjectVersionsHandler))). Handler(named(s3middleware.ListBucketObjectVersionsOperation, h.ListBucketObjectVersionsHandler))).
DefaultHandler(listWrapper(h))) Add(NewFilter().
AllowedQueries(s3middleware.QueryDelimiter, s3middleware.QueryMaxKeys, s3middleware.QueryPrefix,
s3middleware.QueryMarker, s3middleware.QueryEncodingType).
Handler(named(s3middleware.ListObjectsV1Operation, h.ListObjectsV1Handler))).
Add(NewFilter().
NoQueries().
Handler(listWrapper(h))).
DefaultHandler(notSupportedHandler()))
}) })
// PUT method handlers // PUT method handlers
@ -346,7 +375,10 @@ func bucketRouter(h Handler) chi.Router {
Add(NewFilter(). Add(NewFilter().
Queries(s3middleware.NotificationQuery). Queries(s3middleware.NotificationQuery).
Handler(named(s3middleware.PutBucketNotificationOperation, h.PutBucketNotificationHandler))). Handler(named(s3middleware.PutBucketNotificationOperation, h.PutBucketNotificationHandler))).
DefaultHandler(named(s3middleware.CreateBucketOperation, h.CreateBucketHandler))) Add(NewFilter().
NoQueries().
Handler(named(s3middleware.CreateBucketOperation, h.CreateBucketHandler))).
DefaultHandler(notSupportedHandler()))
}) })
// POST method handlers // POST method handlers
@ -380,7 +412,10 @@ func bucketRouter(h Handler) chi.Router {
Add(NewFilter(). Add(NewFilter().
Queries(s3middleware.EncryptionQuery). Queries(s3middleware.EncryptionQuery).
Handler(named(s3middleware.DeleteBucketEncryptionOperation, h.DeleteBucketEncryptionHandler))). Handler(named(s3middleware.DeleteBucketEncryptionOperation, h.DeleteBucketEncryptionHandler))).
DefaultHandler(named(s3middleware.DeleteBucketOperation, h.DeleteBucketHandler))) Add(NewFilter().
NoQueries().
Handler(named(s3middleware.DeleteBucketOperation, h.DeleteBucketHandler))).
DefaultHandler(notSupportedHandler()))
}) })
attachErrorHandler(bktRouter) attachErrorHandler(bktRouter)

View file

@ -13,6 +13,8 @@ type HandlerFilters struct {
type Filter struct { type Filter struct {
queries []Pair queries []Pair
headers []Pair headers []Pair
allowedQueries map[string]struct{}
noQueries bool
h http.Handler h http.Handler
} }
@ -105,6 +107,22 @@ func (f *Filter) Queries(queries ...string) *Filter {
return f return f
} }
// NoQueries sets flag indicating that request shouldn't have query parameters.
func (f *Filter) NoQueries() *Filter {
f.noQueries = true
return f
}
// AllowedQueries adds query parameter keys that may be present in request.
func (f *Filter) AllowedQueries(queries ...string) *Filter {
f.allowedQueries = make(map[string]struct{}, len(queries))
for _, query := range queries {
f.allowedQueries[query] = struct{}{}
}
return f
}
func (hf *HandlerFilters) DefaultHandler(handler http.HandlerFunc) *HandlerFilters { func (hf *HandlerFilters) DefaultHandler(handler http.HandlerFunc) *HandlerFilters {
hf.defaultHandler = handler hf.defaultHandler = handler
return hf return hf
@ -122,6 +140,17 @@ func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (hf *HandlerFilters) match(r *http.Request) http.Handler { func (hf *HandlerFilters) match(r *http.Request) http.Handler {
LOOP: LOOP:
for _, filter := range hf.filters { for _, filter := range hf.filters {
if filter.noQueries && len(r.URL.Query()) > 0 {
continue
}
if len(filter.allowedQueries) > 0 {
queries := r.URL.Query()
for key := range queries {
if _, ok := filter.allowedQueries[key]; !ok {
continue LOOP
}
}
}
for _, header := range filter.headers { for _, header := range filter.headers {
hdrVals := r.Header.Values(header.Key) hdrVals := r.Header.Values(header.Key)
if len(hdrVals) == 0 || header.Value != "" && header.Value != hdrVals[0] { if len(hdrVals) == 0 || header.Value != "" && header.Value != hdrVals[0] {

View file

@ -170,4 +170,5 @@ const (
WarnDomainContainsInvalidPlaceholder = "the domain contains an invalid placeholder, domain skipped" WarnDomainContainsInvalidPlaceholder = "the domain contains an invalid placeholder, domain skipped"
FailedToRemoveOldPartNode = "failed to remove old part node" FailedToRemoveOldPartNode = "failed to remove old part node"
CouldntCacheNetworkInfo = "couldn't cache network info" CouldntCacheNetworkInfo = "couldn't cache network info"
NotSupported = "not supported"
) )