[#172] Convert handler config to interface

Signed-off-by: Marina Biryukova <m.biryukova@yadro.com>
This commit is contained in:
Marina Biryukova 2023-09-08 14:17:14 +03:00 committed by Alexey Vanin
parent 51e591877b
commit b8c93ed391
12 changed files with 239 additions and 230 deletions

View file

@ -21,7 +21,7 @@ type (
log *zap.Logger
obj layer.Client
notificator Notificator
cfg *Config
cfg Config
}
Notificator interface {
@ -30,29 +30,17 @@ type (
}
// Config contains data which handler needs to keep.
Config struct {
Policy PlacementPolicy
XMLDecoder XMLDecoderProvider
DefaultMaxAge int
NotificatorEnabled bool
ResolveZoneList []string
IsResolveListAllow bool // True if ResolveZoneList contains allowed zones
CompleteMultipartKeepalive time.Duration
Kludge KludgeSettings
}
PlacementPolicy interface {
Config interface {
DefaultPlacementPolicy() netmap.PlacementPolicy
PlacementPolicy(string) (netmap.PlacementPolicy, bool)
CopiesNumbers(string) ([]uint32, bool)
DefaultCopiesNumbers() []uint32
}
XMLDecoderProvider interface {
NewCompleteMultipartDecoder(io.Reader) *xml.Decoder
}
KludgeSettings interface {
DefaultMaxAge() int
NotificatorEnabled() bool
ResolveZoneList() []string
IsResolveListAllow() bool
CompleteMultipartKeepalive() time.Duration
BypassContentEncodingInChunks() bool
}
)
@ -60,7 +48,7 @@ type (
var _ api.Handler = (*handler)(nil)
// New creates new api.Handler using given logger and client.
func New(log *zap.Logger, obj layer.Client, notificator Notificator, cfg *Config) (api.Handler, error) {
func New(log *zap.Logger, obj layer.Client, notificator Notificator, cfg Config) (api.Handler, error) {
switch {
case obj == nil:
return nil, errors.New("empty FrostFS Object Layer")
@ -68,7 +56,7 @@ func New(log *zap.Logger, obj layer.Client, notificator Notificator, cfg *Config
return nil, errors.New("empty logger")
}
if !cfg.NotificatorEnabled {
if !cfg.NotificatorEnabled() {
log.Warn(logs.NotificatorIsDisabledS3WontProduceNotificationEvents)
} else if notificator == nil {
return nil, errors.New("empty notificator")
@ -96,12 +84,12 @@ func (h *handler) pickCopiesNumbers(metadata map[string]string, locationConstrai
return result, nil
}
copiesNumbers, ok := h.cfg.Policy.CopiesNumbers(locationConstraint)
copiesNumbers, ok := h.cfg.CopiesNumbers(locationConstraint)
if ok {
return copiesNumbers, nil
}
return h.cfg.Policy.DefaultCopiesNumbers(), nil
return h.cfg.DefaultCopiesNumbers(), nil
}
func parseCopiesNumbers(copiesNumbersStr string) ([]uint32, error) {

View file

@ -12,11 +12,9 @@ func TestCopiesNumberPicker(t *testing.T) {
locationConstraint2 := "two"
locationConstraints[locationConstraint1] = []uint32{2, 3, 4}
config := &Config{
Policy: &placementPolicyMock{
copiesNumbers: locationConstraints,
defaultCopiesNumbers: []uint32{1},
},
config := &configMock{
copiesNumbers: locationConstraints,
defaultCopiesNumbers: []uint32{1},
}
h := handler{
cfg: config,

View file

@ -194,7 +194,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
if rule.MaxAgeSeconds > 0 || rule.MaxAgeSeconds == -1 {
w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(rule.MaxAgeSeconds))
} else {
w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(h.cfg.DefaultMaxAge))
w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(h.cfg.DefaultMaxAge()))
}
if o != wildcard {
w.Header().Set(api.AccessControlAllowCredentials, "true")

View file

@ -37,7 +37,7 @@ type handlerContext struct {
tp *layer.TestFrostFS
tree *tree.Tree
context context.Context
kludge *kludgeSettingsMock
config *configMock
layerFeatures *layer.FeatureSettingsMock
}
@ -58,41 +58,56 @@ func (hc *handlerContext) Context() context.Context {
return hc.context
}
type placementPolicyMock struct {
defaultPolicy netmap.PlacementPolicy
copiesNumbers map[string][]uint32
defaultCopiesNumbers []uint32
}
func (p *placementPolicyMock) DefaultPlacementPolicy() netmap.PlacementPolicy {
return p.defaultPolicy
}
func (p *placementPolicyMock) PlacementPolicy(string) (netmap.PlacementPolicy, bool) {
return netmap.PlacementPolicy{}, false
}
func (p *placementPolicyMock) CopiesNumbers(locationConstraint string) ([]uint32, bool) {
result, ok := p.copiesNumbers[locationConstraint]
return result, ok
}
func (p *placementPolicyMock) DefaultCopiesNumbers() []uint32 {
return p.defaultCopiesNumbers
}
type xmlDecoderProviderMock struct{}
func (p *xmlDecoderProviderMock) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder {
return xml.NewDecoder(r)
}
type kludgeSettingsMock struct {
type configMock struct {
defaultPolicy netmap.PlacementPolicy
copiesNumbers map[string][]uint32
defaultCopiesNumbers []uint32
bypassContentEncodingInChunks bool
}
func (k *kludgeSettingsMock) BypassContentEncodingInChunks() bool {
return k.bypassContentEncodingInChunks
func (c *configMock) DefaultPlacementPolicy() netmap.PlacementPolicy {
return c.defaultPolicy
}
func (c *configMock) PlacementPolicy(string) (netmap.PlacementPolicy, bool) {
return netmap.PlacementPolicy{}, false
}
func (c *configMock) CopiesNumbers(locationConstraint string) ([]uint32, bool) {
result, ok := c.copiesNumbers[locationConstraint]
return result, ok
}
func (c *configMock) DefaultCopiesNumbers() []uint32 {
return c.defaultCopiesNumbers
}
func (c *configMock) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder {
return xml.NewDecoder(r)
}
func (c *configMock) BypassContentEncodingInChunks() bool {
return c.bypassContentEncodingInChunks
}
func (c *configMock) DefaultMaxAge() int {
return 0
}
func (c *configMock) NotificatorEnabled() bool {
return false
}
func (c *configMock) ResolveZoneList() []string {
return []string{}
}
func (c *configMock) IsResolveListAllow() bool {
return false
}
func (c *configMock) CompleteMultipartKeepalive() time.Duration {
return time.Duration(0)
}
func prepareHandlerContext(t *testing.T) *handlerContext {
@ -139,16 +154,13 @@ func prepareHandlerContextBase(t *testing.T, minCache bool) *handlerContext {
err = pp.DecodeString("REP 1")
require.NoError(t, err)
kludge := &kludgeSettingsMock{}
cfg := &configMock{
defaultPolicy: pp,
}
h := &handler{
log: l,
obj: layer.NewLayer(l, tp, layerCfg),
cfg: &Config{
Policy: &placementPolicyMock{defaultPolicy: pp},
XMLDecoder: &xmlDecoderProviderMock{},
Kludge: kludge,
},
cfg: cfg,
}
return &handlerContext{
@ -158,7 +170,7 @@ func prepareHandlerContextBase(t *testing.T, minCache bool) *handlerContext {
tp: tp,
tree: treeMock,
context: middleware.SetBoxData(context.Background(), newTestAccessBox(t, key)),
kludge: kludge,
config: cfg,
layerFeatures: features,
}

View file

@ -135,7 +135,7 @@ func (h *handler) HeadBucketHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set(api.ContainerID, bktInfo.CID.EncodeToString())
w.Header().Set(api.AmzBucketRegion, bktInfo.LocationConstraint)
if isAvailableToResolve(bktInfo.Zone, h.cfg.ResolveZoneList, h.cfg.IsResolveListAllow) {
if isAvailableToResolve(bktInfo.Zone, h.cfg.ResolveZoneList(), h.cfg.IsResolveListAllow()) {
w.Header().Set(api.ContainerName, bktInfo.Name)
w.Header().Set(api.ContainerZone, bktInfo.Zone)
}

View file

@ -406,7 +406,7 @@ func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http.
)
reqBody := new(CompleteMultipartUpload)
if err = h.cfg.XMLDecoder.NewCompleteMultipartDecoder(r.Body).Decode(reqBody); err != nil {
if err = h.cfg.NewCompleteMultipartDecoder(r.Body).Decode(reqBody); err != nil {
h.logAndSendError(w, "could not read complete multipart upload xml", reqInfo,
errors.GetAPIError(errors.ErrMalformedXML), additional...)
return
@ -424,7 +424,7 @@ func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http.
// Next operations might take some time, so we want to keep client's
// connection alive. To do so, gateway sends periodic white spaces
// back to the client the same way as Amazon S3 service does.
stopPeriodicResponseWriter := periodicXMLWriter(w, h.cfg.CompleteMultipartKeepalive)
stopPeriodicResponseWriter := periodicXMLWriter(w, h.cfg.CompleteMultipartKeepalive())
// Start complete multipart upload which may take some time to fetch object
// and re-upload it part by part.

View file

@ -155,7 +155,7 @@ func (h *handler) GetBucketNotificationHandler(w http.ResponseWriter, r *http.Re
}
func (h *handler) sendNotifications(ctx context.Context, p *SendNotificationParams) error {
if !h.cfg.NotificatorEnabled {
if !h.cfg.NotificatorEnabled() {
return nil
}
@ -198,7 +198,7 @@ func (h *handler) checkBucketConfiguration(ctx context.Context, conf *data.Notif
return
}
if h.cfg.NotificatorEnabled {
if h.cfg.NotificatorEnabled() {
if err = h.notificator.SendTestNotification(q.QueueArn, r.BucketName, r.RequestID, r.Host, layer.TimeNow(ctx)); err != nil {
return
}

View file

@ -348,7 +348,7 @@ func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, error) {
}
r.Header.Set(api.ContentEncoding, strings.Join(resultContentEncoding, ","))
if !chunkedEncoding && !h.cfg.Kludge.BypassContentEncodingInChunks() {
if !chunkedEncoding && !h.cfg.BypassContentEncodingInChunks() {
return nil, fmt.Errorf("%w: request is not chunk encoded, encodings '%s'",
errors.GetAPIError(errors.ErrInvalidEncodingMethod), strings.Join(encodings, ","))
}
@ -797,7 +797,7 @@ func (h *handler) CreateBucketHandler(w http.ResponseWriter, r *http.Request) {
}
func (h handler) setPolicy(prm *layer.CreateBucketParams, locationConstraint string, userPolicies []*accessbox.ContainerPolicy) error {
prm.Policy = h.cfg.Policy.DefaultPlacementPolicy()
prm.Policy = h.cfg.DefaultPlacementPolicy()
prm.LocationConstraint = locationConstraint
if locationConstraint == "" {
@ -811,7 +811,7 @@ func (h handler) setPolicy(prm *layer.CreateBucketParams, locationConstraint str
}
}
if policy, ok := h.cfg.Policy.PlacementPolicy(locationConstraint); ok {
if policy, ok := h.cfg.PlacementPolicy(locationConstraint); ok {
prm.Policy = policy
return nil
}

View file

@ -230,7 +230,7 @@ func TestPutChunkedTestContentEncoding(t *testing.T) {
hc.Handler().PutObjectHandler(w, req)
assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrInvalidEncodingMethod))
hc.kludge.bypassContentEncodingInChunks = true
hc.config.bypassContentEncodingInChunks = true
w, req, _ = getChunkedRequest(hc.context, t, bktName, objName)
req.Header.Set(api.ContentEncoding, "gzip")
hc.Handler().PutObjectHandler(w, req)

View file

@ -3,13 +3,14 @@ package main
import (
"context"
"encoding/hex"
"encoding/xml"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"runtime/debug"
"sync"
"sync/atomic"
"syscall"
"time"
@ -27,7 +28,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/version"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/wallet"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/xml"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/metrics"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/pkg/service/tree"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
@ -42,6 +42,8 @@ import (
"google.golang.org/grpc"
)
const awsDefaultNamespace = "http://s3.amazonaws.com/doc/2006-03-01/"
type (
// App is the main application structure.
App struct {
@ -67,12 +69,22 @@ type (
}
appSettings struct {
logLevel zap.AtomicLevel
policies *placementPolicy
xmlDecoder *xml.DecoderProvider
maxClient maxClientsConfig
bypassContentEncodingInChunks atomic.Bool
clientCut atomic.Bool
logLevel zap.AtomicLevel
maxClient maxClientsConfig
defaultMaxAge int
notificatorEnabled bool
resolveZoneList []string
isResolveListAllow bool // True if ResolveZoneList contains allowed zones
completeMultipartKeepalive time.Duration
mu sync.RWMutex
defaultPolicy netmap.PlacementPolicy
regionMap map[string]netmap.PlacementPolicy
copiesNumbers map[string][]uint32
defaultCopiesNumbers []uint32
defaultXMLNSForCompleteMultipart bool
bypassContentEncodingInChunks bool
clientCut bool
}
maxClientsConfig struct {
@ -84,14 +96,6 @@ type (
logger *zap.Logger
lvl zap.AtomicLevel
}
placementPolicy struct {
mu sync.RWMutex
defaultPolicy netmap.PlacementPolicy
regionMap map[string]netmap.PlacementPolicy
copiesNumbers map[string][]uint32
defaultCopiesNumbers []uint32
}
)
func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
@ -168,32 +172,130 @@ func (a *App) initLayer(ctx context.Context) {
func newAppSettings(log *Logger, v *viper.Viper) *appSettings {
settings := &appSettings{
logLevel: log.lvl,
policies: newPlacementPolicy(log.logger, v),
xmlDecoder: xml.NewDecoderProvider(v.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload)),
maxClient: newMaxClients(v),
logLevel: log.lvl,
maxClient: newMaxClients(v),
defaultXMLNSForCompleteMultipart: v.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload),
defaultMaxAge: fetchDefaultMaxAge(v, log.logger),
notificatorEnabled: v.GetBool(cfgEnableNATS),
completeMultipartKeepalive: v.GetDuration(cfgKludgeCompleteMultipartUploadKeepalive),
}
settings.resolveZoneList = v.GetStringSlice(cfgResolveBucketAllow)
settings.isResolveListAllow = len(settings.resolveZoneList) > 0
if !settings.isResolveListAllow {
settings.resolveZoneList = v.GetStringSlice(cfgResolveBucketDeny)
}
settings.setBypassContentEncodingInChunks(v.GetBool(cfgKludgeBypassContentEncodingCheckInChunks))
settings.setClientCut(v.GetBool(cfgClientCut))
settings.initPlacementPolicy(log.logger, v)
return settings
}
func (s *appSettings) BypassContentEncodingInChunks() bool {
return s.bypassContentEncodingInChunks.Load()
s.mu.RLock()
defer s.mu.RUnlock()
return s.bypassContentEncodingInChunks
}
func (s *appSettings) setBypassContentEncodingInChunks(bypass bool) {
s.bypassContentEncodingInChunks.Store(bypass)
s.mu.Lock()
s.bypassContentEncodingInChunks = bypass
s.mu.Unlock()
}
func (s *appSettings) ClientCut() bool {
return s.clientCut.Load()
s.mu.RLock()
defer s.mu.RUnlock()
return s.clientCut
}
func (s *appSettings) setClientCut(clientCut bool) {
s.clientCut.Store(clientCut)
s.mu.Lock()
s.clientCut = clientCut
s.mu.Unlock()
}
func (s *appSettings) initPlacementPolicy(l *zap.Logger, v *viper.Viper) {
defaultPolicy := fetchDefaultPolicy(l, v)
regionMap := fetchRegionMappingPolicies(l, v)
defaultCopies := fetchDefaultCopiesNumbers(l, v)
copiesNumbers := fetchCopiesNumbers(l, v)
s.mu.Lock()
defer s.mu.Unlock()
s.defaultPolicy = defaultPolicy
s.regionMap = regionMap
s.defaultCopiesNumbers = defaultCopies
s.copiesNumbers = copiesNumbers
}
func (s *appSettings) DefaultPlacementPolicy() netmap.PlacementPolicy {
s.mu.RLock()
defer s.mu.RUnlock()
return s.defaultPolicy
}
func (s *appSettings) PlacementPolicy(name string) (netmap.PlacementPolicy, bool) {
s.mu.RLock()
policy, ok := s.regionMap[name]
s.mu.RUnlock()
return policy, ok
}
func (s *appSettings) CopiesNumbers(locationConstraint string) ([]uint32, bool) {
s.mu.RLock()
copiesNumbers, ok := s.copiesNumbers[locationConstraint]
s.mu.RUnlock()
return copiesNumbers, ok
}
func (s *appSettings) DefaultCopiesNumbers() []uint32 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.defaultCopiesNumbers
}
func (s *appSettings) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder {
dec := xml.NewDecoder(r)
s.mu.RLock()
if s.defaultXMLNSForCompleteMultipart {
dec.DefaultSpace = awsDefaultNamespace
}
s.mu.RUnlock()
return dec
}
func (s *appSettings) useDefaultNamespaceForCompleteMultipart(useDefaultNamespace bool) {
s.mu.Lock()
s.defaultXMLNSForCompleteMultipart = useDefaultNamespace
s.mu.Unlock()
}
func (s *appSettings) DefaultMaxAge() int {
return s.defaultMaxAge
}
func (s *appSettings) NotificatorEnabled() bool {
return s.notificatorEnabled
}
func (s *appSettings) ResolveZoneList() []string {
return s.resolveZoneList
}
func (s *appSettings) IsResolveListAllow() bool {
return s.isResolveListAllow
}
func (s *appSettings) CompleteMultipartKeepalive() time.Duration {
return s.completeMultipartKeepalive
}
func (a *App) initAPI(ctx context.Context) {
@ -348,55 +450,6 @@ func getPools(ctx context.Context, logger *zap.Logger, cfg *viper.Viper) (*pool.
return p, treePool, key
}
func newPlacementPolicy(l *zap.Logger, v *viper.Viper) *placementPolicy {
var policies placementPolicy
policies.update(l, v)
return &policies
}
func (p *placementPolicy) DefaultPlacementPolicy() netmap.PlacementPolicy {
p.mu.RLock()
defer p.mu.RUnlock()
return p.defaultPolicy
}
func (p *placementPolicy) PlacementPolicy(name string) (netmap.PlacementPolicy, bool) {
p.mu.RLock()
policy, ok := p.regionMap[name]
p.mu.RUnlock()
return policy, ok
}
func (p *placementPolicy) CopiesNumbers(locationConstraint string) ([]uint32, bool) {
p.mu.RLock()
copiesNumbers, ok := p.copiesNumbers[locationConstraint]
p.mu.RUnlock()
return copiesNumbers, ok
}
func (p *placementPolicy) DefaultCopiesNumbers() []uint32 {
p.mu.RLock()
defer p.mu.RUnlock()
return p.defaultCopiesNumbers
}
func (p *placementPolicy) update(l *zap.Logger, v *viper.Viper) {
defaultPolicy := fetchDefaultPolicy(l, v)
regionMap := fetchRegionMappingPolicies(l, v)
defaultCopies := fetchDefaultCopiesNumbers(l, v)
copiesNumbers := fetchCopiesNumbers(l, v)
p.mu.Lock()
defer p.mu.Unlock()
p.defaultPolicy = defaultPolicy
p.regionMap = regionMap
p.defaultCopiesNumbers = defaultCopies
p.copiesNumbers = copiesNumbers
}
func remove(list []string, element string) []string {
for i, item := range list {
if item == element {
@ -531,9 +584,9 @@ func (a *App) updateSettings() {
a.settings.logLevel.SetLevel(lvl)
}
a.settings.policies.update(a.log, a.cfg)
a.settings.initPlacementPolicy(a.log, a.cfg)
a.settings.xmlDecoder.UseDefaultNamespaceForCompleteMultipart(a.cfg.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload))
a.settings.useDefaultNamespaceForCompleteMultipart(a.cfg.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload))
a.settings.setBypassContentEncodingInChunks(a.cfg.GetBool(cfgKludgeBypassContentEncodingCheckInChunks))
a.settings.setClientCut(a.cfg.GetBool(cfgClientCut))
}
@ -664,24 +717,8 @@ func getAccessBoxCacheConfig(v *viper.Viper, l *zap.Logger) *cache.Config {
}
func (a *App) initHandler() {
cfg := &handler.Config{
Policy: a.settings.policies,
DefaultMaxAge: fetchDefaultMaxAge(a.cfg, a.log),
NotificatorEnabled: a.cfg.GetBool(cfgEnableNATS),
XMLDecoder: a.settings.xmlDecoder,
}
cfg.ResolveZoneList = a.cfg.GetStringSlice(cfgResolveBucketAllow)
cfg.IsResolveListAllow = len(cfg.ResolveZoneList) > 0
if !cfg.IsResolveListAllow {
cfg.ResolveZoneList = a.cfg.GetStringSlice(cfgResolveBucketDeny)
}
cfg.CompleteMultipartKeepalive = a.cfg.GetDuration(cfgKludgeCompleteMultipartUploadKeepalive)
cfg.Kludge = a.settings
var err error
a.api, err = handler.New(a.log, a.obj, a.nc, cfg)
a.api, err = handler.New(a.log, a.obj, a.nc, a.settings)
if err != nil {
a.log.Fatal(logs.CouldNotInitializeAPIHandler, zap.Error(err))
}

View file

@ -1,4 +1,4 @@
package xml
package main
import (
"bytes"
@ -35,44 +35,56 @@ func TestDefaultNamespace(t *testing.T) {
`
for _, tc := range []struct {
provider *DecoderProvider
settings *appSettings
input string
err bool
}{
{
provider: NewDecoderProvider(false),
input: xmlBodyWithNamespace,
err: false,
settings: &appSettings{
defaultXMLNSForCompleteMultipart: false,
},
input: xmlBodyWithNamespace,
err: false,
},
{
provider: NewDecoderProvider(false),
input: xmlBody,
err: true,
settings: &appSettings{
defaultXMLNSForCompleteMultipart: false,
},
input: xmlBody,
err: true,
},
{
provider: NewDecoderProvider(false),
input: xmlBodyWithInvalidNamespace,
err: true,
settings: &appSettings{
defaultXMLNSForCompleteMultipart: false,
},
input: xmlBodyWithInvalidNamespace,
err: true,
},
{
provider: NewDecoderProvider(true),
input: xmlBodyWithNamespace,
err: false,
settings: &appSettings{
defaultXMLNSForCompleteMultipart: true,
},
input: xmlBodyWithNamespace,
err: false,
},
{
provider: NewDecoderProvider(true),
input: xmlBody,
err: false,
settings: &appSettings{
defaultXMLNSForCompleteMultipart: true,
},
input: xmlBody,
err: false,
},
{
provider: NewDecoderProvider(true),
input: xmlBodyWithInvalidNamespace,
err: true,
settings: &appSettings{
defaultXMLNSForCompleteMultipart: true,
},
input: xmlBodyWithInvalidNamespace,
err: true,
},
} {
t.Run("", func(t *testing.T) {
model := new(handler.CompleteMultipartUpload)
err := tc.provider.NewCompleteMultipartDecoder(bytes.NewBufferString(tc.input)).Decode(model)
err := tc.settings.NewCompleteMultipartDecoder(bytes.NewBufferString(tc.input)).Decode(model)
if tc.err {
require.Error(t, err)
} else {

View file

@ -1,38 +0,0 @@
package xml
import (
"encoding/xml"
"io"
"sync"
)
const awsDefaultNamespace = "http://s3.amazonaws.com/doc/2006-03-01/"
type DecoderProvider struct {
mu sync.RWMutex
defaultXMLNSForCompleteMultipart bool
}
func NewDecoderProvider(defaultNamespace bool) *DecoderProvider {
return &DecoderProvider{
defaultXMLNSForCompleteMultipart: defaultNamespace,
}
}
func (d *DecoderProvider) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder {
dec := xml.NewDecoder(r)
d.mu.RLock()
if d.defaultXMLNSForCompleteMultipart {
dec.DefaultSpace = awsDefaultNamespace
}
d.mu.RUnlock()
return dec
}
func (d *DecoderProvider) UseDefaultNamespaceForCompleteMultipart(useDefaultNamespace bool) {
d.mu.Lock()
d.defaultXMLNSForCompleteMultipart = useDefaultNamespace
d.mu.Unlock()
}