[#672] Support wildcard in allowed origins and headers

Signed-off-by: Marina Biryukova <m.biryukova@yadro.com>
This commit is contained in:
Marina Biryukova 2025-03-31 14:42:59 +03:00
parent 2ad2531d3a
commit e45c1a2188
6 changed files with 724 additions and 18 deletions

View file

@ -2,6 +2,8 @@ package handler
import (
"net/http"
"regexp"
"slices"
"strconv"
"strings"
@ -110,6 +112,10 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
if origin == "" {
return
}
method := r.Header.Get(api.AccessControlRequestMethod)
if method == "" {
method = r.Method
}
ctx = qostagging.ContextWithIOTag(ctx, util.InternalIOTag)
reqInfo := middleware.GetReqInfo(ctx)
@ -132,9 +138,9 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
for _, rule := range cors.CORSRules {
for _, o := range rule.AllowedOrigins {
if o == origin {
if o == origin || (strings.Contains(o, "*") && len(o) > 1 && match(o, origin)) {
for _, m := range rule.AllowedMethods {
if m == r.Method {
if m == method {
w.Header().Set(api.AccessControlAllowOrigin, origin)
w.Header().Set(api.AccessControlAllowMethods, strings.Join(rule.AllowedMethods, ", "))
w.Header().Set(api.AccessControlAllowCredentials, "true")
@ -145,7 +151,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
}
if o == wildcard {
for _, m := range rule.AllowedMethods {
if m == r.Method {
if m == method {
if withCredentials {
w.Header().Set(api.AccessControlAllowOrigin, origin)
w.Header().Set(api.AccessControlAllowCredentials, "true")
@ -199,7 +205,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
for _, rule := range cors.CORSRules {
for _, o := range rule.AllowedOrigins {
if o == origin || o == wildcard {
if o == origin || o == wildcard || (strings.Contains(o, "*") && match(o, origin)) {
for _, m := range rule.AllowedMethods {
if m == method {
if !checkSubslice(rule.AllowedHeaders, headers) {
@ -235,12 +241,9 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
}
func checkSubslice(slice []string, subSlice []string) bool {
if sliceContains(slice, wildcard) {
if slices.Contains(slice, wildcard) {
return true
}
if len(subSlice) > len(slice) {
return false
}
for _, r := range subSlice {
if !sliceContains(slice, r) {
return false
@ -251,9 +254,16 @@ func checkSubslice(slice []string, subSlice []string) bool {
func sliceContains(slice []string, str string) bool {
for _, s := range slice {
if s == str {
if s == str || (strings.Contains(s, "*") && match(s, str)) {
return true
}
}
return false
}
func match(tmpl, str string) bool {
regexpStr := "^" + regexp.QuoteMeta(tmpl) + "$"
regexpStr = regexpStr[:strings.Index(regexpStr, "*")-1] + "." + regexpStr[strings.Index(regexpStr, "*"):]
reg := regexp.MustCompile(regexpStr)
return reg.Match([]byte(str))
}