[#226] Improve CORS validation
All checks were successful
/ DCO (pull_request) Successful in 33s
/ Vulncheck (pull_request) Successful in 52s
/ OCI image (pull_request) Successful in 1m14s
/ Lint (pull_request) Successful in 2m25s
/ Tests (pull_request) Successful in 1m8s
/ Integration tests (pull_request) Successful in 5m46s
/ Builds (pull_request) Successful in 43s

Signed-off-by: Marina Biryukova <m.biryukova@yadro.com>
This commit is contained in:
Marina Biryukova 2025-04-30 16:06:04 +03:00
parent 96a22d98f2
commit 871ae5d763
3 changed files with 115 additions and 0 deletions

View file

@ -79,6 +79,10 @@ func (h *Handler) Preflight(req *fasthttp.RequestCtx) {
}
for _, rule := range corsConfig.CORSRules {
if err = checkCORSRuleWildcards(rule); err != nil {
log.Error(logs.InvalidCorsRule, zap.Error(err), logs.TagField(logs.TagDatapath))
continue
}
for _, o := range rule.AllowedOrigins {
if o == string(origin) || o == wildcard || (strings.Contains(o, "*") && match(o, string(origin))) {
for _, m := range rule.AllowedMethods {
@ -147,6 +151,10 @@ func (h *Handler) SetCORSHeaders(req *fasthttp.RequestCtx) {
}
for _, rule := range corsConfig.CORSRules {
if err = checkCORSRuleWildcards(rule); err != nil {
log.Error(logs.InvalidCorsRule, zap.Error(err), logs.TagField(logs.TagDatapath))
continue
}
for _, o := range rule.AllowedOrigins {
if o == string(origin) || (strings.Contains(o, "*") && len(o) > 1 && match(o, string(origin))) {
for _, m := range rule.AllowedMethods {
@ -178,6 +186,22 @@ func (h *Handler) SetCORSHeaders(req *fasthttp.RequestCtx) {
}
}
func checkCORSRuleWildcards(rule data.CORSRule) error {
for _, origin := range rule.AllowedOrigins {
if strings.Count(origin, wildcard) > 1 {
return fmt.Errorf("invalid allowed origin: %s", origin)
}
}
for _, header := range rule.AllowedHeaders {
if strings.Count(header, wildcard) > 1 {
return fmt.Errorf("invalid allowed header: %s", header)
}
}
return nil
}
func (h *Handler) getCORSConfig(ctx context.Context, log *zap.Logger, cidStr string) (*data.CORSConfiguration, error) {
cnrID, err := h.resolveContainer(ctx, cidStr)
if err != nil {