package handler import ( "context" "encoding/xml" "errors" "fmt" "sort" "strconv" "strings" "git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/data" "git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/logs" "git.frostfs.info/TrueCloudLab/frostfs-http-gw/tokens" "git.frostfs.info/TrueCloudLab/frostfs-http-gw/utils" "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing" qostagging "git.frostfs.info/TrueCloudLab/frostfs-qos/tagging" cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" "github.com/valyala/fasthttp" "go.uber.org/zap" ) const ( internalIOTag = "internal" corsFilePathTemplate = "/%s.cors" wildcard = "*" ) var errNoCORS = errors.New("no CORS objects found") func (h *Handler) Preflight(c *fasthttp.RequestCtx) { ctx, span := tracing.StartSpanFromContext(utils.GetContextFromRequest(c), "handler.Preflight") defer span.End() ctx = qostagging.ContextWithIOTag(ctx, internalIOTag) cidParam, _ := c.UserValue("cid").(string) reqLog := utils.GetReqLogOrDefault(ctx, h.log) log := reqLog.With(zap.String("cid", cidParam)) origin := c.Request.Header.Peek(fasthttp.HeaderOrigin) if len(origin) == 0 { log.Error(logs.EmptyOriginRequestHeader, logs.TagField(logs.TagDatapath)) ResponseError(c, "Origin request header needed", fasthttp.StatusBadRequest) return } method := c.Request.Header.Peek(fasthttp.HeaderAccessControlRequestMethod) if len(method) == 0 { log.Error(logs.EmptyAccessControlRequestMethodHeader, logs.TagField(logs.TagDatapath)) ResponseError(c, "Access-Control-Request-Method request header needed", fasthttp.StatusBadRequest) return } corsRule := h.config.CORS() if corsRule != nil { setCORSHeadersFromRule(c, corsRule) return } corsConfig, err := h.getCORSConfig(ctx, log, cidParam) if err != nil { log.Error(logs.CouldNotGetCORSConfiguration, zap.Error(err), logs.TagField(logs.TagDatapath)) status := fasthttp.StatusInternalServerError if errors.Is(err, errNoCORS) { status = fasthttp.StatusNotFound } ResponseError(c, "could not get CORS configuration: "+err.Error(), status) return } var headers []string requestHeaders := c.Request.Header.Peek(fasthttp.HeaderAccessControlRequestHeaders) if len(requestHeaders) > 0 { headers = strings.Split(string(requestHeaders), ", ") } for _, rule := range corsConfig.CORSRules { for _, o := range rule.AllowedOrigins { if o == string(origin) || o == wildcard { for _, m := range rule.AllowedMethods { if m == string(method) { if !checkSubslice(rule.AllowedHeaders, headers) { continue } c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin)) c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", ")) if headers != nil { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowHeaders, string(requestHeaders)) } if rule.ExposeHeaders != nil { c.Response.Header.Set(fasthttp.HeaderAccessControlExposeHeaders, strings.Join(rule.ExposeHeaders, ", ")) } if rule.MaxAgeSeconds > 0 || rule.MaxAgeSeconds == -1 { c.Response.Header.Set(fasthttp.HeaderAccessControlMaxAge, strconv.Itoa(rule.MaxAgeSeconds)) } if o != wildcard { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true") } return } } } } } log.Error(logs.CORSRuleWasNotMatched, logs.TagField(logs.TagDatapath)) ResponseError(c, "Forbidden", fasthttp.StatusForbidden) } func (h *Handler) SetCORSHeaders(c *fasthttp.RequestCtx) { ctx, span := tracing.StartSpanFromContext(utils.GetContextFromRequest(c), "handler.SetCORSHeaders") defer span.End() origin := c.Request.Header.Peek(fasthttp.HeaderOrigin) if len(origin) == 0 { return } ctx = qostagging.ContextWithIOTag(ctx, internalIOTag) cidParam, _ := c.UserValue("cid").(string) reqLog := utils.GetReqLogOrDefault(ctx, h.log) log := reqLog.With(zap.String("cid", cidParam)) corsRule := h.config.CORS() if corsRule != nil { setCORSHeadersFromRule(c, corsRule) return } corsConfig, err := h.getCORSConfig(ctx, log, cidParam) if err != nil { log.Error(logs.CouldNotGetCORSConfiguration, zap.Error(err), logs.TagField(logs.TagDatapath)) return } var withCredentials bool if tkn, err := tokens.LoadBearerToken(ctx); err == nil && tkn != nil { withCredentials = true } for _, rule := range corsConfig.CORSRules { for _, o := range rule.AllowedOrigins { if o == string(origin) { for _, m := range rule.AllowedMethods { if m == string(c.Method()) { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin)) c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", ")) c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true") c.Response.Header.Set(fasthttp.HeaderVary, fasthttp.HeaderOrigin) return } } } if o == wildcard { for _, m := range rule.AllowedMethods { if m == string(c.Method()) { if withCredentials { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, string(origin)) c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true") c.Response.Header.Set(fasthttp.HeaderVary, fasthttp.HeaderOrigin) } else { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, o) } c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", ")) return } } } } } } func (h *Handler) getCORSConfig(ctx context.Context, log *zap.Logger, cidStr string) (*data.CORSConfiguration, error) { cnrID, err := h.resolveContainer(ctx, cidStr) if err != nil { return nil, fmt.Errorf("resolve container '%s': %w", cidStr, err) } if cors := h.corsCache.Get(*cnrID); cors != nil { return cors, nil } objID, err := h.getLastCORSObject(ctx, *cnrID) if err != nil { return nil, fmt.Errorf("get last cors object: %w", err) } var addr oid.Address addr.SetContainer(h.corsCnrID) addr.SetObject(objID) corsObj, err := h.frostfs.GetObject(ctx, PrmObjectGet{ PrmAuth: PrmAuth{ BearerToken: bearerToken(ctx), }, Address: addr, }) if err != nil { return nil, fmt.Errorf("get cors object '%s': %w", addr.EncodeToString(), err) } corsConfig := &data.CORSConfiguration{} if err = xml.NewDecoder(corsObj.Payload).Decode(corsConfig); err != nil { return nil, fmt.Errorf("decode cors object: %w", err) } if err = h.corsCache.Put(*cnrID, corsConfig); err != nil { log.Warn(logs.CouldntCacheCors, zap.Error(err), logs.TagField(logs.TagDatapath)) } return corsConfig, nil } func (h *Handler) getLastCORSObject(ctx context.Context, cnrID cid.ID) (oid.ID, error) { filters := object.NewSearchFilters() filters.AddRootFilter() filters.AddFilter(object.AttributeFilePath, fmt.Sprintf(corsFilePathTemplate, cnrID), object.MatchStringEqual) prmAuth := PrmAuth{ BearerToken: bearerToken(ctx), } res, err := h.frostfs.SearchObjects(ctx, PrmObjectSearch{ PrmAuth: prmAuth, Container: h.corsCnrID, Filters: filters, }) if err != nil { return oid.ID{}, fmt.Errorf("search cors versions: %w", err) } defer res.Close() var ( addr oid.Address obj *object.Object headErr error objs = make([]*object.Object, 0) ) addr.SetContainer(h.corsCnrID) err = res.Iterate(func(id oid.ID) bool { addr.SetObject(id) obj, headErr = h.frostfs.HeadObject(ctx, PrmObjectHead{ PrmAuth: prmAuth, Address: addr, }) if headErr != nil { headErr = fmt.Errorf("head cors object '%s': %w", addr.EncodeToString(), headErr) return true } objs = append(objs, obj) return false }) if err != nil { return oid.ID{}, fmt.Errorf("iterate cors objects: %w", err) } if headErr != nil { return oid.ID{}, headErr } if len(objs) == 0 { return oid.ID{}, errNoCORS } sort.Slice(objs, func(i, j int) bool { versionID1, _ := objs[i].ID() versionID2, _ := objs[j].ID() timestamp1 := utils.GetAttributeValue(objs[i].Attributes(), object.AttributeTimestamp) timestamp2 := utils.GetAttributeValue(objs[j].Attributes(), object.AttributeTimestamp) if objs[i].CreationEpoch() != objs[j].CreationEpoch() { return objs[i].CreationEpoch() < objs[j].CreationEpoch() } if len(timestamp1) > 0 && len(timestamp2) > 0 && timestamp1 != timestamp2 { unixTime1, err := strconv.ParseInt(timestamp1, 10, 64) if err != nil { return versionID1.EncodeToString() < versionID2.EncodeToString() } unixTime2, err := strconv.ParseInt(timestamp2, 10, 64) if err != nil { return versionID1.EncodeToString() < versionID2.EncodeToString() } return unixTime1 < unixTime2 } return versionID1.EncodeToString() < versionID2.EncodeToString() }) objID, _ := objs[len(objs)-1].ID() return objID, nil } func setCORSHeadersFromRule(c *fasthttp.RequestCtx, cors *data.CORSRule) { c.Response.Header.Set(fasthttp.HeaderAccessControlMaxAge, strconv.Itoa(cors.MaxAgeSeconds)) if len(cors.AllowedOrigins) != 0 { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowOrigin, cors.AllowedOrigins[0]) } if len(cors.AllowedMethods) != 0 { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowMethods, strings.Join(cors.AllowedMethods, ", ")) } if len(cors.AllowedHeaders) != 0 { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowHeaders, strings.Join(cors.AllowedHeaders, ", ")) } if len(cors.ExposeHeaders) != 0 { c.Response.Header.Set(fasthttp.HeaderAccessControlExposeHeaders, strings.Join(cors.ExposeHeaders, ", ")) } if cors.AllowedCredentials { c.Response.Header.Set(fasthttp.HeaderAccessControlAllowCredentials, "true") } } func checkSubslice(slice []string, subSlice []string) bool { if sliceContains(slice, wildcard) { return true } if len(subSlice) > len(slice) { return false } for _, r := range subSlice { if !sliceContains(slice, r) { return false } } return true } func sliceContains(slice []string, str string) bool { for _, s := range slice { if s == str { return true } } return false }