[#702] Reload resolvers and TLS certs on SIGHUP

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-09-12 16:46:55 +03:00 committed by Alex Vanin
parent 42893ec046
commit 2a41929be3
4 changed files with 239 additions and 104 deletions

View file

@ -56,7 +56,7 @@ func prepareHandlerContext(t *testing.T) *handlerContext {
l := zap.NewExample() l := zap.NewExample()
tp := layer.NewTestNeoFS() tp := layer.NewTestNeoFS()
testResolver := &resolver.BucketResolver{Name: "test_resolver"} testResolver := &resolver.Resolver{Name: "test_resolver"}
testResolver.SetResolveFunc(func(_ context.Context, name string) (cid.ID, error) { testResolver.SetResolveFunc(func(_ context.Context, name string) (cid.ID, error) {
return tp.ContainerID(name) return tp.ContainerID(name)
}) })

View file

@ -18,7 +18,6 @@ import (
"github.com/nspcc-dev/neofs-s3-gw/api/data" "github.com/nspcc-dev/neofs-s3-gw/api/data"
"github.com/nspcc-dev/neofs-s3-gw/api/errors" "github.com/nspcc-dev/neofs-s3-gw/api/errors"
"github.com/nspcc-dev/neofs-s3-gw/api/layer/encryption" "github.com/nspcc-dev/neofs-s3-gw/api/layer/encryption"
"github.com/nspcc-dev/neofs-s3-gw/api/resolver"
"github.com/nspcc-dev/neofs-s3-gw/creds/accessbox" "github.com/nspcc-dev/neofs-s3-gw/creds/accessbox"
"github.com/nspcc-dev/neofs-sdk-go/bearer" "github.com/nspcc-dev/neofs-sdk-go/bearer"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id" cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
@ -42,11 +41,15 @@ type (
MsgHandlerFunc func(context.Context, *nats.Msg) error MsgHandlerFunc func(context.Context, *nats.Msg) error
BucketResolver interface {
Resolve(ctx context.Context, name string) (cid.ID, error)
}
layer struct { layer struct {
neoFS NeoFS neoFS NeoFS
log *zap.Logger log *zap.Logger
anonKey AnonymousKey anonKey AnonymousKey
resolver *resolver.BucketResolver resolver BucketResolver
ncontroller EventListener ncontroller EventListener
listsCache *cache.ObjectsListCache listsCache *cache.ObjectsListCache
objCache *cache.ObjectsCache objCache *cache.ObjectsCache
@ -60,7 +63,7 @@ type (
ChainAddress string ChainAddress string
Caches *CachesConfig Caches *CachesConfig
AnonKey AnonymousKey AnonKey AnonymousKey
Resolver *resolver.BucketResolver Resolver BucketResolver
TreeService TreeService TreeService TreeService
} }

View file

@ -2,7 +2,9 @@ package resolver
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id" cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
"github.com/nspcc-dev/neofs-sdk-go/ns" "github.com/nspcc-dev/neofs-sdk-go/ns"
@ -13,6 +15,9 @@ const (
DNSResolver = "dns" DNSResolver = "dns"
) )
// ErrNoResolvers returns when trying to resolve container without any resolver.
var ErrNoResolvers = errors.New("no resolvers")
// NeoFS represents virtual connection to the NeoFS network. // NeoFS represents virtual connection to the NeoFS network.
type NeoFS interface { type NeoFS interface {
// SystemDNS reads system DNS network parameters of the NeoFS. // SystemDNS reads system DNS network parameters of the NeoFS.
@ -28,62 +33,115 @@ type Config struct {
} }
type BucketResolver struct { type BucketResolver struct {
Name string mu sync.RWMutex
resolve func(context.Context, string) (cid.ID, error) resolvers []*Resolver
next *BucketResolver
} }
func (r *BucketResolver) SetResolveFunc(fn func(context.Context, string) (cid.ID, error)) { type Resolver struct {
Name string
resolve func(context.Context, string) (cid.ID, error)
}
func (r *Resolver) SetResolveFunc(fn func(context.Context, string) (cid.ID, error)) {
r.resolve = fn r.resolve = fn
} }
func (r *BucketResolver) Resolve(ctx context.Context, name string) (cid.ID, error) { func (r *Resolver) Resolve(ctx context.Context, name string) (cid.ID, error) {
cnrID, err := r.resolve(ctx, name) return r.resolve(ctx, name)
}
func NewBucketResolver(resolverNames []string, cfg *Config) (*BucketResolver, error) {
resolvers, err := createResolvers(resolverNames, cfg)
if err != nil { if err != nil {
if r.next != nil { return nil, err
return r.next.Resolve(ctx, name)
} }
return cid.ID{}, fmt.Errorf("failed resolve: %w", err)
return &BucketResolver{
resolvers: resolvers,
}, nil
}
func createResolvers(resolverNames []string, cfg *Config) ([]*Resolver, error) {
resolvers := make([]*Resolver, len(resolverNames))
for i, name := range resolverNames {
cnrResolver, err := newResolver(name, cfg)
if err != nil {
return nil, err
}
resolvers[i] = cnrResolver
}
return resolvers, nil
}
func (r *BucketResolver) Resolve(ctx context.Context, bktName string) (cnrID cid.ID, err error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, resolver := range r.resolvers {
cnrID, resolverErr := resolver.Resolve(ctx, bktName)
if resolverErr != nil {
resolverErr = fmt.Errorf("%s: %w", resolver.Name, resolverErr)
if err == nil {
err = resolverErr
} else {
err = fmt.Errorf("%s: %w", err.Error(), resolverErr)
}
continue
} }
return cnrID, nil return cnrID, nil
}
if err != nil {
return cnrID, err
}
return cnrID, ErrNoResolvers
} }
func NewResolver(order []string, cfg *Config) (*BucketResolver, error) { func (r *BucketResolver) UpdateResolvers(resolverNames []string, cfg *Config) error {
if len(order) == 0 { r.mu.Lock()
return nil, fmt.Errorf("resolving order must not be empty") defer r.mu.Unlock()
if r.equals(resolverNames) {
return nil
} }
bucketResolver, err := newResolver(order[len(order)-1], cfg, nil) resolvers, err := createResolvers(resolverNames, cfg)
if err != nil { if err != nil {
return nil, fmt.Errorf("create resolver: %w", err) return err
} }
for i := len(order) - 2; i >= 0; i-- { r.resolvers = resolvers
resolverName := order[i]
next := bucketResolver
bucketResolver, err = newResolver(resolverName, cfg, next) return nil
if err != nil {
return nil, fmt.Errorf("create resolver: %w", err)
}
}
return bucketResolver, nil
} }
func newResolver(name string, cfg *Config, next *BucketResolver) (*BucketResolver, error) { func (r *BucketResolver) equals(resolverNames []string) bool {
if len(r.resolvers) != len(resolverNames) {
return false
}
for i := 0; i < len(resolverNames); i++ {
if r.resolvers[i].Name != resolverNames[i] {
return false
}
}
return true
}
func newResolver(name string, cfg *Config) (*Resolver, error) {
switch name { switch name {
case DNSResolver: case DNSResolver:
return NewDNSResolver(cfg.NeoFS, next) return NewDNSResolver(cfg.NeoFS)
case NNSResolver: case NNSResolver:
return NewNNSResolver(cfg.RPCAddress, next) return NewNNSResolver(cfg.RPCAddress)
default: default:
return nil, fmt.Errorf("unknown resolver: %s", name) return nil, fmt.Errorf("unknown resolver: %s", name)
} }
} }
func NewDNSResolver(neoFS NeoFS, next *BucketResolver) (*BucketResolver, error) { func NewDNSResolver(neoFS NeoFS) (*Resolver, error) {
if neoFS == nil { if neoFS == nil {
return nil, fmt.Errorf("pool must not be nil for DNS resolver") return nil, fmt.Errorf("pool must not be nil for DNS resolver")
} }
@ -104,15 +162,13 @@ func NewDNSResolver(neoFS NeoFS, next *BucketResolver) (*BucketResolver, error)
return cnrID, nil return cnrID, nil
} }
return &BucketResolver{ return &Resolver{
Name: DNSResolver, Name: DNSResolver,
resolve: resolveFunc, resolve: resolveFunc,
next: next,
}, nil }, nil
} }
func NewNNSResolver(address string, next *BucketResolver) (*BucketResolver, error) { func NewNNSResolver(address string) (*Resolver, error) {
if address == "" { if address == "" {
return nil, fmt.Errorf("rpc address must not be empty for NNS resolver") return nil, fmt.Errorf("rpc address must not be empty for NNS resolver")
} }
@ -131,10 +187,8 @@ func NewNNSResolver(address string, next *BucketResolver) (*BucketResolver, erro
return cnrID, nil return cnrID, nil
} }
return &BucketResolver{ return &Resolver{
Name: NNSResolver, Name: NNSResolver,
resolve: resolveFunc, resolve: resolveFunc,
next: next,
}, nil }, nil
} }

View file

@ -2,7 +2,9 @@ package main
import ( import (
"context" "context"
"crypto/tls"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -36,11 +38,13 @@ type (
ctr auth.Center ctr auth.Center
log *zap.Logger log *zap.Logger
cfg *viper.Viper cfg *viper.Viper
tls *tlsConfig pool *pool.Pool
obj layer.Client obj layer.Client
api api.Handler api api.Handler
metrics *appMetrics metrics *appMetrics
bucketResolver *resolver.BucketResolver
tlsProvider *certProvider
maxClients api.MaxClients maxClients api.MaxClients
@ -60,9 +64,13 @@ type (
lvl zap.AtomicLevel lvl zap.AtomicLevel
} }
tlsConfig struct { certProvider struct {
KeyFile string Enabled bool
CertFile string
mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
} }
appMetrics struct { appMetrics struct {
@ -84,7 +92,6 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
var ( var (
key *keys.PrivateKey key *keys.PrivateKey
err error err error
tls *tlsConfig
caller api.Handler caller api.Handler
ctr auth.Center ctr auth.Center
obj layer.Client obj layer.Client
@ -130,15 +137,7 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
l.Fatal("could not load NeoFS private key", zap.Error(err)) l.Fatal("could not load NeoFS private key", zap.Error(err))
} }
if v.IsSet(cfgTLSKeyFile) && v.IsSet(cfgTLSCertFile) { l.Info("using credentials", zap.String("NeoFS", hex.EncodeToString(key.PublicKey().Bytes())))
tls = &tlsConfig{
KeyFile: v.GetString(cfgTLSKeyFile),
CertFile: v.GetString(cfgTLSCertFile),
}
}
l.Info("using credentials",
zap.String("NeoFS", hex.EncodeToString(key.PublicKey().Bytes())))
prmPool.SetKey(&key.PrivateKey) prmPool.SetKey(&key.PrivateKey)
prmPool.SetNodeDialTimeout(conTimeout) prmPool.SetNodeDialTimeout(conTimeout)
@ -175,7 +174,7 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
l.Warn(fmt.Sprintf("resolver '%s' won't be used since '%s' isn't provided", resolver.NNSResolver, cfgRPCEndpoint)) l.Warn(fmt.Sprintf("resolver '%s' won't be used since '%s' isn't provided", resolver.NNSResolver, cfgRPCEndpoint))
} }
bucketResolver, err := resolver.NewResolver(order, resolveCfg) bucketResolver, err := resolver.NewBucketResolver(order, resolveCfg)
if err != nil { if err != nil {
l.Fatal("failed to form resolver", zap.Error(err)) l.Fatal("failed to form resolver", zap.Error(err))
} }
@ -223,8 +222,8 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
ctr: ctr, ctr: ctr,
log: l, log: l,
cfg: v, cfg: v,
pool: conns,
obj: obj, obj: obj,
tls: tls,
api: caller, api: caller,
webDone: make(chan struct{}, 1), webDone: make(chan struct{}, 1),
@ -233,16 +232,51 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
maxClients: api.NewMaxClientsMiddleware(maxClientsCount, maxClientsDeadline), maxClients: api.NewMaxClientsMiddleware(maxClientsCount, maxClientsDeadline),
} }
app.initMetrics(neofs.NewPoolStatistic(conns)) app.initMetrics()
app.initResolver()
app.initTLSProvider()
return app return app
} }
func (a *App) initMetrics(scraper StatisticScraper) { func (a *App) initMetrics() {
gateMetricsProvider := newGateMetrics(scraper) gateMetricsProvider := newGateMetrics(neofs.NewPoolStatistic(a.pool))
a.metrics = newAppMetrics(a.log, gateMetricsProvider, a.cfg.GetBool(cfgPrometheusEnabled)) a.metrics = newAppMetrics(a.log, gateMetricsProvider, a.cfg.GetBool(cfgPrometheusEnabled))
} }
func (a *App) initResolver() {
var err error
a.bucketResolver, err = resolver.NewBucketResolver(a.getResolverConfig())
if err != nil {
a.log.Fatal("failed to create resolver", zap.Error(err))
}
}
func (a *App) initTLSProvider() {
a.tlsProvider = &certProvider{
Enabled: a.cfg.IsSet(cfgTLSCertFile) || a.cfg.IsSet(cfgTLSKeyFile),
}
}
func (a *App) getResolverConfig() ([]string, *resolver.Config) {
resolveCfg := &resolver.Config{
NeoFS: neofs.NewResolverNeoFS(a.pool),
RPCAddress: a.cfg.GetString(cfgRPCEndpoint),
}
order := a.cfg.GetStringSlice(cfgResolveOrder)
if resolveCfg.RPCAddress == "" {
order = remove(order, resolver.NNSResolver)
a.log.Warn(fmt.Sprintf("resolver '%s' won't be used since '%s' isn't provided", resolver.NNSResolver, cfgRPCEndpoint))
}
if len(order) == 0 {
a.log.Info("container resolver will be disabled because of resolvers 'resolver_order' is empty")
}
return order, resolveCfg
}
func newAppMetrics(logger *zap.Logger, provider GateMetricsCollector, enabled bool) *appMetrics { func newAppMetrics(logger *zap.Logger, provider GateMetricsCollector, enabled bool) *appMetrics {
if !enabled { if !enabled {
logger.Warn("metrics are disabled") logger.Warn("metrics are disabled")
@ -284,6 +318,44 @@ func (m *appMetrics) Shutdown() {
m.mu.Unlock() m.mu.Unlock()
} }
func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}
func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}
p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}
func (p *certProvider) FilePaths() (string, string) {
if !p.Enabled {
return "", ""
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.certPath, p.keyPath
}
func remove(list []string, element string) []string { func remove(list []string, element string) []string {
for i, item := range list { for i, item := range list {
if item == element { if item == element {
@ -317,50 +389,48 @@ func (a *App) setHealthStatus() {
// Serve runs HTTP server to handle S3 API requests. // Serve runs HTTP server to handle S3 API requests.
func (a *App) Serve(ctx context.Context) { func (a *App) Serve(ctx context.Context) {
var (
err error
lis net.Listener
lic net.ListenConfig
srv = new(http.Server)
addr = a.cfg.GetString(cfgListenAddress)
)
if lis, err = lic.Listen(ctx, "tcp", addr); err != nil {
a.log.Fatal("could not prepare listener",
zap.Error(err))
}
router := mux.NewRouter().SkipClean(true).UseEncodedPath()
// Attach S3 API: // Attach S3 API:
domains := a.cfg.GetStringSlice(cfgListenDomains) domains := a.cfg.GetStringSlice(cfgListenDomains)
a.log.Info("fetch domains, prepare to use API", a.log.Info("fetch domains, prepare to use API", zap.Strings("domains", domains))
zap.Strings("domains", domains)) router := mux.NewRouter().SkipClean(true).UseEncodedPath()
api.Attach(router, domains, a.maxClients, a.api, a.ctr, a.log) api.Attach(router, domains, a.maxClients, a.api, a.ctr, a.log)
// Use mux.Router as http.Handler // Use mux.Router as http.Handler
srv := new(http.Server)
srv.Handler = router srv.Handler = router
srv.ErrorLog = zap.NewStdLog(a.log) srv.ErrorLog = zap.NewStdLog(a.log)
a.startServices() a.startServices()
go func() { go func() {
a.log.Info("starting server", addr := a.cfg.GetString(cfgListenAddress)
zap.String("bind", addr)) a.log.Info("starting server", zap.String("bind", addr))
switch a.tls { var lic net.ListenConfig
case nil: ln, err := lic.Listen(ctx, "tcp", addr)
if err = srv.Serve(lis); err != nil && err != http.ErrServerClosed { if err != nil {
a.log.Fatal("listen and serve", a.log.Fatal("could not prepare listener", zap.Error(err))
zap.Error(err))
} }
default:
a.log.Info("using certificate",
zap.String("key", a.tls.KeyFile),
zap.String("cert", a.tls.CertFile))
if err = srv.ServeTLS(lis, a.tls.CertFile, a.tls.KeyFile); err != nil && err != http.ErrServerClosed { if a.tlsProvider.Enabled {
a.log.Fatal("listen and serve", certFile := a.cfg.GetString(cfgTLSCertFile)
zap.Error(err)) keyFile := a.cfg.GetString(cfgTLSKeyFile)
a.log.Info("using certificate", zap.String("cert", certFile), zap.String("key", keyFile))
if err = a.tlsProvider.UpdateCert(certFile, keyFile); err != nil {
a.log.Fatal("failed to update cert", zap.Error(err))
}
lnTLS := tls.NewListener(ln, &tls.Config{
GetCertificate: a.tlsProvider.GetCertificate,
})
if err = srv.ServeTLS(lnTLS, certFile, keyFile); err != nil && err != http.ErrServerClosed {
a.log.Fatal("listen and serve", zap.Error(err))
}
} else {
if err = srv.Serve(ln); err != nil && err != http.ErrServerClosed {
a.log.Fatal("listen and serve", zap.Error(err))
} }
} }
}() }()
@ -405,6 +475,14 @@ func (a *App) configReload() {
return return
} }
if err := a.bucketResolver.UpdateResolvers(a.getResolverConfig()); err != nil {
a.log.Warn("failed to reload resolvers", zap.Error(err))
}
if err := a.tlsProvider.UpdateCert(a.cfg.GetString(cfgTLSCertFile), a.cfg.GetString(cfgTLSKeyFile)); err != nil {
a.log.Warn("failed to reload TLS certs", zap.Error(err))
}
a.stopServices() a.stopServices()
a.startServices() a.startServices()