From 2a41929be30b46c58d1ffc03b486a7fbefb0ae8e Mon Sep 17 00:00:00 2001 From: Denis Kirillov Date: Mon, 12 Sep 2022 16:46:55 +0300 Subject: [PATCH] [#702] Reload resolvers and TLS certs on SIGHUP Signed-off-by: Denis Kirillov --- api/handler/handlers_test.go | 2 +- api/layer/layer.go | 9 +- api/resolver/resolver.go | 134 +++++++++++++++++------- cmd/s3-gw/app.go | 198 ++++++++++++++++++++++++----------- 4 files changed, 239 insertions(+), 104 deletions(-) diff --git a/api/handler/handlers_test.go b/api/handler/handlers_test.go index 67e3f2530..f6a53bd89 100644 --- a/api/handler/handlers_test.go +++ b/api/handler/handlers_test.go @@ -56,7 +56,7 @@ func prepareHandlerContext(t *testing.T) *handlerContext { l := zap.NewExample() tp := layer.NewTestNeoFS() - testResolver := &resolver.BucketResolver{Name: "test_resolver"} + testResolver := &resolver.Resolver{Name: "test_resolver"} testResolver.SetResolveFunc(func(_ context.Context, name string) (cid.ID, error) { return tp.ContainerID(name) }) diff --git a/api/layer/layer.go b/api/layer/layer.go index 91752a0d4..877a99b84 100644 --- a/api/layer/layer.go +++ b/api/layer/layer.go @@ -18,7 +18,6 @@ import ( "github.com/nspcc-dev/neofs-s3-gw/api/data" "github.com/nspcc-dev/neofs-s3-gw/api/errors" "github.com/nspcc-dev/neofs-s3-gw/api/layer/encryption" - "github.com/nspcc-dev/neofs-s3-gw/api/resolver" "github.com/nspcc-dev/neofs-s3-gw/creds/accessbox" "github.com/nspcc-dev/neofs-sdk-go/bearer" cid "github.com/nspcc-dev/neofs-sdk-go/container/id" @@ -42,11 +41,15 @@ type ( MsgHandlerFunc func(context.Context, *nats.Msg) error + BucketResolver interface { + Resolve(ctx context.Context, name string) (cid.ID, error) + } + layer struct { neoFS NeoFS log *zap.Logger anonKey AnonymousKey - resolver *resolver.BucketResolver + resolver BucketResolver ncontroller EventListener listsCache *cache.ObjectsListCache objCache *cache.ObjectsCache @@ -60,7 +63,7 @@ type ( ChainAddress string Caches *CachesConfig AnonKey AnonymousKey - Resolver *resolver.BucketResolver + Resolver BucketResolver TreeService TreeService } diff --git a/api/resolver/resolver.go b/api/resolver/resolver.go index 91bf14cd7..40d100eb5 100644 --- a/api/resolver/resolver.go +++ b/api/resolver/resolver.go @@ -2,7 +2,9 @@ package resolver import ( "context" + "errors" "fmt" + "sync" cid "github.com/nspcc-dev/neofs-sdk-go/container/id" "github.com/nspcc-dev/neofs-sdk-go/ns" @@ -13,6 +15,9 @@ const ( DNSResolver = "dns" ) +// ErrNoResolvers returns when trying to resolve container without any resolver. +var ErrNoResolvers = errors.New("no resolvers") + // NeoFS represents virtual connection to the NeoFS network. type NeoFS interface { // SystemDNS reads system DNS network parameters of the NeoFS. @@ -28,62 +33,115 @@ type Config struct { } type BucketResolver struct { - Name string - resolve func(context.Context, string) (cid.ID, error) - - next *BucketResolver + mu sync.RWMutex + resolvers []*Resolver } -func (r *BucketResolver) SetResolveFunc(fn func(context.Context, string) (cid.ID, error)) { +type Resolver struct { + Name string + resolve func(context.Context, string) (cid.ID, error) +} + +func (r *Resolver) SetResolveFunc(fn func(context.Context, string) (cid.ID, error)) { r.resolve = fn } -func (r *BucketResolver) Resolve(ctx context.Context, name string) (cid.ID, error) { - cnrID, err := r.resolve(ctx, name) - if err != nil { - if r.next != nil { - return r.next.Resolve(ctx, name) - } - return cid.ID{}, fmt.Errorf("failed resolve: %w", err) - } - return cnrID, nil +func (r *Resolver) Resolve(ctx context.Context, name string) (cid.ID, error) { + return r.resolve(ctx, name) } -func NewResolver(order []string, cfg *Config) (*BucketResolver, error) { - if len(order) == 0 { - return nil, fmt.Errorf("resolving order must not be empty") - } - - bucketResolver, err := newResolver(order[len(order)-1], cfg, nil) +func NewBucketResolver(resolverNames []string, cfg *Config) (*BucketResolver, error) { + resolvers, err := createResolvers(resolverNames, cfg) if err != nil { - return nil, fmt.Errorf("create resolver: %w", err) + return nil, err } - for i := len(order) - 2; i >= 0; i-- { - resolverName := order[i] - next := bucketResolver + return &BucketResolver{ + resolvers: resolvers, + }, nil +} - bucketResolver, err = newResolver(resolverName, cfg, next) +func createResolvers(resolverNames []string, cfg *Config) ([]*Resolver, error) { + resolvers := make([]*Resolver, len(resolverNames)) + for i, name := range resolverNames { + cnrResolver, err := newResolver(name, cfg) if err != nil { - return nil, fmt.Errorf("create resolver: %w", err) + return nil, err } + resolvers[i] = cnrResolver } - return bucketResolver, nil + return resolvers, nil } -func newResolver(name string, cfg *Config, next *BucketResolver) (*BucketResolver, error) { +func (r *BucketResolver) Resolve(ctx context.Context, bktName string) (cnrID cid.ID, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + for _, resolver := range r.resolvers { + cnrID, resolverErr := resolver.Resolve(ctx, bktName) + if resolverErr != nil { + resolverErr = fmt.Errorf("%s: %w", resolver.Name, resolverErr) + if err == nil { + err = resolverErr + } else { + err = fmt.Errorf("%s: %w", err.Error(), resolverErr) + } + continue + } + return cnrID, nil + } + + if err != nil { + return cnrID, err + } + + return cnrID, ErrNoResolvers +} + +func (r *BucketResolver) UpdateResolvers(resolverNames []string, cfg *Config) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.equals(resolverNames) { + return nil + } + + resolvers, err := createResolvers(resolverNames, cfg) + if err != nil { + return err + } + + r.resolvers = resolvers + + return nil +} + +func (r *BucketResolver) equals(resolverNames []string) bool { + if len(r.resolvers) != len(resolverNames) { + return false + } + + for i := 0; i < len(resolverNames); i++ { + if r.resolvers[i].Name != resolverNames[i] { + return false + } + } + return true +} + +func newResolver(name string, cfg *Config) (*Resolver, error) { switch name { case DNSResolver: - return NewDNSResolver(cfg.NeoFS, next) + return NewDNSResolver(cfg.NeoFS) case NNSResolver: - return NewNNSResolver(cfg.RPCAddress, next) + return NewNNSResolver(cfg.RPCAddress) default: return nil, fmt.Errorf("unknown resolver: %s", name) } } -func NewDNSResolver(neoFS NeoFS, next *BucketResolver) (*BucketResolver, error) { +func NewDNSResolver(neoFS NeoFS) (*Resolver, error) { if neoFS == nil { return nil, fmt.Errorf("pool must not be nil for DNS resolver") } @@ -104,15 +162,13 @@ func NewDNSResolver(neoFS NeoFS, next *BucketResolver) (*BucketResolver, error) return cnrID, nil } - return &BucketResolver{ - Name: DNSResolver, - + return &Resolver{ + Name: DNSResolver, resolve: resolveFunc, - next: next, }, nil } -func NewNNSResolver(address string, next *BucketResolver) (*BucketResolver, error) { +func NewNNSResolver(address string) (*Resolver, error) { if address == "" { return nil, fmt.Errorf("rpc address must not be empty for NNS resolver") } @@ -131,10 +187,8 @@ func NewNNSResolver(address string, next *BucketResolver) (*BucketResolver, erro return cnrID, nil } - return &BucketResolver{ - Name: NNSResolver, - + return &Resolver{ + Name: NNSResolver, resolve: resolveFunc, - next: next, }, nil } diff --git a/cmd/s3-gw/app.go b/cmd/s3-gw/app.go index 47cd22f43..d8e241332 100644 --- a/cmd/s3-gw/app.go +++ b/cmd/s3-gw/app.go @@ -2,7 +2,9 @@ package main import ( "context" + "crypto/tls" "encoding/hex" + "errors" "fmt" "net" "net/http" @@ -33,14 +35,16 @@ import ( type ( // App is the main application structure. App struct { - ctr auth.Center - log *zap.Logger - cfg *viper.Viper - tls *tlsConfig - obj layer.Client - api api.Handler + ctr auth.Center + log *zap.Logger + cfg *viper.Viper + pool *pool.Pool + obj layer.Client + api api.Handler - metrics *appMetrics + metrics *appMetrics + bucketResolver *resolver.BucketResolver + tlsProvider *certProvider maxClients api.MaxClients @@ -60,9 +64,13 @@ type ( lvl zap.AtomicLevel } - tlsConfig struct { - KeyFile string - CertFile string + certProvider struct { + Enabled bool + + mu sync.RWMutex + certPath string + keyPath string + cert *tls.Certificate } appMetrics struct { @@ -84,7 +92,6 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App { var ( key *keys.PrivateKey err error - tls *tlsConfig caller api.Handler ctr auth.Center obj layer.Client @@ -130,15 +137,7 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App { l.Fatal("could not load NeoFS private key", zap.Error(err)) } - if v.IsSet(cfgTLSKeyFile) && v.IsSet(cfgTLSCertFile) { - tls = &tlsConfig{ - KeyFile: v.GetString(cfgTLSKeyFile), - CertFile: v.GetString(cfgTLSCertFile), - } - } - - l.Info("using credentials", - zap.String("NeoFS", hex.EncodeToString(key.PublicKey().Bytes()))) + l.Info("using credentials", zap.String("NeoFS", hex.EncodeToString(key.PublicKey().Bytes()))) prmPool.SetKey(&key.PrivateKey) prmPool.SetNodeDialTimeout(conTimeout) @@ -175,7 +174,7 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App { l.Warn(fmt.Sprintf("resolver '%s' won't be used since '%s' isn't provided", resolver.NNSResolver, cfgRPCEndpoint)) } - bucketResolver, err := resolver.NewResolver(order, resolveCfg) + bucketResolver, err := resolver.NewBucketResolver(order, resolveCfg) if err != nil { l.Fatal("failed to form resolver", zap.Error(err)) } @@ -220,12 +219,12 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App { } app := &App{ - ctr: ctr, - log: l, - cfg: v, - obj: obj, - tls: tls, - api: caller, + ctr: ctr, + log: l, + cfg: v, + pool: conns, + obj: obj, + api: caller, webDone: make(chan struct{}, 1), wrkDone: make(chan struct{}, 1), @@ -233,16 +232,51 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App { maxClients: api.NewMaxClientsMiddleware(maxClientsCount, maxClientsDeadline), } - app.initMetrics(neofs.NewPoolStatistic(conns)) + app.initMetrics() + app.initResolver() + app.initTLSProvider() return app } -func (a *App) initMetrics(scraper StatisticScraper) { - gateMetricsProvider := newGateMetrics(scraper) +func (a *App) initMetrics() { + gateMetricsProvider := newGateMetrics(neofs.NewPoolStatistic(a.pool)) a.metrics = newAppMetrics(a.log, gateMetricsProvider, a.cfg.GetBool(cfgPrometheusEnabled)) } +func (a *App) initResolver() { + var err error + a.bucketResolver, err = resolver.NewBucketResolver(a.getResolverConfig()) + if err != nil { + a.log.Fatal("failed to create resolver", zap.Error(err)) + } +} + +func (a *App) initTLSProvider() { + a.tlsProvider = &certProvider{ + Enabled: a.cfg.IsSet(cfgTLSCertFile) || a.cfg.IsSet(cfgTLSKeyFile), + } +} + +func (a *App) getResolverConfig() ([]string, *resolver.Config) { + resolveCfg := &resolver.Config{ + NeoFS: neofs.NewResolverNeoFS(a.pool), + RPCAddress: a.cfg.GetString(cfgRPCEndpoint), + } + + order := a.cfg.GetStringSlice(cfgResolveOrder) + if resolveCfg.RPCAddress == "" { + order = remove(order, resolver.NNSResolver) + a.log.Warn(fmt.Sprintf("resolver '%s' won't be used since '%s' isn't provided", resolver.NNSResolver, cfgRPCEndpoint)) + } + + if len(order) == 0 { + a.log.Info("container resolver will be disabled because of resolvers 'resolver_order' is empty") + } + + return order, resolveCfg +} + func newAppMetrics(logger *zap.Logger, provider GateMetricsCollector, enabled bool) *appMetrics { if !enabled { logger.Warn("metrics are disabled") @@ -284,6 +318,44 @@ func (m *appMetrics) Shutdown() { m.mu.Unlock() } +func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + if !p.Enabled { + return nil, errors.New("cert provider: disabled") + } + + p.mu.RLock() + defer p.mu.RUnlock() + return p.cert, nil +} + +func (p *certProvider) UpdateCert(certPath, keyPath string) error { + if !p.Enabled { + return fmt.Errorf("tls disabled") + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err) + } + + p.mu.Lock() + p.certPath = certPath + p.keyPath = keyPath + p.cert = &cert + p.mu.Unlock() + return nil +} + +func (p *certProvider) FilePaths() (string, string) { + if !p.Enabled { + return "", "" + } + + p.mu.RLock() + defer p.mu.RUnlock() + return p.certPath, p.keyPath +} + func remove(list []string, element string) []string { for i, item := range list { if item == element { @@ -317,50 +389,48 @@ func (a *App) setHealthStatus() { // Serve runs HTTP server to handle S3 API requests. func (a *App) Serve(ctx context.Context) { - var ( - err error - lis net.Listener - lic net.ListenConfig - srv = new(http.Server) - addr = a.cfg.GetString(cfgListenAddress) - ) - - if lis, err = lic.Listen(ctx, "tcp", addr); err != nil { - a.log.Fatal("could not prepare listener", - zap.Error(err)) - } - - router := mux.NewRouter().SkipClean(true).UseEncodedPath() // Attach S3 API: domains := a.cfg.GetStringSlice(cfgListenDomains) - a.log.Info("fetch domains, prepare to use API", - zap.Strings("domains", domains)) + a.log.Info("fetch domains, prepare to use API", zap.Strings("domains", domains)) + router := mux.NewRouter().SkipClean(true).UseEncodedPath() api.Attach(router, domains, a.maxClients, a.api, a.ctr, a.log) // Use mux.Router as http.Handler + srv := new(http.Server) srv.Handler = router srv.ErrorLog = zap.NewStdLog(a.log) a.startServices() go func() { - a.log.Info("starting server", - zap.String("bind", addr)) + addr := a.cfg.GetString(cfgListenAddress) + a.log.Info("starting server", zap.String("bind", addr)) - switch a.tls { - case nil: - if err = srv.Serve(lis); err != nil && err != http.ErrServerClosed { - a.log.Fatal("listen and serve", - zap.Error(err)) + var lic net.ListenConfig + ln, err := lic.Listen(ctx, "tcp", addr) + if err != nil { + a.log.Fatal("could not prepare listener", zap.Error(err)) + } + + if a.tlsProvider.Enabled { + certFile := a.cfg.GetString(cfgTLSCertFile) + keyFile := a.cfg.GetString(cfgTLSKeyFile) + + a.log.Info("using certificate", zap.String("cert", certFile), zap.String("key", keyFile)) + if err = a.tlsProvider.UpdateCert(certFile, keyFile); err != nil { + a.log.Fatal("failed to update cert", zap.Error(err)) } - default: - a.log.Info("using certificate", - zap.String("key", a.tls.KeyFile), - zap.String("cert", a.tls.CertFile)) - if err = srv.ServeTLS(lis, a.tls.CertFile, a.tls.KeyFile); err != nil && err != http.ErrServerClosed { - a.log.Fatal("listen and serve", - zap.Error(err)) + lnTLS := tls.NewListener(ln, &tls.Config{ + GetCertificate: a.tlsProvider.GetCertificate, + }) + + if err = srv.ServeTLS(lnTLS, certFile, keyFile); err != nil && err != http.ErrServerClosed { + a.log.Fatal("listen and serve", zap.Error(err)) + } + } else { + if err = srv.Serve(ln); err != nil && err != http.ErrServerClosed { + a.log.Fatal("listen and serve", zap.Error(err)) } } }() @@ -405,6 +475,14 @@ func (a *App) configReload() { return } + if err := a.bucketResolver.UpdateResolvers(a.getResolverConfig()); err != nil { + a.log.Warn("failed to reload resolvers", zap.Error(err)) + } + + if err := a.tlsProvider.UpdateCert(a.cfg.GetString(cfgTLSCertFile), a.cfg.GetString(cfgTLSKeyFile)); err != nil { + a.log.Warn("failed to reload TLS certs", zap.Error(err)) + } + a.stopServices() a.startServices()