All checks were successful
/ DCO (pull_request) Successful in 31s
/ Vulncheck (pull_request) Successful in 45s
/ Builds (pull_request) Successful in 1m2s
/ OCI image (pull_request) Successful in 1m25s
/ Lint (pull_request) Successful in 2m23s
/ Tests (pull_request) Successful in 53s
/ Integration tests (pull_request) Successful in 5m24s
/ Vulncheck (push) Successful in 47s
/ Builds (push) Successful in 1m2s
/ OCI image (push) Successful in 1m21s
/ Lint (push) Successful in 1m56s
/ Tests (push) Successful in 59s
/ Integration tests (push) Successful in 5m32s
Signed-off-by: Marina Biryukova <m.biryukova@yadro.com>
342 lines
10 KiB
Go
342 lines
10 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"encoding/xml"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/data"
|
|
"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/logs"
|
|
"git.frostfs.info/TrueCloudLab/frostfs-http-gw/tokens"
|
|
"git.frostfs.info/TrueCloudLab/frostfs-http-gw/utils"
|
|
"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
|
|
qostagging "git.frostfs.info/TrueCloudLab/frostfs-qos/tagging"
|
|
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
|
|
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
|
|
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
|
|
"github.com/valyala/fasthttp"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
internalIOTag = "internal"
|
|
corsFilePathTemplate = "/%s.cors"
|
|
wildcard = "*"
|
|
)
|
|
|
|
var errNoCORS = errors.New("no CORS objects found")
|
|
|
|
func (h *Handler) Preflight(c *fasthttp.RequestCtx) {
|
|
ctx, span := tracing.StartSpanFromContext(utils.GetContextFromRequest(c), "handler.Preflight")
|
|
defer span.End()
|
|
|
|
ctx = qostagging.ContextWithIOTag(ctx, internalIOTag)
|
|
cidParam, _ := c.UserValue("cid").(string)
|
|
reqLog := utils.GetReqLogOrDefault(ctx, h.log)
|
|
log := reqLog.With(zap.String("cid", cidParam))
|
|
|
|
origin := c.Request.Header.Peek(fasthttp.HeaderOrigin)
|
|
if len(origin) == 0 {
|
|
log.Error(logs.EmptyOriginRequestHeader, logs.TagField(logs.TagDatapath))
|
|
ResponseError(c, "Origin request header needed", fasthttp.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
method := c.Request.Header.Peek(fasthttp.HeaderAccessControlRequestMethod)
|
|
if len(method) == 0 {
|
|
log.Error(logs.EmptyAccessControlRequestMethodHeader, logs.TagField(logs.TagDatapath))
|
|
ResponseError(c, "Access-Control-Request-Method request header needed", fasthttp.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
corsRule := h.config.CORS()
|
|
if corsRule != nil {
|
|
setCORSHeadersFromRule(c, corsRule)
|
|
return
|
|
}
|
|
|
|
corsConfig, err := h.getCORSConfig(ctx, log, cidParam)
|
|
if err != nil {
|
|
log.Error(logs.CouldNotGetCORSConfiguration, zap.Error(err), logs.TagField(logs.TagDatapath))
|
|
status := fasthttp.StatusInternalServerError
|
|
if errors.Is(err, errNoCORS) {
|
|
status = fasthttp.StatusNotFound
|
|
}
|
|
ResponseError(c, "could not get CORS configuration: "+err.Error(), status)
|
|
return
|
|
}
|
|
|
|
var headers []string
|
|
requestHeaders := c.Request.Header.Peek(fasthttp.HeaderAccessControlRequestHeaders)
|
|
if len(requestHeaders) > 0 {
|
|
headers = strings.Split(string(requestHeaders), ", ")
|
|
}
|
|
|
|
for _, rule := range corsConfig.CORSRules {
|
|
for _, o := range rule.AllowedOrigins {
|
|
if o == string(origin) || o == wildcard {
|
|
for _, m := range rule.AllowedMethods {
|
|
if m == string(method) {
|
|
if !checkSubslice(rule.AllowedHeaders, headers) {
|
|
continue
|
|
}
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin))
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
|
|
if headers != nil {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowHeaders, string(requestHeaders))
|
|
}
|
|
if rule.ExposeHeaders != nil {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlExposeHeaders, strings.Join(rule.ExposeHeaders, ", "))
|
|
}
|
|
if rule.MaxAgeSeconds > 0 || rule.MaxAgeSeconds == -1 {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlMaxAge, strconv.Itoa(rule.MaxAgeSeconds))
|
|
}
|
|
if o != wildcard {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
log.Error(logs.CORSRuleWasNotMatched, logs.TagField(logs.TagDatapath))
|
|
ResponseError(c, "Forbidden", fasthttp.StatusForbidden)
|
|
}
|
|
|
|
func (h *Handler) SetCORSHeaders(c *fasthttp.RequestCtx) {
|
|
ctx, span := tracing.StartSpanFromContext(utils.GetContextFromRequest(c), "handler.SetCORSHeaders")
|
|
defer span.End()
|
|
|
|
origin := c.Request.Header.Peek(fasthttp.HeaderOrigin)
|
|
if len(origin) == 0 {
|
|
return
|
|
}
|
|
|
|
ctx = qostagging.ContextWithIOTag(ctx, internalIOTag)
|
|
cidParam, _ := c.UserValue("cid").(string)
|
|
reqLog := utils.GetReqLogOrDefault(ctx, h.log)
|
|
log := reqLog.With(zap.String("cid", cidParam))
|
|
|
|
corsRule := h.config.CORS()
|
|
if corsRule != nil {
|
|
setCORSHeadersFromRule(c, corsRule)
|
|
return
|
|
}
|
|
|
|
corsConfig, err := h.getCORSConfig(ctx, log, cidParam)
|
|
if err != nil {
|
|
log.Error(logs.CouldNotGetCORSConfiguration, zap.Error(err), logs.TagField(logs.TagDatapath))
|
|
return
|
|
}
|
|
|
|
var withCredentials bool
|
|
if tkn, err := tokens.LoadBearerToken(ctx); err == nil && tkn != nil {
|
|
withCredentials = true
|
|
}
|
|
|
|
for _, rule := range corsConfig.CORSRules {
|
|
for _, o := range rule.AllowedOrigins {
|
|
if o == string(origin) {
|
|
for _, m := range rule.AllowedMethods {
|
|
if m == string(c.Method()) {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin))
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
|
|
c.Response.Header.Set(fasthttp.HeaderVary, fasthttp.HeaderOrigin)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
if o == wildcard {
|
|
for _, m := range rule.AllowedMethods {
|
|
if m == string(c.Method()) {
|
|
if withCredentials {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin))
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
|
|
c.Response.Header.Set(fasthttp.HeaderVary, fasthttp.HeaderOrigin)
|
|
} else {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, o)
|
|
}
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Handler) getCORSConfig(ctx context.Context, log *zap.Logger, cidStr string) (*data.CORSConfiguration, error) {
|
|
cnrID, err := h.resolveContainer(ctx, cidStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve container '%s': %w", cidStr, err)
|
|
}
|
|
|
|
if cors := h.corsCache.Get(*cnrID); cors != nil {
|
|
return cors, nil
|
|
}
|
|
|
|
objID, err := h.getLastCORSObject(ctx, *cnrID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get last cors object: %w", err)
|
|
}
|
|
|
|
var addr oid.Address
|
|
addr.SetContainer(h.corsCnrID)
|
|
addr.SetObject(objID)
|
|
corsObj, err := h.frostfs.GetObject(ctx, PrmObjectGet{
|
|
PrmAuth: PrmAuth{
|
|
BearerToken: bearerToken(ctx),
|
|
},
|
|
Address: addr,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get cors object '%s': %w", addr.EncodeToString(), err)
|
|
}
|
|
|
|
corsConfig := &data.CORSConfiguration{}
|
|
if err = xml.NewDecoder(corsObj.Payload).Decode(corsConfig); err != nil {
|
|
return nil, fmt.Errorf("decode cors object: %w", err)
|
|
}
|
|
|
|
if err = h.corsCache.Put(*cnrID, corsConfig); err != nil {
|
|
log.Warn(logs.CouldntCacheCors, zap.Error(err), logs.TagField(logs.TagDatapath))
|
|
}
|
|
|
|
return corsConfig, nil
|
|
}
|
|
|
|
func (h *Handler) getLastCORSObject(ctx context.Context, cnrID cid.ID) (oid.ID, error) {
|
|
filters := object.NewSearchFilters()
|
|
filters.AddRootFilter()
|
|
filters.AddFilter(object.AttributeFilePath, fmt.Sprintf(corsFilePathTemplate, cnrID), object.MatchStringEqual)
|
|
|
|
prmAuth := PrmAuth{
|
|
BearerToken: bearerToken(ctx),
|
|
}
|
|
res, err := h.frostfs.SearchObjects(ctx, PrmObjectSearch{
|
|
PrmAuth: prmAuth,
|
|
Container: h.corsCnrID,
|
|
Filters: filters,
|
|
})
|
|
if err != nil {
|
|
return oid.ID{}, fmt.Errorf("search cors versions: %w", err)
|
|
}
|
|
defer res.Close()
|
|
|
|
var (
|
|
addr oid.Address
|
|
obj *object.Object
|
|
headErr error
|
|
objs = make([]*object.Object, 0)
|
|
)
|
|
addr.SetContainer(h.corsCnrID)
|
|
err = res.Iterate(func(id oid.ID) bool {
|
|
addr.SetObject(id)
|
|
obj, headErr = h.frostfs.HeadObject(ctx, PrmObjectHead{
|
|
PrmAuth: prmAuth,
|
|
Address: addr,
|
|
})
|
|
if headErr != nil {
|
|
headErr = fmt.Errorf("head cors object '%s': %w", addr.EncodeToString(), headErr)
|
|
return true
|
|
}
|
|
|
|
objs = append(objs, obj)
|
|
return false
|
|
})
|
|
if err != nil {
|
|
return oid.ID{}, fmt.Errorf("iterate cors objects: %w", err)
|
|
}
|
|
|
|
if headErr != nil {
|
|
return oid.ID{}, headErr
|
|
}
|
|
|
|
if len(objs) == 0 {
|
|
return oid.ID{}, errNoCORS
|
|
}
|
|
|
|
sort.Slice(objs, func(i, j int) bool {
|
|
versionID1, _ := objs[i].ID()
|
|
versionID2, _ := objs[j].ID()
|
|
timestamp1 := utils.GetAttributeValue(objs[i].Attributes(), object.AttributeTimestamp)
|
|
timestamp2 := utils.GetAttributeValue(objs[j].Attributes(), object.AttributeTimestamp)
|
|
|
|
if objs[i].CreationEpoch() != objs[j].CreationEpoch() {
|
|
return objs[i].CreationEpoch() < objs[j].CreationEpoch()
|
|
}
|
|
|
|
if len(timestamp1) > 0 && len(timestamp2) > 0 && timestamp1 != timestamp2 {
|
|
unixTime1, err := strconv.ParseInt(timestamp1, 10, 64)
|
|
if err != nil {
|
|
return versionID1.EncodeToString() < versionID2.EncodeToString()
|
|
}
|
|
|
|
unixTime2, err := strconv.ParseInt(timestamp2, 10, 64)
|
|
if err != nil {
|
|
return versionID1.EncodeToString() < versionID2.EncodeToString()
|
|
}
|
|
|
|
return unixTime1 < unixTime2
|
|
}
|
|
|
|
return versionID1.EncodeToString() < versionID2.EncodeToString()
|
|
})
|
|
|
|
objID, _ := objs[len(objs)-1].ID()
|
|
return objID, nil
|
|
}
|
|
|
|
func setCORSHeadersFromRule(c *fasthttp.RequestCtx, cors *data.CORSRule) {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlMaxAge, strconv.Itoa(cors.MaxAgeSeconds))
|
|
|
|
if len(cors.AllowedOrigins) != 0 {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, cors.AllowedOrigins[0])
|
|
}
|
|
|
|
if len(cors.AllowedMethods) != 0 {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(cors.AllowedMethods, ", "))
|
|
}
|
|
|
|
if len(cors.AllowedHeaders) != 0 {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowHeaders, strings.Join(cors.AllowedHeaders, ", "))
|
|
}
|
|
|
|
if len(cors.ExposeHeaders) != 0 {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlExposeHeaders, strings.Join(cors.ExposeHeaders, ", "))
|
|
}
|
|
|
|
if cors.AllowedCredentials {
|
|
c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true")
|
|
}
|
|
}
|
|
|
|
func checkSubslice(slice []string, subSlice []string) bool {
|
|
if sliceContains(slice, wildcard) {
|
|
return true
|
|
}
|
|
if len(subSlice) > len(slice) {
|
|
return false
|
|
}
|
|
for _, r := range subSlice {
|
|
if !sliceContains(slice, r) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func sliceContains(slice []string, str string) bool {
|
|
for _, s := range slice {
|
|
if s == str {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|