From 5acbe60b789a51b451e10601081651bb4f2334ee Mon Sep 17 00:00:00 2001 From: Denis Kirillov Date: Thu, 12 Jan 2023 10:53:46 +0300 Subject: [PATCH] [#XX] Use go-chi mux Signed-off-by: Denis Kirillov --- api/max_clients.go | 60 ---- api/reqinfo.go | 25 -- api/router.go | 750 +++++++++++++++++++++++++++------------------ cmd/s3-gw/app.go | 52 ++-- go.mod | 2 + go.sum | 5 + 6 files changed, 496 insertions(+), 398 deletions(-) delete mode 100644 api/max_clients.go diff --git a/api/max_clients.go b/api/max_clients.go deleted file mode 100644 index cd34b535..00000000 --- a/api/max_clients.go +++ /dev/null @@ -1,60 +0,0 @@ -package api - -import ( - "net/http" - "time" - - "github.com/TrueCloudLab/frostfs-s3-gw/api/errors" -) - -type ( - // MaxClients provides HTTP handler wrapper with the client limit. - MaxClients interface { - Handle(http.HandlerFunc) http.HandlerFunc - } - - maxClients struct { - pool chan struct{} - timeout time.Duration - } -) - -const defaultRequestDeadline = time.Second * 30 - -// NewMaxClientsMiddleware returns MaxClients interface with handler wrapper based on -// the provided count and the timeout limits. -func NewMaxClientsMiddleware(count int, timeout time.Duration) MaxClients { - if timeout <= 0 { - timeout = defaultRequestDeadline - } - - return &maxClients{ - pool: make(chan struct{}, count), - timeout: timeout, - } -} - -// Handler wraps HTTP handler function with logic limiting access to it. -func (m *maxClients) Handle(f http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if m.pool == nil { - f.ServeHTTP(w, r) - return - } - - deadline := time.NewTimer(m.timeout) - defer deadline.Stop() - - select { - case m.pool <- struct{}{}: - defer func() { <-m.pool }() - f.ServeHTTP(w, r) - case <-deadline.C: - // Send a http timeout message - WriteErrorResponse(w, GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrOperationTimedOut)) - return - case <-r.Context().Done(): - return - } - } -} diff --git a/api/reqinfo.go b/api/reqinfo.go index b57b20aa..d46342aa 100644 --- a/api/reqinfo.go +++ b/api/reqinfo.go @@ -8,8 +8,6 @@ import ( "regexp" "strings" "sync" - - "github.com/gorilla/mux" ) type ( @@ -104,29 +102,6 @@ func GetSourceIP(r *http.Request) string { return addr } -func prepareContext(w http.ResponseWriter, r *http.Request) context.Context { - vars := mux.Vars(r) - bucket := vars["bucket"] - object, err := url.PathUnescape(vars["object"]) - if err != nil { - object = vars["object"] - } - prefix, err := url.QueryUnescape(vars["prefix"]) - if err != nil { - prefix = vars["prefix"] - } - if prefix != "" { - object = prefix - } - return SetReqInfo(r.Context(), - // prepare request info - NewReqInfo(w, r, ObjectRequest{ - Bucket: bucket, - Object: object, - Method: mux.CurrentRoute(r).GetName(), - })) -} - // NewReqInfo returns new ReqInfo based on parameters. func NewReqInfo(w http.ResponseWriter, r *http.Request, req ObjectRequest) *ReqInfo { return &ReqInfo{ diff --git a/api/router.go b/api/router.go index 7e49d670..65bd0c2f 100644 --- a/api/router.go +++ b/api/router.go @@ -2,11 +2,17 @@ package api import ( "context" + "fmt" "net/http" + "net/url" "sync" "github.com/TrueCloudLab/frostfs-s3-gw/api/auth" + "github.com/TrueCloudLab/frostfs-s3-gw/api/errors" "github.com/TrueCloudLab/frostfs-s3-gw/api/metrics" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/hostrouter" "github.com/google/uuid" "github.com/gorilla/mux" "go.uber.org/zap" @@ -129,13 +135,46 @@ func setRequestID(h http.Handler) http.Handler { )) // set request info into context - r = r.WithContext(prepareContext(w, r)) + // bucket name and object will be set in reqInfo later (limitation of go-chi) + r = r.WithContext(SetReqInfo(r.Context(), NewReqInfo(w, r, ObjectRequest{}))) // continue execution h.ServeHTTP(w, r) }) } +// addBucketName adds bucket name to ReqInfo from context. +func addBucketName(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqInfo := GetReqInfo(r.Context()) + reqInfo.BucketName = chi.URLParam(r, "bucket") + h.ServeHTTP(w, r) + }) +} + +// addObjectName adds objects name to ReqInfo from context. +func addObjectName(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + obj := chi.URLParam(r, "object") + object, err := url.PathUnescape(obj) + if err != nil { + object = obj + } + prefix, err := url.QueryUnescape(chi.URLParam(r, "prefix")) + if err != nil { + prefix = chi.URLParam(r, "prefix") + } + if prefix != "" { + object = prefix + } + + reqInfo := GetReqInfo(r.Context()) + reqInfo.ObjectName = object + + h.ServeHTTP(w, r) + }) +} + func appendCORS(handler Handler) mux.MiddlewareFunc { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -151,9 +190,14 @@ func logErrorResponse(l *zap.Logger) mux.MiddlewareFunc { lw := &logResponseWriter{ResponseWriter: w} reqInfo := GetReqInfo(r.Context()) + // here reqInfo doesn't contain bucket name and object name + // pass execution: h.ServeHTTP(lw, r) + // here reqInfo contains bucket name and object name because of + // addBucketName and addObjectName middlewares + // Ignore >400 status codes if lw.statusCode >= http.StatusBadRequest { return @@ -163,7 +207,7 @@ func logErrorResponse(l *zap.Logger) mux.MiddlewareFunc { zap.Int("status", lw.statusCode), zap.String("host", r.Host), zap.String("request_id", GetRequestID(r.Context())), - zap.String("method", mux.CurrentRoute(r).GetName()), + zap.String("method", reqInfo.API), zap.String("bucket", reqInfo.BucketName), zap.String("object", reqInfo.ObjectName), zap.String("description", http.StatusText(lw.statusCode))) @@ -183,307 +227,425 @@ func GetRequestID(v interface{}) string { } } -// Attach adds S3 API handlers from h to r for domains with m client limit using -// center authentication and log logger. -func Attach(r *mux.Router, domains []string, m MaxClients, h Handler, center auth.Center, log *zap.Logger) { - api := r.PathPrefix(SlashSeparator).Subrouter() +func authMiddleware(center auth.Center, log *zap.Logger) func(h http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ctx context.Context + box, err := center.Authenticate(r) + if err != nil { + if err == auth.ErrNoAuthorizationHeader { + log.Debug("couldn't receive access box for gate key, random key will be used") + ctx = r.Context() + } else { + log.Error("failed to pass authentication", zap.Error(err)) + if _, ok := err.(errors.Error); !ok { + err = errors.GetAPIError(errors.ErrAccessDenied) + } + WriteErrorResponse(w, GetReqInfo(r.Context()), err) + return + } + } else { + ctx = context.WithValue(r.Context(), BoxData, box.AccessBox) + if !box.ClientTime.IsZero() { + ctx = context.WithValue(ctx, ClientTime, box.ClientTime) + } + } + h.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func AttachChi(api *chi.Mux, domains []string, throttle middleware.ThrottleOpts, h Handler, center auth.Center, log *zap.Logger) { api.Use( - // -- prepare request + middleware.CleanPath, setRequestID, - - // -- logging error requests logErrorResponse(log), + middleware.ThrottleWithOpts(throttle), + middleware.Recoverer, + authMiddleware(center, log), ) - // Attach user authentication for all S3 routes. - AttachUserAuth(api, center, log) - - buckets := make([]*mux.Router, 0, len(domains)+1) - buckets = append(buckets, api.PathPrefix("/{bucket}").Subrouter()) - + // todo reconsider host routing + hr := hostrouter.New() for _, domain := range domains { - buckets = append(buckets, api.Host("{bucket:.+}."+domain).Subrouter()) + hr.Map("*."+domain, bucketRouter(h, log)) } - for _, bucket := range buckets { - // Object operations - // HeadObject - bucket.Use( - // -- append CORS headers to a response for - appendCORS(h), - ) - bucket.Methods(http.MethodOptions).HandlerFunc(m.Handle(metrics.APIStats("preflight", h.Preflight))).Name("Options") - bucket.Methods(http.MethodHead).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("headobject", h.HeadObjectHandler))).Name("HeadObject") - // CopyObjectPart - bucket.Methods(http.MethodPut).Path("/{object:.+}").Headers(hdrAmzCopySource, "").HandlerFunc(m.Handle(metrics.APIStats("uploadpartcopy", h.UploadPartCopy))).Queries("partNumber", "{partNumber:[0-9]+}", "uploadId", "{uploadId:.*}"). - Name("UploadPartCopy") - // PutObjectPart - bucket.Methods(http.MethodPut).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("uploadpart", h.UploadPartHandler))).Queries("partNumber", "{partNumber:[0-9]+}", "uploadId", "{uploadId:.*}"). - Name("UploadPart") - // ListParts - bucket.Methods(http.MethodGet).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("listobjectparts", h.ListPartsHandler))).Queries("uploadId", "{uploadId:.*}"). - Name("ListObjectParts") - // CompleteMultipartUpload - bucket.Methods(http.MethodPost).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("completemutipartupload", h.CompleteMultipartUploadHandler))).Queries("uploadId", "{uploadId:.*}"). - Name("CompleteMultipartUpload") - // CreateMultipartUpload - bucket.Methods(http.MethodPost).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("createmultipartupload", h.CreateMultipartUploadHandler))).Queries("uploads", ""). - Name("CreateMultipartUpload") - // AbortMultipartUpload - bucket.Methods(http.MethodDelete).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("abortmultipartupload", h.AbortMultipartUploadHandler))).Queries("uploadId", "{uploadId:.*}"). - Name("AbortMultipartUpload") - // ListMultipartUploads - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("listmultipartuploads", h.ListMultipartUploadsHandler))).Queries("uploads", ""). - Name("ListMultipartUploads") - // GetObjectACL -- this is a dummy call. - bucket.Methods(http.MethodGet).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("getobjectacl", h.GetObjectACLHandler))).Queries("acl", ""). - Name("GetObjectACL") - // PutObjectACL -- this is a dummy call. - bucket.Methods(http.MethodPut).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("putobjectacl", h.PutObjectACLHandler))).Queries("acl", ""). - Name("PutObjectACL") - // GetObjectTagging - bucket.Methods(http.MethodGet).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("getobjecttagging", h.GetObjectTaggingHandler))).Queries("tagging", ""). - Name("GetObjectTagging") - // PutObjectTagging - bucket.Methods(http.MethodPut).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("putobjecttagging", h.PutObjectTaggingHandler))).Queries("tagging", ""). - Name("PutObjectTagging") - // DeleteObjectTagging - bucket.Methods(http.MethodDelete).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("deleteobjecttagging", h.DeleteObjectTaggingHandler))).Queries("tagging", ""). - Name("DeleteObjectTagging") - // SelectObjectContent - bucket.Methods(http.MethodPost).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("selectobjectcontent", h.SelectObjectContentHandler))).Queries("select", "").Queries("select-type", "2"). - Name("SelectObjectContent") - // GetObjectRetention - bucket.Methods(http.MethodGet).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("getobjectretention", h.GetObjectRetentionHandler))).Queries("retention", ""). - Name("GetObjectRetention") - // GetObjectLegalHold - bucket.Methods(http.MethodGet).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("getobjectlegalhold", h.GetObjectLegalHoldHandler))).Queries("legal-hold", ""). - Name("GetObjectLegalHold") - // GetObjectAttributes - bucket.Methods(http.MethodGet).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("getobjectattributes", h.GetObjectAttributesHandler))).Queries("attributes", ""). - Name("GetObjectAttributes") - // GetObject - bucket.Methods(http.MethodGet).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("getobject", h.GetObjectHandler))). - Name("GetObject") - // CopyObject - bucket.Methods(http.MethodPut).Path("/{object:.+}").Headers(hdrAmzCopySource, "").HandlerFunc(m.Handle(metrics.APIStats("copyobject", h.CopyObjectHandler))). - Name("CopyObject") - // PutObjectRetention - bucket.Methods(http.MethodPut).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("putobjectretention", h.PutObjectRetentionHandler))).Queries("retention", ""). - Name("PutObjectRetention") - // PutObjectLegalHold - bucket.Methods(http.MethodPut).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("putobjectlegalhold", h.PutObjectLegalHoldHandler))).Queries("legal-hold", ""). - Name("PutObjectLegalHold") - - // PutObject - bucket.Methods(http.MethodPut).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("putobject", h.PutObjectHandler))). - Name("PutObject") - // DeleteObject - bucket.Methods(http.MethodDelete).Path("/{object:.+}").HandlerFunc( - m.Handle(metrics.APIStats("deleteobject", h.DeleteObjectHandler))). - Name("DeleteObject") - - // Bucket operations - // GetBucketLocation - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketlocation", h.GetBucketLocationHandler))).Queries("location", ""). - Name("GetBucketLocation") - // GetBucketPolicy - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketpolicy", h.GetBucketPolicyHandler))).Queries("policy", ""). - Name("GetBucketPolicy") - // GetBucketLifecycle - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketlifecycle", h.GetBucketLifecycleHandler))).Queries("lifecycle", ""). - Name("GetBucketLifecycle") - // GetBucketEncryption - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketencryption", h.GetBucketEncryptionHandler))).Queries("encryption", ""). - Name("GetBucketEncryption") - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketcors", h.GetBucketCorsHandler))).Queries("cors", ""). - Name("GetBucketCors") - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketcors", h.PutBucketCorsHandler))).Queries("cors", ""). - Name("PutBucketCors") - bucket.Methods(http.MethodDelete).HandlerFunc( - m.Handle(metrics.APIStats("deletebucketcors", h.DeleteBucketCorsHandler))).Queries("cors", ""). - Name("DeleteBucketCors") - // Dummy Bucket Calls - // GetBucketACL -- this is a dummy call. - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketacl", h.GetBucketACLHandler))).Queries("acl", ""). - Name("GetBucketACL") - // PutBucketACL -- this is a dummy call. - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketacl", h.PutBucketACLHandler))).Queries("acl", ""). - Name("PutBucketACL") - // GetBucketWebsiteHandler -- this is a dummy call. - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketwebsite", h.GetBucketWebsiteHandler))).Queries("website", ""). - Name("GetBucketWebsite") - // GetBucketAccelerateHandler -- this is a dummy call. - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketaccelerate", h.GetBucketAccelerateHandler))).Queries("accelerate", ""). - Name("GetBucketAccelerate") - // GetBucketRequestPaymentHandler -- this is a dummy call. - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketrequestpayment", h.GetBucketRequestPaymentHandler))).Queries("requestPayment", ""). - Name("GetBucketRequestPayment") - // GetBucketLoggingHandler -- this is a dummy call. - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketlogging", h.GetBucketLoggingHandler))).Queries("logging", ""). - Name("GetBucketLogging") - // GetBucketLifecycleHandler -- this is a dummy call. - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketlifecycle", h.GetBucketLifecycleHandler))).Queries("lifecycle", ""). - Name("GetBucketLifecycle") - // GetBucketReplicationHandler -- this is a dummy call. - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketreplication", h.GetBucketReplicationHandler))).Queries("replication", ""). - Name("GetBucketReplication") - // GetBucketTaggingHandler - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbuckettagging", h.GetBucketTaggingHandler))).Queries("tagging", ""). - Name("GetBucketTagging") - // DeleteBucketWebsiteHandler - bucket.Methods(http.MethodDelete).HandlerFunc( - m.Handle(metrics.APIStats("deletebucketwebsite", h.DeleteBucketWebsiteHandler))).Queries("website", ""). - Name("DeleteBucketWebsite") - // DeleteBucketTaggingHandler - bucket.Methods(http.MethodDelete).HandlerFunc( - m.Handle(metrics.APIStats("deletebuckettagging", h.DeleteBucketTaggingHandler))).Queries("tagging", ""). - Name("DeleteBucketTagging") - - // GetBucketObjectLockConfig - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketobjectlockconfiguration", h.GetBucketObjectLockConfigHandler))).Queries("object-lock", ""). - Name("GetBucketObjectLockConfig") - // GetBucketVersioning - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketversioning", h.GetBucketVersioningHandler))).Queries("versioning", ""). - Name("GetBucketVersioning") - // GetBucketNotification - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("getbucketnotification", h.GetBucketNotificationHandler))).Queries("notification", ""). - Name("GetBucketNotification") - // ListenBucketNotification - bucket.Methods(http.MethodGet).HandlerFunc(metrics.APIStats("listenbucketnotification", h.ListenBucketNotificationHandler)).Queries("events", "{events:.*}"). - Name("ListenBucketNotification") - // ListObjectsV2M - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("listobjectsv2M", h.ListObjectsV2MHandler))).Queries("list-type", "2", "metadata", "true"). - Name("ListObjectsV2M") - // ListObjectsV2 - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("listobjectsv2", h.ListObjectsV2Handler))).Queries("list-type", "2"). - Name("ListObjectsV2") - // ListBucketVersions - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("listbucketversions", h.ListBucketObjectVersionsHandler))).Queries("versions", ""). - Name("ListBucketVersions") - // ListObjectsV1 (Legacy) - bucket.Methods(http.MethodGet).HandlerFunc( - m.Handle(metrics.APIStats("listobjectsv1", h.ListObjectsV1Handler))). - Name("ListObjectsV1") - // PutBucketLifecycle - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketlifecycle", h.PutBucketLifecycleHandler))).Queries("lifecycle", ""). - Name("PutBucketLifecycle") - // PutBucketEncryption - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketencryption", h.PutBucketEncryptionHandler))).Queries("encryption", ""). - Name("PutBucketEncryption") - - // PutBucketPolicy - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketpolicy", h.PutBucketPolicyHandler))).Queries("policy", ""). - Name("PutBucketPolicy") - - // PutBucketObjectLockConfig - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketobjectlockconfig", h.PutBucketObjectLockConfigHandler))).Queries("object-lock", ""). - Name("PutBucketObjectLockConfig") - // PutBucketTaggingHandler - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbuckettagging", h.PutBucketTaggingHandler))).Queries("tagging", ""). - Name("PutBucketTagging") - // PutBucketVersioning - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketversioning", h.PutBucketVersioningHandler))).Queries("versioning", ""). - Name("PutBucketVersioning") - // PutBucketNotification - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("putbucketnotification", h.PutBucketNotificationHandler))).Queries("notification", ""). - Name("PutBucketNotification") - // CreateBucket - bucket.Methods(http.MethodPut).HandlerFunc( - m.Handle(metrics.APIStats("createbucket", h.CreateBucketHandler))). - Name("CreateBucket") - // HeadBucket - bucket.Methods(http.MethodHead).HandlerFunc( - m.Handle(metrics.APIStats("headbucket", h.HeadBucketHandler))). - Name("HeadBucket") - // PostPolicy - bucket.Methods(http.MethodPost).HeadersRegexp(hdrContentType, "multipart/form-data*").HandlerFunc( - m.Handle(metrics.APIStats("postobject", h.PostObject))). - Name("PostObject") - // DeleteMultipleObjects - bucket.Methods(http.MethodPost).HandlerFunc( - m.Handle(metrics.APIStats("deletemultipleobjects", h.DeleteMultipleObjectsHandler))).Queries("delete", ""). - Name("DeleteMultipleObjects") - // DeleteBucketPolicy - bucket.Methods(http.MethodDelete).HandlerFunc( - m.Handle(metrics.APIStats("deletebucketpolicy", h.DeleteBucketPolicyHandler))).Queries("policy", ""). - Name("DeleteBucketPolicy") - // DeleteBucketLifecycle - bucket.Methods(http.MethodDelete).HandlerFunc( - m.Handle(metrics.APIStats("deletebucketlifecycle", h.DeleteBucketLifecycleHandler))).Queries("lifecycle", ""). - Name("DeleteBucketLifecycle") - // DeleteBucketEncryption - bucket.Methods(http.MethodDelete).HandlerFunc( - m.Handle(metrics.APIStats("deletebucketencryption", h.DeleteBucketEncryptionHandler))).Queries("encryption", ""). - Name("DeleteBucketEncryption") - // DeleteBucket - bucket.Methods(http.MethodDelete).HandlerFunc( - m.Handle(metrics.APIStats("deletebucket", h.DeleteBucketHandler))). - Name("DeleteBucket") - } - // Root operation - - // ListBuckets - api.Methods(http.MethodGet).Path(SlashSeparator).HandlerFunc( - m.Handle(metrics.APIStats("listbuckets", h.ListBucketsHandler))). - Name("ListBuckets") - - // S3 browser with signature v4 adds '//' for ListBuckets request, so rather - // than failing with UnknownAPIRequest we simply handle it for now. - api.Methods(http.MethodGet).Path(SlashSeparator + SlashSeparator).HandlerFunc( - m.Handle(metrics.APIStats("listbuckets", h.ListBucketsHandler))). - Name("ListBuckets") + api.Mount("/", hr) + api.Mount("/{bucket}", bucketRouter(h, log)) + api.Get("/", h.ListBucketsHandler) // If none of the routes match, add default error handler routes - api.NotFoundHandler = metrics.APIStats("notfound", errorResponseHandler) - api.MethodNotAllowedHandler = metrics.APIStats("methodnotallowed", errorResponseHandler) + api.NotFound(metrics.APIStats("notfound", errorResponseHandler)) + api.MethodNotAllowed(metrics.APIStats("methodnotallowed", errorResponseHandler)) +} + +func Named(name string, handlerFunc http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + reqInfo := GetReqInfo(r.Context()) + reqInfo.API = name + handlerFunc.ServeHTTP(w, r) + } +} + +func bucketRouter(h Handler, log *zap.Logger) chi.Router { + bktRouter := chi.NewRouter() + bktRouter.Use( + addBucketName, + appendCORS(h), + ) + + bktRouter.Mount("/{object}", objectRouter(h, log)) + + bktRouter.Options("/", h.Preflight) + + bktRouter.Head("/", Named("HeadBucket", h.HeadBucketHandler)) + + // GET method handlers + bktRouter.Group(func(r chi.Router) { + r.Method(http.MethodGet, "/", NewHandlerFilter(). + Add(NewFilter(). + Queries("upload"). + Handler(Named("ListMultipartUploads", h.ListMultipartUploadsHandler))). + Add(NewFilter(). + Queries("location"). + Handler(Named("GetBucketLocation", h.GetBucketLocationHandler))). + Add(NewFilter(). + Queries("policy"). + Handler(Named("GetBucketPolicy", h.GetBucketPolicyHandler))). + Add(NewFilter(). + Queries("lifecycle"). + Handler(Named("GetBucketLifecycle", h.GetBucketLifecycleHandler))). + Add(NewFilter(). + Queries("encryption"). + Handler(Named("GetBucketEncryption", h.GetBucketEncryptionHandler))). + Add(NewFilter(). + Queries("cors"). + Handler(Named("GetBucketCors", h.GetBucketCorsHandler))). + Add(NewFilter(). + Queries("acl"). + Handler(Named("GetBucketACL", h.GetBucketACLHandler))). + Add(NewFilter(). + Queries("website"). + Handler(Named("GetBucketWebsite", h.GetBucketWebsiteHandler))). + Add(NewFilter(). + Queries("accelerate"). + Handler(Named("GetBucketAccelerate", h.GetBucketAccelerateHandler))). + Add(NewFilter(). + Queries("requestPayment"). + Handler(Named("GetBucketRequestPayment", h.GetBucketRequestPaymentHandler))). + Add(NewFilter(). + Queries("logging"). + Handler(Named("GetBucketLogging", h.GetBucketLoggingHandler))). + Add(NewFilter(). + Queries("replication"). + Handler(Named("GetBucketReplication", h.GetBucketReplicationHandler))). + Add(NewFilter(). + Queries("tagging"). + Handler(Named("GetBucketTagging", h.GetBucketTaggingHandler))). + Add(NewFilter(). + Queries("object-lock"). + Handler(Named("GetBucketObjectLockConfig", h.GetBucketObjectLockConfigHandler))). + Add(NewFilter(). + Queries("versioning"). + Handler(Named("GetBucketVersioning", h.GetBucketVersioningHandler))). + Add(NewFilter(). + Queries("notification"). + Handler(Named("GetBucketNotification", h.GetBucketNotificationHandler))). + Add(NewFilter(). + Queries("events"). + Handler(Named("ListenBucketNotification", h.ListenBucketNotificationHandler))). + Add(NewFilter(). + QueriesMatch("list-type", "2", "metadata", "true"). + Handler(Named("ListObjectsV2M", h.ListObjectsV2MHandler))). + Add(NewFilter(). + QueriesMatch("list-type", "2"). + Handler(Named("ListObjectsV2", h.ListObjectsV2Handler))). + Add(NewFilter(). + Queries("versions"). + Handler(Named("ListBucketObjectVersions", h.ListBucketObjectVersionsHandler))). + DefaultHandler(Named("ListObjectsV1", h.ListObjectsV1Handler))) + }) + + // PUT method handlers + bktRouter.Group(func(r chi.Router) { + r.Method(http.MethodPut, "/", NewHandlerFilter(). + Add(NewFilter(). + Queries("cors"). + Handler(Named("PutBucketCors", h.PutBucketCorsHandler))). + Add(NewFilter(). + Queries("acl"). + Handler(Named("PutBucketACL", h.PutBucketACLHandler))). + Add(NewFilter(). + Queries("lifecycle"). + Handler(Named("PutBucketLifecycle", h.PutBucketLifecycleHandler))). + Add(NewFilter(). + Queries("encryption"). + Handler(Named("PutBucketEncryption", h.PutBucketEncryptionHandler))). + Add(NewFilter(). + Queries("policy"). + Handler(Named("PutBucketPolicy", h.PutBucketPolicyHandler))). + Add(NewFilter(). + Queries("object-lock"). + Handler(Named("PutBucketObjectLockConfig", h.PutBucketObjectLockConfigHandler))). + Add(NewFilter(). + Queries("tagging"). + Handler(Named("PutBucketTagging", h.PutBucketTaggingHandler))). + Add(NewFilter(). + Queries("versioning"). + Handler(Named("PutBucketVersioning", h.PutBucketVersioningHandler))). + Add(NewFilter(). + Queries("notification"). + Handler(Named("PutBucketNotification", h.PutBucketNotificationHandler))). + DefaultHandler(Named("CreateBucket", h.CreateBucketHandler))) + }) + + // POST method handlers + bktRouter.Group(func(r chi.Router) { + r.Method(http.MethodPost, "/", NewHandlerFilter(). + Add(NewFilter(). + Queries("delete"). + Handler(Named("DeleteMultipleObjects", h.DeleteMultipleObjectsHandler))). + // todo consider add filter to match header for defaultHandler: hdrContentType, "multipart/form-data*" + DefaultHandler(Named("PostObject", h.PostObject))) + }) + + // DELETE method handlers + bktRouter.Group(func(r chi.Router) { + r.Method(http.MethodDelete, "/", NewHandlerFilter(). + Add(NewFilter(). + Queries("cors"). + Handler(Named("DeleteBucketCors", h.DeleteBucketCorsHandler))). + Add(NewFilter(). + Queries("website"). + Handler(Named("DeleteBucketWebsite", h.DeleteBucketWebsiteHandler))). + Add(NewFilter(). + Queries("tagging"). + Handler(Named("DeleteBucketTagging", h.DeleteBucketTaggingHandler))). + Add(NewFilter(). + Queries("policy"). + Handler(Named("PutBucketPolicy", h.PutBucketPolicyHandler))). + Add(NewFilter(). + Queries("lifecycle"). + Handler(Named("PutBucketLifecycle", h.PutBucketLifecycleHandler))). + Add(NewFilter(). + Queries("encryption"). + Handler(Named("DeleteBucketEncryption", h.DeleteBucketEncryptionHandler))). + DefaultHandler(Named("DeleteBucket", h.DeleteBucketHandler))) + }) + + return bktRouter +} + +func objectRouter(h Handler, log *zap.Logger) chi.Router { + objRouter := chi.NewRouter() + objRouter.Use(addObjectName) + + objRouter.Head("/", Named("HeadObject", h.HeadObjectHandler)) + + // GET method handlers + objRouter.Group(func(r chi.Router) { + r.Method(http.MethodGet, "/", NewHandlerFilter(). + Add(NewFilter(). + Queries("uploadId"). + Handler(Named("ListParts", h.ListPartsHandler))). + Add(NewFilter(). + Queries("acl"). + Handler(Named("GetObjectACL", h.GetObjectACLHandler))). + Add(NewFilter(). + Queries("tagging"). + Handler(Named("GetObjectTagging", h.GetObjectTaggingHandler))). + Add(NewFilter(). + Queries("retention"). + Handler(Named("GetObjectRetention", h.GetObjectRetentionHandler))). + Add(NewFilter(). + Queries("legal-hold"). + Handler(Named("GetObjectLegalHold", h.GetObjectLegalHoldHandler))). + Add(NewFilter(). + Queries("attributes"). + Handler(Named("GetObjectAttributes", h.GetObjectAttributesHandler))). + DefaultHandler(Named("GetObject", h.GetObjectHandler))) + }) + + // PUT method handlers + objRouter.Group(func(r chi.Router) { + r.Method(http.MethodPut, "/", NewHandlerFilter(). + Add(NewFilter(). + Headers(hdrAmzCopySource). + Queries("partNumber", "uploadId"). + Handler(Named("UploadPartCopy", h.UploadPartCopy))). + Add(NewFilter(). + Queries("partNumber", "uploadId"). + Handler(Named("UploadPart", h.UploadPartHandler))). + Add(NewFilter(). + Queries("acl"). + Handler(Named("PutObjectACL", h.PutObjectACLHandler))). + Add(NewFilter(). + Queries("tagging"). + Handler(Named("PutObjectTagging", h.PutObjectTaggingHandler))). + Add(NewFilter(). + Headers(hdrAmzCopySource). + Handler(Named("CopyObject", h.CopyObjectHandler))). + Add(NewFilter(). + Queries("retention"). + Handler(Named("PutObjectRetention", h.PutObjectRetentionHandler))). + Add(NewFilter(). + Queries("legal-hold"). + Handler(Named("PutObjectLegalHold", h.PutObjectLegalHoldHandler))). + DefaultHandler(Named("PutObject", h.PutObjectHandler))) + }) + + // POST method handlers + objRouter.Group(func(r chi.Router) { + r.Method(http.MethodPost, "/", NewHandlerFilter(). + Add(NewFilter(). + Queries("uploadId"). + Handler(Named("CompleteMultipartUpload", h.CompleteMultipartUploadHandler))). + Add(NewFilter(). + Queries("uploads"). + Handler(Named("CreateMultipartUpload", h.CreateMultipartUploadHandler))). + DefaultHandler(Named("SelectObjectContent", h.SelectObjectContentHandler))) + }) + + // DELETE method handlers + objRouter.Group(func(r chi.Router) { + r.Method(http.MethodDelete, "/", NewHandlerFilter(). + Add(NewFilter(). + Queries("uploadId"). + Handler(Named("AbortMultipartUpload", h.AbortMultipartUploadHandler))). + Add(NewFilter(). + Queries("tagging"). + Handler(Named("DeleteObjectTagging", h.DeleteObjectTaggingHandler))). + DefaultHandler(Named("DeleteObject", h.DeleteObjectHandler))) + }) + + return objRouter +} + +type HandlerFilters struct { + filters []Filter + defaultHandler http.Handler +} + +type Filter struct { + queries []Pair + headers []Pair + h http.Handler +} + +type Pair struct { + Key string + Value string +} + +func NewHandlerFilter() *HandlerFilters { + return &HandlerFilters{} +} + +func NewFilter() *Filter { + return &Filter{} +} + +func (hf *HandlerFilters) Add(filter *Filter) *HandlerFilters { + hf.filters = append(hf.filters, *filter) + return hf +} + +// HeadersMatch adds a matcher for header values. +// It accepts a sequence of key/value pairs. Values may define variables. +// Panics if number of parameters is not even. +// Supports only exact matching. +// If the value is an empty string, it will match any value if the key is set. +func (f *Filter) HeadersMatch(pairs ...string) *Filter { + length := len(pairs) + if length%2 != 0 { + panic(fmt.Errorf("filter headers: number of parameters must be multiple of 2, got %v", pairs)) + } + + for i := 0; i < length; i += 2 { + f.headers = append(f.headers, Pair{ + Key: pairs[i], + Value: pairs[i+1], + }) + } + + return f +} + +// Headers is similar to HeadersMatch but accept only header keys, set value to empty string internally. +func (f *Filter) Headers(headers ...string) *Filter { + for _, header := range headers { + f.headers = append(f.headers, Pair{ + Key: header, + Value: "", + }) + } + + return f +} + +func (f *Filter) Handler(handler http.HandlerFunc) *Filter { + f.h = handler + return f +} + +// QueriesMatch adds a matcher for URL query values. +// It accepts a sequence of key/value pairs. Values may define variables. +// Panics if number of parameters is not even. +// Supports only exact matching. +// If the value is an empty string, it will match any value if the key is set. +func (f *Filter) QueriesMatch(pairs ...string) *Filter { + length := len(pairs) + if length%2 != 0 { + panic(fmt.Errorf("filter headers: number of parameters must be multiple of 2, got %v", pairs)) + } + + for i := 0; i < length; i += 2 { + f.queries = append(f.queries, Pair{ + Key: pairs[i], + Value: pairs[i+1], + }) + } + + return f +} + +// Queries is similar to QueriesMatch but accept only query keys, set value to empty string internally. +func (f *Filter) Queries(queries ...string) *Filter { + for _, query := range queries { + f.queries = append(f.queries, Pair{ + Key: query, + Value: "", + }) + } + + return f +} + +func (hf *HandlerFilters) DefaultHandler(handler http.HandlerFunc) *HandlerFilters { + hf.defaultHandler = handler + return hf +} + +func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) { +LOOP: + for _, filter := range hf.filters { + for _, header := range filter.headers { + hdrVals := r.Header.Values(header.Key) + if len(hdrVals) == 0 || header.Value != "" && header.Value != hdrVals[0] { + continue LOOP + } + } + for _, query := range filter.queries { + queryVal := r.URL.Query().Get(query.Key) + if !r.URL.Query().Has(query.Key) || queryVal != "" && query.Value != queryVal { + continue LOOP + } + } + filter.h.ServeHTTP(w, r) + return + } + + hf.defaultHandler.ServeHTTP(w, r) } diff --git a/cmd/s3-gw/app.go b/cmd/s3-gw/app.go index 476d67d9..ecc5036c 100644 --- a/cmd/s3-gw/app.go +++ b/cmd/s3-gw/app.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/go-chi/chi/v5/middleware" "net/http" "os" "os/signal" @@ -25,7 +26,7 @@ import ( "github.com/TrueCloudLab/frostfs-s3-gw/internal/wallet" "github.com/TrueCloudLab/frostfs-sdk-go/netmap" "github.com/TrueCloudLab/frostfs-sdk-go/pool" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/spf13/viper" "go.uber.org/zap" @@ -49,15 +50,20 @@ type ( bucketResolver *resolver.BucketResolver services []*Service settings *appSettings - maxClients api.MaxClients webDone chan struct{} wrkDone chan struct{} } appSettings struct { - logLevel zap.AtomicLevel - policies *placementPolicy + logLevel zap.AtomicLevel + policies *placementPolicy + maxClient maxClientsConfig + } + + maxClientsConfig struct { + deadline time.Duration + count int } Logger struct { @@ -100,8 +106,7 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App { webDone: make(chan struct{}, 1), wrkDone: make(chan struct{}, 1), - maxClients: newMaxClients(v), - settings: newAppSettings(log, v), + settings: newAppSettings(log, v), } app.init(ctx) @@ -163,8 +168,9 @@ func newAppSettings(log *Logger, v *viper.Viper) *appSettings { } return &appSettings{ - logLevel: log.lvl, - policies: policies, + logLevel: log.lvl, + policies: policies, + maxClient: newMaxClients(v), } } @@ -214,18 +220,20 @@ func (a *App) getResolverConfig() ([]string, *resolver.Config) { return order, resolveCfg } -func newMaxClients(cfg *viper.Viper) api.MaxClients { - maxClientsCount := cfg.GetInt(cfgMaxClientsCount) - if maxClientsCount <= 0 { - maxClientsCount = defaultMaxClientsCount +func newMaxClients(cfg *viper.Viper) maxClientsConfig { + config := maxClientsConfig{} + + config.count = cfg.GetInt(cfgMaxClientsCount) + if config.count <= 0 { + config.count = defaultMaxClientsCount } - maxClientsDeadline := cfg.GetDuration(cfgMaxClientsDeadline) - if maxClientsDeadline <= 0 { - maxClientsDeadline = defaultMaxClientsDeadline + config.deadline = cfg.GetDuration(cfgMaxClientsDeadline) + if config.deadline <= 0 { + config.deadline = defaultMaxClientsDeadline } - return api.NewMaxClientsMiddleware(maxClientsCount, maxClientsDeadline) + return config } func getPool(ctx context.Context, logger *zap.Logger, cfg *viper.Viper) (*pool.Pool, *keys.PrivateKey) { @@ -420,12 +428,18 @@ func (a *App) Serve(ctx context.Context) { // Attach S3 API: domains := a.cfg.GetStringSlice(cfgListenDomains) a.log.Info("fetch domains, prepare to use API", zap.Strings("domains", domains)) - router := mux.NewRouter().SkipClean(true).UseEncodedPath() - api.Attach(router, domains, a.maxClients, a.api, a.ctr, a.log) + + throttleOps := middleware.ThrottleOpts{ + Limit: a.settings.maxClient.count, + BacklogTimeout: a.settings.maxClient.deadline, + } + + chiRouter := chi.NewRouter() + api.AttachChi(chiRouter, domains, throttleOps, a.api, a.ctr, a.log) // Use mux.Router as http.Handler srv := new(http.Server) - srv.Handler = router + srv.Handler = chiRouter srv.ErrorLog = zap.NewStdLog(a.log) a.startServices() diff --git a/go.mod b/go.mod index 0f9394d1..c2710d10 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,8 @@ require ( github.com/TrueCloudLab/frostfs-sdk-go v0.0.0-20221214065929-4c779423f556 github.com/aws/aws-sdk-go v1.44.6 github.com/bluele/gcache v0.0.2 + github.com/go-chi/chi/v5 v5.0.8 + github.com/go-chi/hostrouter v0.2.0 github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.0 github.com/minio/sio v0.3.0 diff --git a/go.sum b/go.sum index d1aaa97e..64bbadb2 100644 --- a/go.sum +++ b/go.sum @@ -148,6 +148,11 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-chi/chi/v5 v5.0.0/go.mod h1:BBug9lr0cqtdAhsu6R4AAdvufI0/XBzAQSsUqJpoZOs= +github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= +github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/hostrouter v0.2.0 h1:GwC7TZz8+SlJN/tV/aeJgx4F+mI5+sp+5H1PelQUjHM= +github.com/go-chi/hostrouter v0.2.0/go.mod h1:pJ49vWVmtsKRKZivQx0YMYv4h0aX+Gcn6V23Np9Wf1s= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=