[#594] Fix unmarshal cors: expected element in name space error

Signed-off-by: Pavel Pogodaev <p.pogodaev@yadro.com>
This commit is contained in:
Pavel Pogodaev 2025-01-15 12:07:39 +03:00 committed by Alexey Vanin
parent 0cab76d01e
commit bc975989de
5 changed files with 34 additions and 9 deletions

View file

@ -29,7 +29,7 @@ func (h *handler) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
cors, err := h.obj.GetBucketCORS(ctx, bktInfo) cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder)
if err != nil { if err != nil {
h.logAndSendError(ctx, w, "could not get cors", reqInfo, err) h.logAndSendError(ctx, w, "could not get cors", reqInfo, err)
return return
@ -112,7 +112,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
return return
} }
cors, err := h.obj.GetBucketCORS(ctx, bktInfo) cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder)
if err != nil { if err != nil {
h.reqLogger(ctx).Warn(logs.GetBucketCors, zap.Error(err)) h.reqLogger(ctx).Warn(logs.GetBucketCors, zap.Error(err))
return return
@ -178,7 +178,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
headers = strings.Split(requestHeaders, ", ") headers = strings.Split(requestHeaders, ", ")
} }
cors, err := h.obj.GetBucketCORS(ctx, bktInfo) cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder)
if err != nil { if err != nil {
h.logAndSendError(ctx, w, "could not get cors", reqInfo, err) h.logAndSendError(ctx, w, "could not get cors", reqInfo, err)
return return

View file

@ -19,7 +19,14 @@ func TestCORSOriginWildcard(t *testing.T) {
</CORSRule> </CORSRule>
</CORSConfiguration> </CORSConfiguration>
` `
hc := prepareHandlerContext(t) bodyNoXmlns := `
<CORSConfiguration>
<CORSRule>
<AllowedMethod>GET</AllowedMethod>
<AllowedOrigin>*</AllowedOrigin>
</CORSRule>
</CORSConfiguration>`
hc := prepareHandlerContextWithMinCache(t)
bktName := "bucket-for-cors" bktName := "bucket-for-cors"
box, _ := createAccessBox(t) box, _ := createAccessBox(t)
@ -39,6 +46,17 @@ func TestCORSOriginWildcard(t *testing.T) {
w, r = prepareTestPayloadRequest(hc, bktName, "", nil) w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
hc.Handler().GetBucketCorsHandler(w, r) hc.Handler().GetBucketCorsHandler(w, r)
assertStatus(t, w, http.StatusOK) assertStatus(t, w, http.StatusOK)
hc.config.useDefaultXMLNS = true
w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(bodyNoXmlns))
ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
r = r.WithContext(ctx)
hc.Handler().PutBucketCorsHandler(w, r)
assertStatus(t, w, http.StatusOK)
w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
hc.Handler().GetBucketCorsHandler(w, r)
assertStatus(t, w, http.StatusOK)
} }
func TestPreflight(t *testing.T) { func TestPreflight(t *testing.T) {

View file

@ -79,6 +79,7 @@ type configMock struct {
bypassContentEncodingInChunks bool bypassContentEncodingInChunks bool
md5Enabled bool md5Enabled bool
tlsTerminationHeader string tlsTerminationHeader string
useDefaultXMLNS bool
} }
func (c *configMock) DefaultPlacementPolicy(_ string) netmap.PlacementPolicy { func (c *configMock) DefaultPlacementPolicy(_ string) netmap.PlacementPolicy {
@ -99,7 +100,11 @@ func (c *configMock) DefaultCopiesNumbers(_ string) []uint32 {
} }
func (c *configMock) NewXMLDecoder(r io.Reader, _ string) *xml.Decoder { func (c *configMock) NewXMLDecoder(r io.Reader, _ string) *xml.Decoder {
return xml.NewDecoder(r) dec := xml.NewDecoder(r)
if c.useDefaultXMLNS {
dec.DefaultSpace = "http://s3.amazonaws.com/doc/2006-03-01/"
}
return dec
} }
func (c *configMock) BypassContentEncodingInChunks(_ string) bool { func (c *configMock) BypassContentEncodingInChunks(_ string) bool {

View file

@ -3,6 +3,7 @@ package layer
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/xml"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -95,8 +96,8 @@ func (n *Layer) deleteCORSObject(ctx context.Context, bktInfo *data.BucketInfo,
} }
} }
func (n *Layer) GetBucketCORS(ctx context.Context, bktInfo *data.BucketInfo) (*data.CORSConfiguration, error) { func (n *Layer) GetBucketCORS(ctx context.Context, bktInfo *data.BucketInfo, decoder func(io.Reader, string) *xml.Decoder) (*data.CORSConfiguration, error) {
cors, err := n.getCORS(ctx, bktInfo) cors, err := n.getCORS(ctx, bktInfo, decoder)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -5,6 +5,7 @@ import (
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
"io"
"math" "math"
"strconv" "strconv"
"time" "time"
@ -161,7 +162,7 @@ func (n *Layer) GetLockInfo(ctx context.Context, objVersion *data.ObjectVersion)
return lockInfo, nil return lockInfo, nil
} }
func (n *Layer) getCORS(ctx context.Context, bkt *data.BucketInfo) (*data.CORSConfiguration, error) { func (n *Layer) getCORS(ctx context.Context, bkt *data.BucketInfo, decoder func(io.Reader, string) *xml.Decoder) (*data.CORSConfiguration, error) {
owner := n.BearerOwner(ctx) owner := n.BearerOwner(ctx)
if cors := n.cache.GetCORS(owner, bkt); cors != nil { if cors := n.cache.GetCORS(owner, bkt); cors != nil {
return cors, nil return cors, nil
@ -190,7 +191,7 @@ func (n *Layer) getCORS(ctx context.Context, bkt *data.BucketInfo) (*data.CORSCo
} }
cors := &data.CORSConfiguration{} cors := &data.CORSConfiguration{}
if err = xml.NewDecoder(obj.Payload).Decode(&cors); err != nil { if err = decoder(obj.Payload, "").Decode(&cors); err != nil {
return nil, fmt.Errorf("unmarshal cors: %w", err) return nil, fmt.Errorf("unmarshal cors: %w", err)
} }