diff --git a/api/handler/cors.go b/api/handler/cors.go index 6a671a5..dcdaa52 100644 --- a/api/handler/cors.go +++ b/api/handler/cors.go @@ -29,7 +29,7 @@ func (h *handler) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) { return } - cors, err := h.obj.GetBucketCORS(ctx, bktInfo) + cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder) if err != nil { h.logAndSendError(ctx, w, "could not get cors", reqInfo, err) return @@ -112,7 +112,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) { return } - cors, err := h.obj.GetBucketCORS(ctx, bktInfo) + cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder) if err != nil { h.reqLogger(ctx).Warn(logs.GetBucketCors, zap.Error(err)) return @@ -178,7 +178,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) { headers = strings.Split(requestHeaders, ", ") } - cors, err := h.obj.GetBucketCORS(ctx, bktInfo) + cors, err := h.obj.GetBucketCORS(ctx, bktInfo, h.cfg.NewXMLDecoder) if err != nil { h.logAndSendError(ctx, w, "could not get cors", reqInfo, err) return diff --git a/api/handler/cors_test.go b/api/handler/cors_test.go index 42008d7..595bff7 100644 --- a/api/handler/cors_test.go +++ b/api/handler/cors_test.go @@ -19,7 +19,14 @@ func TestCORSOriginWildcard(t *testing.T) { ` - hc := prepareHandlerContext(t) + bodyNoXmlns := ` + + + GET + * + +` + hc := prepareHandlerContextWithMinCache(t) bktName := "bucket-for-cors" box, _ := createAccessBox(t) @@ -39,6 +46,17 @@ func TestCORSOriginWildcard(t *testing.T) { w, r = prepareTestPayloadRequest(hc, bktName, "", nil) hc.Handler().GetBucketCorsHandler(w, r) assertStatus(t, w, http.StatusOK) + + hc.config.useDefaultXMLNS = true + w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(bodyNoXmlns)) + ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}) + r = r.WithContext(ctx) + hc.Handler().PutBucketCorsHandler(w, r) + assertStatus(t, w, http.StatusOK) + + w, r = prepareTestPayloadRequest(hc, bktName, "", nil) + hc.Handler().GetBucketCorsHandler(w, r) + assertStatus(t, w, http.StatusOK) } func TestPreflight(t *testing.T) { diff --git a/api/handler/handlers_test.go b/api/handler/handlers_test.go index dae87fc..7b73e51 100644 --- a/api/handler/handlers_test.go +++ b/api/handler/handlers_test.go @@ -79,6 +79,7 @@ type configMock struct { bypassContentEncodingInChunks bool md5Enabled bool tlsTerminationHeader string + useDefaultXMLNS bool } func (c *configMock) DefaultPlacementPolicy(_ string) netmap.PlacementPolicy { @@ -99,7 +100,11 @@ func (c *configMock) DefaultCopiesNumbers(_ string) []uint32 { } func (c *configMock) NewXMLDecoder(r io.Reader, _ string) *xml.Decoder { - return xml.NewDecoder(r) + dec := xml.NewDecoder(r) + if c.useDefaultXMLNS { + dec.DefaultSpace = "http://s3.amazonaws.com/doc/2006-03-01/" + } + return dec } func (c *configMock) BypassContentEncodingInChunks(_ string) bool { diff --git a/api/layer/cors.go b/api/layer/cors.go index ebf0a19..925f8ed 100644 --- a/api/layer/cors.go +++ b/api/layer/cors.go @@ -3,6 +3,7 @@ package layer import ( "bytes" "context" + "encoding/xml" "errors" "fmt" "io" @@ -95,8 +96,8 @@ func (n *Layer) deleteCORSObject(ctx context.Context, bktInfo *data.BucketInfo, } } -func (n *Layer) GetBucketCORS(ctx context.Context, bktInfo *data.BucketInfo) (*data.CORSConfiguration, error) { - cors, err := n.getCORS(ctx, bktInfo) +func (n *Layer) GetBucketCORS(ctx context.Context, bktInfo *data.BucketInfo, decoder func(io.Reader, string) *xml.Decoder) (*data.CORSConfiguration, error) { + cors, err := n.getCORS(ctx, bktInfo, decoder) if err != nil { return nil, err } diff --git a/api/layer/system_object.go b/api/layer/system_object.go index 19e58f0..4f7f40e 100644 --- a/api/layer/system_object.go +++ b/api/layer/system_object.go @@ -5,6 +5,7 @@ import ( "encoding/xml" "errors" "fmt" + "io" "math" "strconv" "time" @@ -161,7 +162,7 @@ func (n *Layer) GetLockInfo(ctx context.Context, objVersion *data.ObjectVersion) return lockInfo, nil } -func (n *Layer) getCORS(ctx context.Context, bkt *data.BucketInfo) (*data.CORSConfiguration, error) { +func (n *Layer) getCORS(ctx context.Context, bkt *data.BucketInfo, decoder func(io.Reader, string) *xml.Decoder) (*data.CORSConfiguration, error) { owner := n.BearerOwner(ctx) if cors := n.cache.GetCORS(owner, bkt); cors != nil { return cors, nil @@ -190,7 +191,7 @@ func (n *Layer) getCORS(ctx context.Context, bkt *data.BucketInfo) (*data.CORSCo } cors := &data.CORSConfiguration{} - if err = xml.NewDecoder(obj.Payload).Decode(&cors); err != nil { + if err = decoder(obj.Payload, "").Decode(&cors); err != nil { return nil, fmt.Errorf("unmarshal cors: %w", err) }