diff --git a/app.go b/app.go index 8e72555..6eb0945 100644 --- a/app.go +++ b/app.go @@ -3,10 +3,8 @@ package main import ( "context" "crypto/ecdsa" - "crypto/tls" - "errors" "fmt" - "net" + "net/http" "os" "os/signal" "strconv" @@ -45,12 +43,12 @@ type ( metrics *gateMetrics services []*metrics.Service settings *appSettings + servers []Server } appSettings struct { - Uploader *uploader.Settings - Downloader *downloader.Settings - TLSProvider *certProvider + Uploader *uploader.Settings + Downloader *downloader.Settings } // App is an interface for the main gateway function. @@ -179,9 +177,8 @@ func newApp(ctx context.Context, opt ...Option) App { func (a *app) initAppSettings() { a.settings = &appSettings{ - Uploader: &uploader.Settings{}, - Downloader: &downloader.Settings{}, - TLSProvider: &certProvider{Enabled: a.cfg.IsSet(cfgTLSCertificate) || a.cfg.IsSet(cfgTLSKey)}, + Uploader: &uploader.Settings{}, + Downloader: &downloader.Settings{}, } a.updateSettings() @@ -341,43 +338,6 @@ func (a *app) setHealthStatus() { a.metrics.SetHealth(1) } -type certProvider struct { - Enabled bool - - mu sync.RWMutex - certPath string - keyPath string - cert *tls.Certificate -} - -func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { - if !p.Enabled { - return nil, errors.New("cert provider: disabled") - } - - p.mu.RLock() - defer p.mu.RUnlock() - return p.cert, nil -} - -func (p *certProvider) UpdateCert(certPath, keyPath string) error { - if !p.Enabled { - return fmt.Errorf("tls disabled") - } - - cert, err := tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err) - } - - p.mu.Lock() - p.certPath = certPath - p.keyPath = keyPath - p.cert = &cert - p.mu.Unlock() - return nil -} - func (a *app) Serve(ctx context.Context) { uploadRoutes := uploader.New(ctx, a.AppParams(), a.settings.Uploader) downloadRoutes := downloader.New(ctx, a.AppParams(), a.settings.Downloader) @@ -386,38 +346,16 @@ func (a *app) Serve(ctx context.Context) { a.configureRouter(uploadRoutes, downloadRoutes) a.startServices() + a.initServers(ctx) - go func() { - var err error - defer func() { - if err != nil { - a.log.Fatal("could not start server", zap.Error(err)) + for i := range a.servers { + go func(i int) { + a.log.Info("starting server", zap.String("address", a.servers[i].Address())) + if err := a.webServer.Serve(a.servers[i].Listener()); err != nil && err != http.ErrServerClosed { + a.log.Fatal("listen and serve", zap.Error(err)) } - }() - - bind := a.cfg.GetString(cfgListenAddress) - - if a.settings.TLSProvider.Enabled { - if err = a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil { - return - } - - var lnConf net.ListenConfig - var ln net.Listener - if ln, err = lnConf.Listen(ctx, "tcp4", bind); err != nil { - return - } - lnTLS := tls.NewListener(ln, &tls.Config{ - GetCertificate: a.settings.TLSProvider.GetCertificate, - }) - - a.log.Info("running web server (TLS-enabled)", zap.String("address", bind)) - err = a.webServer.Serve(lnTLS) - } else { - a.log.Info("running web server", zap.String("address", bind)) - err = a.webServer.ListenAndServe(bind) - } - }() + }(i) + } sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGHUP) @@ -460,6 +398,10 @@ func (a *app) configReload() { a.log.Warn("failed to update resolvers", zap.Error(err)) } + if err := a.updateServers(); err != nil { + a.log.Warn("failed to reload server parameters", zap.Error(err)) + } + a.stopServices() a.startServices() @@ -474,10 +416,6 @@ func (a *app) configReload() { func (a *app) updateSettings() { a.settings.Uploader.SetDefaultTimestamp(a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp)) a.settings.Downloader.SetZipCompression(a.cfg.GetBool(cfgZipCompression)) - - if err := a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil { - a.log.Warn("failed to reload TLS certs", zap.Error(err)) - } } func (a *app) startServices() { @@ -543,3 +481,37 @@ func (a *app) AppParams() *utils.AppParams { Resolver: a.resolver, } } + +func (a *app) initServers(ctx context.Context) { + serversInfo := fetchServers(a.cfg) + + a.servers = make([]Server, len(serversInfo)) + for i, serverInfo := range serversInfo { + a.log.Info("added server", + zap.String("address", serverInfo.Address), zap.Bool("tls enabled", serverInfo.TLS.Enabled), + zap.String("tls cert", serverInfo.TLS.CertFile), zap.String("tls key", serverInfo.TLS.KeyFile)) + a.servers[i] = newServer(ctx, serverInfo, a.log) + } +} + +func (a *app) updateServers() error { + serversInfo := fetchServers(a.cfg) + + if len(serversInfo) != len(a.servers) { + return fmt.Errorf("invalid servers configuration: length mismatch: old '%d', new '%d", len(a.servers), len(serversInfo)) + } + + for i, serverInfo := range serversInfo { + if serverInfo.Address != a.servers[i].Address() { + return fmt.Errorf("invalid servers configuration: addresses mismatch: old '%s', new '%s", a.servers[i].Address(), serverInfo.Address) + } + + if serverInfo.TLS.Enabled { + if err := a.servers[i].UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil { + return fmt.Errorf("failed to update tls certs: %w", err) + } + } + } + + return nil +} diff --git a/config/config.env b/config/config.env index c92306a..fbf7327 100644 --- a/config/config.env +++ b/config/config.env @@ -17,12 +17,14 @@ HTTP_GW_PROMETHEUS_ADDRESS=localhost:8084 # Log level. HTTP_GW_LOGGER_LEVEL=debug -# Address to bind. -HTTP_GW_LISTEN_ADDRESS=0.0.0.0:443 -# Provide cert to enable TLS. -HTTP_GW_TLS_CERTIFICATE=/path/to/tls/cert -# Provide key to enable TLS. -HTTP_GW_TLS_KEY=/path/to/tls/key +HTTP_GW_SERVER_0_ADDRESS=0.0.0.0:443 +HTTP_GW_SERVER_0_TLS_ENABLED=false +HTTP_GW_SERVER_0_TLS_CERT_FILE=/path/to/tls/cert +HTTP_GW_SERVER_0_TLS_KEY_FILE=/path/to/tls/key +HTTP_GW_SERVER_1_ADDRESS=0.0.0.0:444 +HTTP_GW_SERVER_1_TLS_ENABLED=true +HTTP_GW_SERVER_1_TLS_CERT_FILE=/path/to/tls/cert +HTTP_GW_SERVER_1_TLS_KEY_FILE=/path/to/tls/key # Nodes configuration. # This configuration make the gateway use the first node (grpc://s01.neofs.devenv:8080) diff --git a/config/config.yaml b/config/config.yaml index 38a9ff0..fc2fada 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -13,9 +13,17 @@ prometheus: logger: level: debug # Log level. -listen_address: 0.0.0.0:443 # Address to bind. -tls_certificate: /path/to/tls/cert # Provide cert to enable TLS. -tls_key: /path/to/tls/key # Provide key to enable TLS. +server: + - address: 0.0.0.0:8080 + tls: + enabled: false + cert_file: /path/to/cert + key_file: /path/to/key + - address: 0.0.0.0:8081 + tls: + enabled: false + cert_file: /path/to/cert + key_file: /path/to/key # Nodes configuration. # This configuration make the gateway use the first node (grpc://s01.neofs.devenv:8080) diff --git a/integration_test.go b/integration_test.go index b8b1e57..4607ef6 100644 --- a/integration_test.go +++ b/integration_test.go @@ -335,7 +335,7 @@ func getDefaultConfig() *viper.Viper { v.SetDefault(cfgPeers+".0.priority", 1) v.SetDefault(cfgRPCEndpoint, "http://localhost:30333") - v.SetDefault(cfgListenAddress, testListenAddress) + v.SetDefault("server.0.address", testListenAddress) return v } diff --git a/server.go b/server.go new file mode 100644 index 0000000..3843081 --- /dev/null +++ b/server.go @@ -0,0 +1,124 @@ +package main + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "sync" + + "go.uber.org/zap" +) + +type ( + ServerInfo struct { + Address string + TLS ServerTLSInfo + } + + ServerTLSInfo struct { + Enabled bool + CertFile string + KeyFile string + } + + Server interface { + Address() string + Listener() net.Listener + UpdateCert(certFile, keyFile string) error + } + + server struct { + address string + listener net.Listener + tlsProvider *certProvider + } + + certProvider struct { + Enabled bool + + mu sync.RWMutex + certPath string + keyPath string + cert *tls.Certificate + } +) + +func (s *server) Address() string { + return s.address +} + +func (s *server) Listener() net.Listener { + return s.listener +} + +func (s *server) UpdateCert(certFile, keyFile string) error { + return s.tlsProvider.UpdateCert(certFile, keyFile) +} + +func newServer(ctx context.Context, serverInfo ServerInfo, logger *zap.Logger) *server { + var lic net.ListenConfig + ln, err := lic.Listen(ctx, "tcp", serverInfo.Address) + if err != nil { + logger.Fatal("could not prepare listener", zap.String("address", serverInfo.Address), zap.Error(err)) + } + + tlsProvider := &certProvider{ + Enabled: serverInfo.TLS.Enabled, + } + + if serverInfo.TLS.Enabled { + if err = tlsProvider.UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil { + logger.Fatal("failed to update cert", zap.Error(err)) + } + + ln = tls.NewListener(ln, &tls.Config{ + GetCertificate: tlsProvider.GetCertificate, + }) + } + + return &server{ + address: serverInfo.Address, + listener: ln, + tlsProvider: tlsProvider, + } +} + +func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + if !p.Enabled { + return nil, errors.New("cert provider: disabled") + } + + p.mu.RLock() + defer p.mu.RUnlock() + return p.cert, nil +} + +func (p *certProvider) UpdateCert(certPath, keyPath string) error { + if !p.Enabled { + return fmt.Errorf("tls disabled") + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err) + } + + p.mu.Lock() + p.certPath = certPath + p.keyPath = keyPath + p.cert = &cert + p.mu.Unlock() + return nil +} + +func (p *certProvider) FilePaths() (string, string) { + if !p.Enabled { + return "", "" + } + + p.mu.RLock() + defer p.mu.RUnlock() + return p.certPath, p.keyPath +} diff --git a/settings.go b/settings.go index b042ab5..bb841d0 100644 --- a/settings.go +++ b/settings.go @@ -27,9 +27,10 @@ const ( defaultPoolErrorThreshold uint32 = 100 - cfgListenAddress = "listen_address" - cfgTLSCertificate = "tls_certificate" - cfgTLSKey = "tls_key" + cfgServer = "server" + cfgTLSEnabled = "tls.enabled" + cfgTLSCertFile = "tls.cert_file" + cfgTLSKeyFile = "tls.key_file" // Web. cfgWebReadBufferSize = "web.read_buffer_size" @@ -76,13 +77,14 @@ const ( cfgZipCompression = "zip.compression" // Command line args. - cmdHelp = "help" - cmdVersion = "version" - cmdPprof = "pprof" - cmdMetrics = "metrics" - cmdWallet = "wallet" - cmdAddress = "address" - cmdConfig = "config" + cmdHelp = "help" + cmdVersion = "version" + cmdPprof = "pprof" + cmdMetrics = "metrics" + cmdWallet = "wallet" + cmdAddress = "address" + cmdConfig = "config" + cmdListenAddress = "listen_address" ) var ignore = map[string]struct{}{ @@ -118,9 +120,9 @@ func settings() *viper.Viper { flags.Duration(cfgReqTimeout, defaultRequestTimeout, "gRPC request timeout") flags.Duration(cfgRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer") - flags.String(cfgListenAddress, "0.0.0.0:8082", "address to listen") - flags.String(cfgTLSCertificate, "", "TLS certificate path") - flags.String(cfgTLSKey, "", "TLS key path") + flags.String(cmdListenAddress, "0.0.0.0:8080", "addresses to listen") + flags.String(cfgTLSCertFile, "", "TLS certificate path") + flags.String(cfgTLSKeyFile, "", "TLS key path") peers := flags.StringArrayP(cfgPeers, "p", nil, "NeoFS nodes") resolveMethods := flags.StringSlice(cfgResolveOrder, []string{resolver.NNSResolver, resolver.DNSResolver}, "set container name resolve order") @@ -171,10 +173,24 @@ func settings() *viper.Viper { panic(err) } + if err := v.BindPFlag(cfgServer+".0.address", flags.Lookup(cmdListenAddress)); err != nil { + panic(err) + } + if err := v.BindPFlag(cfgServer+".0."+cfgTLSKeyFile, flags.Lookup(cfgTLSKeyFile)); err != nil { + panic(err) + } + if err := v.BindPFlag(cfgServer+".0."+cfgTLSCertFile, flags.Lookup(cfgTLSCertFile)); err != nil { + panic(err) + } + if err := flags.Parse(os.Args); err != nil { panic(err) } + if v.IsSet(cfgServer+".0."+cfgTLSKeyFile) && v.IsSet(cfgServer+".0."+cfgTLSCertFile) { + v.Set(cfgServer+".0."+cfgTLSEnabled, true) + } + if resolveMethods != nil { v.SetDefault(cfgResolveOrder, *resolveMethods) } @@ -297,3 +313,25 @@ func getLogLevel(v *viper.Viper) (zapcore.Level, error) { } return lvl, nil } + +func fetchServers(v *viper.Viper) []ServerInfo { + var servers []ServerInfo + + for i := 0; ; i++ { + key := cfgServer + "." + strconv.Itoa(i) + "." + + var serverInfo ServerInfo + serverInfo.Address = v.GetString(key + "address") + serverInfo.TLS.Enabled = v.GetBool(key + cfgTLSEnabled) + serverInfo.TLS.KeyFile = v.GetString(key + cfgTLSKeyFile) + serverInfo.TLS.CertFile = v.GetString(key + cfgTLSCertFile) + + if serverInfo.Address == "" { + break + } + + servers = append(servers, serverInfo) + } + + return servers +}