diff --git a/api/handler/api.go b/api/handler/api.go index e6aea4a..4835cab 100644 --- a/api/handler/api.go +++ b/api/handler/api.go @@ -21,7 +21,7 @@ type ( log *zap.Logger obj layer.Client notificator Notificator - cfg *Config + cfg Config } Notificator interface { @@ -30,29 +30,17 @@ type ( } // Config contains data which handler needs to keep. - Config struct { - Policy PlacementPolicy - XMLDecoder XMLDecoderProvider - DefaultMaxAge int - NotificatorEnabled bool - ResolveZoneList []string - IsResolveListAllow bool // True if ResolveZoneList contains allowed zones - CompleteMultipartKeepalive time.Duration - Kludge KludgeSettings - } - - PlacementPolicy interface { + Config interface { DefaultPlacementPolicy() netmap.PlacementPolicy PlacementPolicy(string) (netmap.PlacementPolicy, bool) CopiesNumbers(string) ([]uint32, bool) DefaultCopiesNumbers() []uint32 - } - - XMLDecoderProvider interface { NewCompleteMultipartDecoder(io.Reader) *xml.Decoder - } - - KludgeSettings interface { + DefaultMaxAge() int + NotificatorEnabled() bool + ResolveZoneList() []string + IsResolveListAllow() bool + CompleteMultipartKeepalive() time.Duration BypassContentEncodingInChunks() bool } ) @@ -60,7 +48,7 @@ type ( var _ api.Handler = (*handler)(nil) // New creates new api.Handler using given logger and client. -func New(log *zap.Logger, obj layer.Client, notificator Notificator, cfg *Config) (api.Handler, error) { +func New(log *zap.Logger, obj layer.Client, notificator Notificator, cfg Config) (api.Handler, error) { switch { case obj == nil: return nil, errors.New("empty FrostFS Object Layer") @@ -68,7 +56,7 @@ func New(log *zap.Logger, obj layer.Client, notificator Notificator, cfg *Config return nil, errors.New("empty logger") } - if !cfg.NotificatorEnabled { + if !cfg.NotificatorEnabled() { log.Warn(logs.NotificatorIsDisabledS3WontProduceNotificationEvents) } else if notificator == nil { return nil, errors.New("empty notificator") @@ -96,12 +84,12 @@ func (h *handler) pickCopiesNumbers(metadata map[string]string, locationConstrai return result, nil } - copiesNumbers, ok := h.cfg.Policy.CopiesNumbers(locationConstraint) + copiesNumbers, ok := h.cfg.CopiesNumbers(locationConstraint) if ok { return copiesNumbers, nil } - return h.cfg.Policy.DefaultCopiesNumbers(), nil + return h.cfg.DefaultCopiesNumbers(), nil } func parseCopiesNumbers(copiesNumbersStr string) ([]uint32, error) { diff --git a/api/handler/api_test.go b/api/handler/api_test.go index 7b51e28..d6d68ab 100644 --- a/api/handler/api_test.go +++ b/api/handler/api_test.go @@ -12,11 +12,9 @@ func TestCopiesNumberPicker(t *testing.T) { locationConstraint2 := "two" locationConstraints[locationConstraint1] = []uint32{2, 3, 4} - config := &Config{ - Policy: &placementPolicyMock{ - copiesNumbers: locationConstraints, - defaultCopiesNumbers: []uint32{1}, - }, + config := &configMock{ + copiesNumbers: locationConstraints, + defaultCopiesNumbers: []uint32{1}, } h := handler{ cfg: config, diff --git a/api/handler/cors.go b/api/handler/cors.go index fd0b7b6..7a623b2 100644 --- a/api/handler/cors.go +++ b/api/handler/cors.go @@ -194,7 +194,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) { if rule.MaxAgeSeconds > 0 || rule.MaxAgeSeconds == -1 { w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(rule.MaxAgeSeconds)) } else { - w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(h.cfg.DefaultMaxAge)) + w.Header().Set(api.AccessControlMaxAge, strconv.Itoa(h.cfg.DefaultMaxAge())) } if o != wildcard { w.Header().Set(api.AccessControlAllowCredentials, "true") diff --git a/api/handler/handlers_test.go b/api/handler/handlers_test.go index c864f5e..2302303 100644 --- a/api/handler/handlers_test.go +++ b/api/handler/handlers_test.go @@ -37,7 +37,7 @@ type handlerContext struct { tp *layer.TestFrostFS tree *tree.Tree context context.Context - kludge *kludgeSettingsMock + config *configMock layerFeatures *layer.FeatureSettingsMock } @@ -58,41 +58,56 @@ func (hc *handlerContext) Context() context.Context { return hc.context } -type placementPolicyMock struct { - defaultPolicy netmap.PlacementPolicy - copiesNumbers map[string][]uint32 - defaultCopiesNumbers []uint32 -} - -func (p *placementPolicyMock) DefaultPlacementPolicy() netmap.PlacementPolicy { - return p.defaultPolicy -} - -func (p *placementPolicyMock) PlacementPolicy(string) (netmap.PlacementPolicy, bool) { - return netmap.PlacementPolicy{}, false -} - -func (p *placementPolicyMock) CopiesNumbers(locationConstraint string) ([]uint32, bool) { - result, ok := p.copiesNumbers[locationConstraint] - return result, ok -} - -func (p *placementPolicyMock) DefaultCopiesNumbers() []uint32 { - return p.defaultCopiesNumbers -} - -type xmlDecoderProviderMock struct{} - -func (p *xmlDecoderProviderMock) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder { - return xml.NewDecoder(r) -} - -type kludgeSettingsMock struct { +type configMock struct { + defaultPolicy netmap.PlacementPolicy + copiesNumbers map[string][]uint32 + defaultCopiesNumbers []uint32 bypassContentEncodingInChunks bool } -func (k *kludgeSettingsMock) BypassContentEncodingInChunks() bool { - return k.bypassContentEncodingInChunks +func (c *configMock) DefaultPlacementPolicy() netmap.PlacementPolicy { + return c.defaultPolicy +} + +func (c *configMock) PlacementPolicy(string) (netmap.PlacementPolicy, bool) { + return netmap.PlacementPolicy{}, false +} + +func (c *configMock) CopiesNumbers(locationConstraint string) ([]uint32, bool) { + result, ok := c.copiesNumbers[locationConstraint] + return result, ok +} + +func (c *configMock) DefaultCopiesNumbers() []uint32 { + return c.defaultCopiesNumbers +} + +func (c *configMock) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder { + return xml.NewDecoder(r) +} + +func (c *configMock) BypassContentEncodingInChunks() bool { + return c.bypassContentEncodingInChunks +} + +func (c *configMock) DefaultMaxAge() int { + return 0 +} + +func (c *configMock) NotificatorEnabled() bool { + return false +} + +func (c *configMock) ResolveZoneList() []string { + return []string{} +} + +func (c *configMock) IsResolveListAllow() bool { + return false +} + +func (c *configMock) CompleteMultipartKeepalive() time.Duration { + return time.Duration(0) } func prepareHandlerContext(t *testing.T) *handlerContext { @@ -139,16 +154,13 @@ func prepareHandlerContextBase(t *testing.T, minCache bool) *handlerContext { err = pp.DecodeString("REP 1") require.NoError(t, err) - kludge := &kludgeSettingsMock{} - + cfg := &configMock{ + defaultPolicy: pp, + } h := &handler{ log: l, obj: layer.NewLayer(l, tp, layerCfg), - cfg: &Config{ - Policy: &placementPolicyMock{defaultPolicy: pp}, - XMLDecoder: &xmlDecoderProviderMock{}, - Kludge: kludge, - }, + cfg: cfg, } return &handlerContext{ @@ -158,7 +170,7 @@ func prepareHandlerContextBase(t *testing.T, minCache bool) *handlerContext { tp: tp, tree: treeMock, context: middleware.SetBoxData(context.Background(), newTestAccessBox(t, key)), - kludge: kludge, + config: cfg, layerFeatures: features, } diff --git a/api/handler/head.go b/api/handler/head.go index a2f1f14..1561791 100644 --- a/api/handler/head.go +++ b/api/handler/head.go @@ -135,7 +135,7 @@ func (h *handler) HeadBucketHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set(api.ContainerID, bktInfo.CID.EncodeToString()) w.Header().Set(api.AmzBucketRegion, bktInfo.LocationConstraint) - if isAvailableToResolve(bktInfo.Zone, h.cfg.ResolveZoneList, h.cfg.IsResolveListAllow) { + if isAvailableToResolve(bktInfo.Zone, h.cfg.ResolveZoneList(), h.cfg.IsResolveListAllow()) { w.Header().Set(api.ContainerName, bktInfo.Name) w.Header().Set(api.ContainerZone, bktInfo.Zone) } diff --git a/api/handler/multipart_upload.go b/api/handler/multipart_upload.go index 04ec8b2..ce237bc 100644 --- a/api/handler/multipart_upload.go +++ b/api/handler/multipart_upload.go @@ -406,7 +406,7 @@ func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http. ) reqBody := new(CompleteMultipartUpload) - if err = h.cfg.XMLDecoder.NewCompleteMultipartDecoder(r.Body).Decode(reqBody); err != nil { + if err = h.cfg.NewCompleteMultipartDecoder(r.Body).Decode(reqBody); err != nil { h.logAndSendError(w, "could not read complete multipart upload xml", reqInfo, errors.GetAPIError(errors.ErrMalformedXML), additional...) return @@ -424,7 +424,7 @@ func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http. // Next operations might take some time, so we want to keep client's // connection alive. To do so, gateway sends periodic white spaces // back to the client the same way as Amazon S3 service does. - stopPeriodicResponseWriter := periodicXMLWriter(w, h.cfg.CompleteMultipartKeepalive) + stopPeriodicResponseWriter := periodicXMLWriter(w, h.cfg.CompleteMultipartKeepalive()) // Start complete multipart upload which may take some time to fetch object // and re-upload it part by part. diff --git a/api/handler/notifications.go b/api/handler/notifications.go index d52d0c8..e0e7b49 100644 --- a/api/handler/notifications.go +++ b/api/handler/notifications.go @@ -155,7 +155,7 @@ func (h *handler) GetBucketNotificationHandler(w http.ResponseWriter, r *http.Re } func (h *handler) sendNotifications(ctx context.Context, p *SendNotificationParams) error { - if !h.cfg.NotificatorEnabled { + if !h.cfg.NotificatorEnabled() { return nil } @@ -198,7 +198,7 @@ func (h *handler) checkBucketConfiguration(ctx context.Context, conf *data.Notif return } - if h.cfg.NotificatorEnabled { + if h.cfg.NotificatorEnabled() { if err = h.notificator.SendTestNotification(q.QueueArn, r.BucketName, r.RequestID, r.Host, layer.TimeNow(ctx)); err != nil { return } diff --git a/api/handler/put.go b/api/handler/put.go index 5c02078..ae3f804 100644 --- a/api/handler/put.go +++ b/api/handler/put.go @@ -348,7 +348,7 @@ func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, error) { } r.Header.Set(api.ContentEncoding, strings.Join(resultContentEncoding, ",")) - if !chunkedEncoding && !h.cfg.Kludge.BypassContentEncodingInChunks() { + if !chunkedEncoding && !h.cfg.BypassContentEncodingInChunks() { return nil, fmt.Errorf("%w: request is not chunk encoded, encodings '%s'", errors.GetAPIError(errors.ErrInvalidEncodingMethod), strings.Join(encodings, ",")) } @@ -797,7 +797,7 @@ func (h *handler) CreateBucketHandler(w http.ResponseWriter, r *http.Request) { } func (h handler) setPolicy(prm *layer.CreateBucketParams, locationConstraint string, userPolicies []*accessbox.ContainerPolicy) error { - prm.Policy = h.cfg.Policy.DefaultPlacementPolicy() + prm.Policy = h.cfg.DefaultPlacementPolicy() prm.LocationConstraint = locationConstraint if locationConstraint == "" { @@ -811,7 +811,7 @@ func (h handler) setPolicy(prm *layer.CreateBucketParams, locationConstraint str } } - if policy, ok := h.cfg.Policy.PlacementPolicy(locationConstraint); ok { + if policy, ok := h.cfg.PlacementPolicy(locationConstraint); ok { prm.Policy = policy return nil } diff --git a/api/handler/put_test.go b/api/handler/put_test.go index 1e21047..30ddf5c 100644 --- a/api/handler/put_test.go +++ b/api/handler/put_test.go @@ -230,7 +230,7 @@ func TestPutChunkedTestContentEncoding(t *testing.T) { hc.Handler().PutObjectHandler(w, req) assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrInvalidEncodingMethod)) - hc.kludge.bypassContentEncodingInChunks = true + hc.config.bypassContentEncodingInChunks = true w, req, _ = getChunkedRequest(hc.context, t, bktName, objName) req.Header.Set(api.ContentEncoding, "gzip") hc.Handler().PutObjectHandler(w, req) diff --git a/cmd/s3-gw/app.go b/cmd/s3-gw/app.go index a84ae16..daf5338 100644 --- a/cmd/s3-gw/app.go +++ b/cmd/s3-gw/app.go @@ -3,13 +3,14 @@ package main import ( "context" "encoding/hex" + "encoding/xml" "fmt" + "io" "net/http" "os" "os/signal" "runtime/debug" "sync" - "sync/atomic" "syscall" "time" @@ -27,7 +28,6 @@ import ( "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/version" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/wallet" - "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/xml" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/metrics" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/pkg/service/tree" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" @@ -42,6 +42,8 @@ import ( "google.golang.org/grpc" ) +const awsDefaultNamespace = "http://s3.amazonaws.com/doc/2006-03-01/" + type ( // App is the main application structure. App struct { @@ -67,12 +69,22 @@ type ( } appSettings struct { - logLevel zap.AtomicLevel - policies *placementPolicy - xmlDecoder *xml.DecoderProvider - maxClient maxClientsConfig - bypassContentEncodingInChunks atomic.Bool - clientCut atomic.Bool + logLevel zap.AtomicLevel + maxClient maxClientsConfig + defaultMaxAge int + notificatorEnabled bool + resolveZoneList []string + isResolveListAllow bool // True if ResolveZoneList contains allowed zones + completeMultipartKeepalive time.Duration + + mu sync.RWMutex + defaultPolicy netmap.PlacementPolicy + regionMap map[string]netmap.PlacementPolicy + copiesNumbers map[string][]uint32 + defaultCopiesNumbers []uint32 + defaultXMLNSForCompleteMultipart bool + bypassContentEncodingInChunks bool + clientCut bool } maxClientsConfig struct { @@ -84,14 +96,6 @@ type ( logger *zap.Logger lvl zap.AtomicLevel } - - placementPolicy struct { - mu sync.RWMutex - defaultPolicy netmap.PlacementPolicy - regionMap map[string]netmap.PlacementPolicy - copiesNumbers map[string][]uint32 - defaultCopiesNumbers []uint32 - } ) func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App { @@ -168,32 +172,130 @@ func (a *App) initLayer(ctx context.Context) { func newAppSettings(log *Logger, v *viper.Viper) *appSettings { settings := &appSettings{ - logLevel: log.lvl, - policies: newPlacementPolicy(log.logger, v), - xmlDecoder: xml.NewDecoderProvider(v.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload)), - maxClient: newMaxClients(v), + logLevel: log.lvl, + maxClient: newMaxClients(v), + defaultXMLNSForCompleteMultipart: v.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload), + defaultMaxAge: fetchDefaultMaxAge(v, log.logger), + notificatorEnabled: v.GetBool(cfgEnableNATS), + completeMultipartKeepalive: v.GetDuration(cfgKludgeCompleteMultipartUploadKeepalive), + } + + settings.resolveZoneList = v.GetStringSlice(cfgResolveBucketAllow) + settings.isResolveListAllow = len(settings.resolveZoneList) > 0 + if !settings.isResolveListAllow { + settings.resolveZoneList = v.GetStringSlice(cfgResolveBucketDeny) } settings.setBypassContentEncodingInChunks(v.GetBool(cfgKludgeBypassContentEncodingCheckInChunks)) settings.setClientCut(v.GetBool(cfgClientCut)) + settings.initPlacementPolicy(log.logger, v) return settings } func (s *appSettings) BypassContentEncodingInChunks() bool { - return s.bypassContentEncodingInChunks.Load() + s.mu.RLock() + defer s.mu.RUnlock() + return s.bypassContentEncodingInChunks } func (s *appSettings) setBypassContentEncodingInChunks(bypass bool) { - s.bypassContentEncodingInChunks.Store(bypass) + s.mu.Lock() + s.bypassContentEncodingInChunks = bypass + s.mu.Unlock() } func (s *appSettings) ClientCut() bool { - return s.clientCut.Load() + s.mu.RLock() + defer s.mu.RUnlock() + return s.clientCut } func (s *appSettings) setClientCut(clientCut bool) { - s.clientCut.Store(clientCut) + s.mu.Lock() + s.clientCut = clientCut + s.mu.Unlock() +} + +func (s *appSettings) initPlacementPolicy(l *zap.Logger, v *viper.Viper) { + defaultPolicy := fetchDefaultPolicy(l, v) + regionMap := fetchRegionMappingPolicies(l, v) + defaultCopies := fetchDefaultCopiesNumbers(l, v) + copiesNumbers := fetchCopiesNumbers(l, v) + + s.mu.Lock() + defer s.mu.Unlock() + + s.defaultPolicy = defaultPolicy + s.regionMap = regionMap + s.defaultCopiesNumbers = defaultCopies + s.copiesNumbers = copiesNumbers +} + +func (s *appSettings) DefaultPlacementPolicy() netmap.PlacementPolicy { + s.mu.RLock() + defer s.mu.RUnlock() + return s.defaultPolicy +} + +func (s *appSettings) PlacementPolicy(name string) (netmap.PlacementPolicy, bool) { + s.mu.RLock() + policy, ok := s.regionMap[name] + s.mu.RUnlock() + + return policy, ok +} + +func (s *appSettings) CopiesNumbers(locationConstraint string) ([]uint32, bool) { + s.mu.RLock() + copiesNumbers, ok := s.copiesNumbers[locationConstraint] + s.mu.RUnlock() + + return copiesNumbers, ok +} + +func (s *appSettings) DefaultCopiesNumbers() []uint32 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.defaultCopiesNumbers +} + +func (s *appSettings) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder { + dec := xml.NewDecoder(r) + + s.mu.RLock() + if s.defaultXMLNSForCompleteMultipart { + dec.DefaultSpace = awsDefaultNamespace + } + s.mu.RUnlock() + + return dec +} + +func (s *appSettings) useDefaultNamespaceForCompleteMultipart(useDefaultNamespace bool) { + s.mu.Lock() + s.defaultXMLNSForCompleteMultipart = useDefaultNamespace + s.mu.Unlock() +} + +func (s *appSettings) DefaultMaxAge() int { + return s.defaultMaxAge +} + +func (s *appSettings) NotificatorEnabled() bool { + return s.notificatorEnabled +} + +func (s *appSettings) ResolveZoneList() []string { + return s.resolveZoneList +} + +func (s *appSettings) IsResolveListAllow() bool { + return s.isResolveListAllow +} + +func (s *appSettings) CompleteMultipartKeepalive() time.Duration { + return s.completeMultipartKeepalive } func (a *App) initAPI(ctx context.Context) { @@ -348,55 +450,6 @@ func getPools(ctx context.Context, logger *zap.Logger, cfg *viper.Viper) (*pool. return p, treePool, key } -func newPlacementPolicy(l *zap.Logger, v *viper.Viper) *placementPolicy { - var policies placementPolicy - policies.update(l, v) - return &policies -} - -func (p *placementPolicy) DefaultPlacementPolicy() netmap.PlacementPolicy { - p.mu.RLock() - defer p.mu.RUnlock() - return p.defaultPolicy -} - -func (p *placementPolicy) PlacementPolicy(name string) (netmap.PlacementPolicy, bool) { - p.mu.RLock() - policy, ok := p.regionMap[name] - p.mu.RUnlock() - - return policy, ok -} - -func (p *placementPolicy) CopiesNumbers(locationConstraint string) ([]uint32, bool) { - p.mu.RLock() - copiesNumbers, ok := p.copiesNumbers[locationConstraint] - p.mu.RUnlock() - - return copiesNumbers, ok -} - -func (p *placementPolicy) DefaultCopiesNumbers() []uint32 { - p.mu.RLock() - defer p.mu.RUnlock() - return p.defaultCopiesNumbers -} - -func (p *placementPolicy) update(l *zap.Logger, v *viper.Viper) { - defaultPolicy := fetchDefaultPolicy(l, v) - regionMap := fetchRegionMappingPolicies(l, v) - defaultCopies := fetchDefaultCopiesNumbers(l, v) - copiesNumbers := fetchCopiesNumbers(l, v) - - p.mu.Lock() - defer p.mu.Unlock() - - p.defaultPolicy = defaultPolicy - p.regionMap = regionMap - p.defaultCopiesNumbers = defaultCopies - p.copiesNumbers = copiesNumbers -} - func remove(list []string, element string) []string { for i, item := range list { if item == element { @@ -531,9 +584,9 @@ func (a *App) updateSettings() { a.settings.logLevel.SetLevel(lvl) } - a.settings.policies.update(a.log, a.cfg) + a.settings.initPlacementPolicy(a.log, a.cfg) - a.settings.xmlDecoder.UseDefaultNamespaceForCompleteMultipart(a.cfg.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload)) + a.settings.useDefaultNamespaceForCompleteMultipart(a.cfg.GetBool(cfgKludgeUseDefaultXMLNSForCompleteMultipartUpload)) a.settings.setBypassContentEncodingInChunks(a.cfg.GetBool(cfgKludgeBypassContentEncodingCheckInChunks)) a.settings.setClientCut(a.cfg.GetBool(cfgClientCut)) } @@ -664,24 +717,8 @@ func getAccessBoxCacheConfig(v *viper.Viper, l *zap.Logger) *cache.Config { } func (a *App) initHandler() { - cfg := &handler.Config{ - Policy: a.settings.policies, - DefaultMaxAge: fetchDefaultMaxAge(a.cfg, a.log), - NotificatorEnabled: a.cfg.GetBool(cfgEnableNATS), - XMLDecoder: a.settings.xmlDecoder, - } - - cfg.ResolveZoneList = a.cfg.GetStringSlice(cfgResolveBucketAllow) - cfg.IsResolveListAllow = len(cfg.ResolveZoneList) > 0 - if !cfg.IsResolveListAllow { - cfg.ResolveZoneList = a.cfg.GetStringSlice(cfgResolveBucketDeny) - } - - cfg.CompleteMultipartKeepalive = a.cfg.GetDuration(cfgKludgeCompleteMultipartUploadKeepalive) - cfg.Kludge = a.settings - var err error - a.api, err = handler.New(a.log, a.obj, a.nc, cfg) + a.api, err = handler.New(a.log, a.obj, a.nc, a.settings) if err != nil { a.log.Fatal(logs.CouldNotInitializeAPIHandler, zap.Error(err)) } diff --git a/internal/xml/decoder_test.go b/cmd/s3-gw/decoder_test.go similarity index 55% rename from internal/xml/decoder_test.go rename to cmd/s3-gw/decoder_test.go index 9832692..d9022d0 100644 --- a/internal/xml/decoder_test.go +++ b/cmd/s3-gw/decoder_test.go @@ -1,4 +1,4 @@ -package xml +package main import ( "bytes" @@ -35,44 +35,56 @@ func TestDefaultNamespace(t *testing.T) { ` for _, tc := range []struct { - provider *DecoderProvider + settings *appSettings input string err bool }{ { - provider: NewDecoderProvider(false), - input: xmlBodyWithNamespace, - err: false, + settings: &appSettings{ + defaultXMLNSForCompleteMultipart: false, + }, + input: xmlBodyWithNamespace, + err: false, }, { - provider: NewDecoderProvider(false), - input: xmlBody, - err: true, + settings: &appSettings{ + defaultXMLNSForCompleteMultipart: false, + }, + input: xmlBody, + err: true, }, { - provider: NewDecoderProvider(false), - input: xmlBodyWithInvalidNamespace, - err: true, + settings: &appSettings{ + defaultXMLNSForCompleteMultipart: false, + }, + input: xmlBodyWithInvalidNamespace, + err: true, }, { - provider: NewDecoderProvider(true), - input: xmlBodyWithNamespace, - err: false, + settings: &appSettings{ + defaultXMLNSForCompleteMultipart: true, + }, + input: xmlBodyWithNamespace, + err: false, }, { - provider: NewDecoderProvider(true), - input: xmlBody, - err: false, + settings: &appSettings{ + defaultXMLNSForCompleteMultipart: true, + }, + input: xmlBody, + err: false, }, { - provider: NewDecoderProvider(true), - input: xmlBodyWithInvalidNamespace, - err: true, + settings: &appSettings{ + defaultXMLNSForCompleteMultipart: true, + }, + input: xmlBodyWithInvalidNamespace, + err: true, }, } { t.Run("", func(t *testing.T) { model := new(handler.CompleteMultipartUpload) - err := tc.provider.NewCompleteMultipartDecoder(bytes.NewBufferString(tc.input)).Decode(model) + err := tc.settings.NewCompleteMultipartDecoder(bytes.NewBufferString(tc.input)).Decode(model) if tc.err { require.Error(t, err) } else { diff --git a/internal/xml/decoder.go b/internal/xml/decoder.go deleted file mode 100644 index 0066c08..0000000 --- a/internal/xml/decoder.go +++ /dev/null @@ -1,38 +0,0 @@ -package xml - -import ( - "encoding/xml" - "io" - "sync" -) - -const awsDefaultNamespace = "http://s3.amazonaws.com/doc/2006-03-01/" - -type DecoderProvider struct { - mu sync.RWMutex - defaultXMLNSForCompleteMultipart bool -} - -func NewDecoderProvider(defaultNamespace bool) *DecoderProvider { - return &DecoderProvider{ - defaultXMLNSForCompleteMultipart: defaultNamespace, - } -} - -func (d *DecoderProvider) NewCompleteMultipartDecoder(r io.Reader) *xml.Decoder { - dec := xml.NewDecoder(r) - - d.mu.RLock() - if d.defaultXMLNSForCompleteMultipart { - dec.DefaultSpace = awsDefaultNamespace - } - d.mu.RUnlock() - - return dec -} - -func (d *DecoderProvider) UseDefaultNamespaceForCompleteMultipart(useDefaultNamespace bool) { - d.mu.Lock() - d.defaultXMLNSForCompleteMultipart = useDefaultNamespace - d.mu.Unlock() -}