frostfs-s3-gw/api/handler/cors.go
2025-01-16 12:55:53 +00:00

246 lines
6.7 KiB
Go

package handler
import (
"net/http"
"strconv"
"strings"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
"go.uber.org/zap"
)
const (
// DefaultMaxAge is a default value of Access-Control-Max-Age if this value is not set in a rule.
DefaultMaxAge = 600
wildcard = "*"
)
func (h *handler) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
return
}
cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder)
if err != nil {
h.logAndSendError(ctx, w, "could not get cors", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, cors); err != nil {
h.logAndSendError(ctx, w, "could not encode cors to response", reqInfo, err)
return
}
}
func (h *handler) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
return
}
p := &layer.PutCORSParams{
BktInfo: bktInfo,
Reader: r.Body,
NewDecoder: h.cfg.NewXMLDecoder,
UserAgent: r.UserAgent(),
}
p.CopiesNumbers, err = h.pickCopiesNumbers(parseMetadata(r), reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
return
}
if err = h.obj.PutBucketCORS(ctx, p); err != nil {
h.logAndSendError(ctx, w, "could not put cors configuration", reqInfo, err)
return
}
if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
return
}
}
func (h *handler) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
return
}
if err = h.obj.DeleteBucketCORS(ctx, bktInfo); err != nil {
h.logAndSendError(ctx, w, "could not delete cors", reqInfo, err)
}
w.WriteHeader(http.StatusNoContent)
}
func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
return
}
origin := r.Header.Get(api.Origin)
if origin == "" {
return
}
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
if reqInfo.BucketName == "" {
return
}
bktInfo, err := h.getBucketInfo(ctx, reqInfo.BucketName)
if err != nil {
h.reqLogger(ctx).Warn(logs.GetBucketInfo, zap.Error(err))
return
}
cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder)
if err != nil {
h.reqLogger(ctx).Warn(logs.GetBucketCors, zap.Error(err))
return
}
withCredentials := r.Header.Get(api.Authorization) != ""
for _, rule := range cors.CORSRules {
for _, o := range rule.AllowedOrigins {
if o == origin {
for _, m := range rule.AllowedMethods {
if m == r.Method {
w.Header().Set(api.AccessControlAllowOrigin, origin)
w.Header().Set(api.AccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
w.Header().Set(api.AccessControlAllowCredentials, "true")
w.Header().Set(api.Vary, api.Origin)
return
}
}
}
if o == wildcard {
for _, m := range rule.AllowedMethods {
if m == r.Method {
if withCredentials {
w.Header().Set(api.AccessControlAllowOrigin, origin)
w.Header().Set(api.AccessControlAllowCredentials, "true")
w.Header().Set(api.Vary, api.Origin)
} else {
w.Header().Set(api.AccessControlAllowOrigin, o)
}
w.Header().Set(api.AccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
return
}
}
}
}
}
}
func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
bktInfo, err := h.getBucketInfo(ctx, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
return
}
origin := r.Header.Get(api.Origin)
if origin == "" {
h.logAndSendError(ctx, w, "origin request header needed", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
}
method := r.Header.Get(api.AccessControlRequestMethod)
if method == "" {
h.logAndSendError(ctx, w, "Access-Control-Request-Method request header needed", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
return
}
var headers []string
requestHeaders := r.Header.Get(api.AccessControlRequestHeaders)
if requestHeaders != "" {
headers = strings.Split(requestHeaders, ", ")
}
cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder)
if err != nil {
h.logAndSendError(ctx, w, "could not get cors", reqInfo, err)
return
}
for _, rule := range cors.CORSRules {
for _, o := range rule.AllowedOrigins {
if o == origin || o == wildcard {
for _, m := range rule.AllowedMethods {
if m == method {
if !checkSubslice(rule.AllowedHeaders, headers) {
continue
}
w.Header().Set(api.AccessControlAllowOrigin, origin)
w.Header().Set(api.AccessControlAllowMethods, method)
if headers != nil {
w.Header().Set(api.AccessControlAllowHeaders, requestHeaders)
}
if rule.ExposeHeaders != nil {
w.Header().Set(api.AccessControlExposeHeaders, strings.Join(rule.ExposeHeaders, ", "))
}
if rule.MaxAgeSeconds > 0 || rule.MaxAgeSeconds == -1 {
w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(rule.MaxAgeSeconds))
} else {
w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(h.cfg.DefaultMaxAge()))
}
if o != wildcard {
w.Header().Set(api.AccessControlAllowCredentials, "true")
}
if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
return
}
return
}
}
}
}
}
h.logAndSendError(ctx, w, "Forbidden", reqInfo, errors.GetAPIError(errors.ErrAccessDenied))
}
func checkSubslice(slice []string, subSlice []string) bool {
if sliceContains(slice, wildcard) {
return true
}
if len(subSlice) > len(slice) {
return false
}
for _, r := range subSlice {
if !sliceContains(slice, r) {
return false
}
}
return true
}
func sliceContains(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}