From f3d58e4ef00de8e9a0eeda408fd25e02b21570e8 Mon Sep 17 00:00:00 2001
From: Pavel Pogodaev
Date: Thu, 24 Nov 2022 17:49:21 +0300
Subject: [PATCH] [#228] add support of multiple sockets
Signed-off-by: Pavel Pogodaev
---
app.go | 132 +++++++++++++++++---------------------------
config/config.env | 14 +++--
config/config.yaml | 14 ++++-
integration_test.go | 2 +-
server.go | 124 +++++++++++++++++++++++++++++++++++++++++
settings.go | 64 ++++++++++++++++-----
6 files changed, 247 insertions(+), 103 deletions(-)
create mode 100644 server.go
diff --git a/app.go b/app.go
index 8e72555..6eb0945 100644
--- a/app.go
+++ b/app.go
@@ -3,10 +3,8 @@ package main
import (
"context"
"crypto/ecdsa"
- "crypto/tls"
- "errors"
"fmt"
- "net"
+ "net/http"
"os"
"os/signal"
"strconv"
@@ -45,12 +43,12 @@ type (
metrics *gateMetrics
services []*metrics.Service
settings *appSettings
+ servers []Server
}
appSettings struct {
- Uploader *uploader.Settings
- Downloader *downloader.Settings
- TLSProvider *certProvider
+ Uploader *uploader.Settings
+ Downloader *downloader.Settings
}
// App is an interface for the main gateway function.
@@ -179,9 +177,8 @@ func newApp(ctx context.Context, opt ...Option) App {
func (a *app) initAppSettings() {
a.settings = &appSettings{
- Uploader: &uploader.Settings{},
- Downloader: &downloader.Settings{},
- TLSProvider: &certProvider{Enabled: a.cfg.IsSet(cfgTLSCertificate) || a.cfg.IsSet(cfgTLSKey)},
+ Uploader: &uploader.Settings{},
+ Downloader: &downloader.Settings{},
}
a.updateSettings()
@@ -341,43 +338,6 @@ func (a *app) setHealthStatus() {
a.metrics.SetHealth(1)
}
-type certProvider struct {
- Enabled bool
-
- mu sync.RWMutex
- certPath string
- keyPath string
- cert *tls.Certificate
-}
-
-func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
- if !p.Enabled {
- return nil, errors.New("cert provider: disabled")
- }
-
- p.mu.RLock()
- defer p.mu.RUnlock()
- return p.cert, nil
-}
-
-func (p *certProvider) UpdateCert(certPath, keyPath string) error {
- if !p.Enabled {
- return fmt.Errorf("tls disabled")
- }
-
- cert, err := tls.LoadX509KeyPair(certPath, keyPath)
- if err != nil {
- return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
- }
-
- p.mu.Lock()
- p.certPath = certPath
- p.keyPath = keyPath
- p.cert = &cert
- p.mu.Unlock()
- return nil
-}
-
func (a *app) Serve(ctx context.Context) {
uploadRoutes := uploader.New(ctx, a.AppParams(), a.settings.Uploader)
downloadRoutes := downloader.New(ctx, a.AppParams(), a.settings.Downloader)
@@ -386,38 +346,16 @@ func (a *app) Serve(ctx context.Context) {
a.configureRouter(uploadRoutes, downloadRoutes)
a.startServices()
+ a.initServers(ctx)
- go func() {
- var err error
- defer func() {
- if err != nil {
- a.log.Fatal("could not start server", zap.Error(err))
+ for i := range a.servers {
+ go func(i int) {
+ a.log.Info("starting server", zap.String("address", a.servers[i].Address()))
+ if err := a.webServer.Serve(a.servers[i].Listener()); err != nil && err != http.ErrServerClosed {
+ a.log.Fatal("listen and serve", zap.Error(err))
}
- }()
-
- bind := a.cfg.GetString(cfgListenAddress)
-
- if a.settings.TLSProvider.Enabled {
- if err = a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil {
- return
- }
-
- var lnConf net.ListenConfig
- var ln net.Listener
- if ln, err = lnConf.Listen(ctx, "tcp4", bind); err != nil {
- return
- }
- lnTLS := tls.NewListener(ln, &tls.Config{
- GetCertificate: a.settings.TLSProvider.GetCertificate,
- })
-
- a.log.Info("running web server (TLS-enabled)", zap.String("address", bind))
- err = a.webServer.Serve(lnTLS)
- } else {
- a.log.Info("running web server", zap.String("address", bind))
- err = a.webServer.ListenAndServe(bind)
- }
- }()
+ }(i)
+ }
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP)
@@ -460,6 +398,10 @@ func (a *app) configReload() {
a.log.Warn("failed to update resolvers", zap.Error(err))
}
+ if err := a.updateServers(); err != nil {
+ a.log.Warn("failed to reload server parameters", zap.Error(err))
+ }
+
a.stopServices()
a.startServices()
@@ -474,10 +416,6 @@ func (a *app) configReload() {
func (a *app) updateSettings() {
a.settings.Uploader.SetDefaultTimestamp(a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp))
a.settings.Downloader.SetZipCompression(a.cfg.GetBool(cfgZipCompression))
-
- if err := a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil {
- a.log.Warn("failed to reload TLS certs", zap.Error(err))
- }
}
func (a *app) startServices() {
@@ -543,3 +481,37 @@ func (a *app) AppParams() *utils.AppParams {
Resolver: a.resolver,
}
}
+
+func (a *app) initServers(ctx context.Context) {
+ serversInfo := fetchServers(a.cfg)
+
+ a.servers = make([]Server, len(serversInfo))
+ for i, serverInfo := range serversInfo {
+ a.log.Info("added server",
+ zap.String("address", serverInfo.Address), zap.Bool("tls enabled", serverInfo.TLS.Enabled),
+ zap.String("tls cert", serverInfo.TLS.CertFile), zap.String("tls key", serverInfo.TLS.KeyFile))
+ a.servers[i] = newServer(ctx, serverInfo, a.log)
+ }
+}
+
+func (a *app) updateServers() error {
+ serversInfo := fetchServers(a.cfg)
+
+ if len(serversInfo) != len(a.servers) {
+ return fmt.Errorf("invalid servers configuration: length mismatch: old '%d', new '%d", len(a.servers), len(serversInfo))
+ }
+
+ for i, serverInfo := range serversInfo {
+ if serverInfo.Address != a.servers[i].Address() {
+ return fmt.Errorf("invalid servers configuration: addresses mismatch: old '%s', new '%s", a.servers[i].Address(), serverInfo.Address)
+ }
+
+ if serverInfo.TLS.Enabled {
+ if err := a.servers[i].UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
+ return fmt.Errorf("failed to update tls certs: %w", err)
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/config/config.env b/config/config.env
index c92306a..fbf7327 100644
--- a/config/config.env
+++ b/config/config.env
@@ -17,12 +17,14 @@ HTTP_GW_PROMETHEUS_ADDRESS=localhost:8084
# Log level.
HTTP_GW_LOGGER_LEVEL=debug
-# Address to bind.
-HTTP_GW_LISTEN_ADDRESS=0.0.0.0:443
-# Provide cert to enable TLS.
-HTTP_GW_TLS_CERTIFICATE=/path/to/tls/cert
-# Provide key to enable TLS.
-HTTP_GW_TLS_KEY=/path/to/tls/key
+HTTP_GW_SERVER_0_ADDRESS=0.0.0.0:443
+HTTP_GW_SERVER_0_TLS_ENABLED=false
+HTTP_GW_SERVER_0_TLS_CERT_FILE=/path/to/tls/cert
+HTTP_GW_SERVER_0_TLS_KEY_FILE=/path/to/tls/key
+HTTP_GW_SERVER_1_ADDRESS=0.0.0.0:444
+HTTP_GW_SERVER_1_TLS_ENABLED=true
+HTTP_GW_SERVER_1_TLS_CERT_FILE=/path/to/tls/cert
+HTTP_GW_SERVER_1_TLS_KEY_FILE=/path/to/tls/key
# Nodes configuration.
# This configuration make the gateway use the first node (grpc://s01.neofs.devenv:8080)
diff --git a/config/config.yaml b/config/config.yaml
index 38a9ff0..fc2fada 100644
--- a/config/config.yaml
+++ b/config/config.yaml
@@ -13,9 +13,17 @@ prometheus:
logger:
level: debug # Log level.
-listen_address: 0.0.0.0:443 # Address to bind.
-tls_certificate: /path/to/tls/cert # Provide cert to enable TLS.
-tls_key: /path/to/tls/key # Provide key to enable TLS.
+server:
+ - address: 0.0.0.0:8080
+ tls:
+ enabled: false
+ cert_file: /path/to/cert
+ key_file: /path/to/key
+ - address: 0.0.0.0:8081
+ tls:
+ enabled: false
+ cert_file: /path/to/cert
+ key_file: /path/to/key
# Nodes configuration.
# This configuration make the gateway use the first node (grpc://s01.neofs.devenv:8080)
diff --git a/integration_test.go b/integration_test.go
index b8b1e57..4607ef6 100644
--- a/integration_test.go
+++ b/integration_test.go
@@ -335,7 +335,7 @@ func getDefaultConfig() *viper.Viper {
v.SetDefault(cfgPeers+".0.priority", 1)
v.SetDefault(cfgRPCEndpoint, "http://localhost:30333")
- v.SetDefault(cfgListenAddress, testListenAddress)
+ v.SetDefault("server.0.address", testListenAddress)
return v
}
diff --git a/server.go b/server.go
new file mode 100644
index 0000000..3843081
--- /dev/null
+++ b/server.go
@@ -0,0 +1,124 @@
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+
+ "go.uber.org/zap"
+)
+
+type (
+ ServerInfo struct {
+ Address string
+ TLS ServerTLSInfo
+ }
+
+ ServerTLSInfo struct {
+ Enabled bool
+ CertFile string
+ KeyFile string
+ }
+
+ Server interface {
+ Address() string
+ Listener() net.Listener
+ UpdateCert(certFile, keyFile string) error
+ }
+
+ server struct {
+ address string
+ listener net.Listener
+ tlsProvider *certProvider
+ }
+
+ certProvider struct {
+ Enabled bool
+
+ mu sync.RWMutex
+ certPath string
+ keyPath string
+ cert *tls.Certificate
+ }
+)
+
+func (s *server) Address() string {
+ return s.address
+}
+
+func (s *server) Listener() net.Listener {
+ return s.listener
+}
+
+func (s *server) UpdateCert(certFile, keyFile string) error {
+ return s.tlsProvider.UpdateCert(certFile, keyFile)
+}
+
+func newServer(ctx context.Context, serverInfo ServerInfo, logger *zap.Logger) *server {
+ var lic net.ListenConfig
+ ln, err := lic.Listen(ctx, "tcp", serverInfo.Address)
+ if err != nil {
+ logger.Fatal("could not prepare listener", zap.String("address", serverInfo.Address), zap.Error(err))
+ }
+
+ tlsProvider := &certProvider{
+ Enabled: serverInfo.TLS.Enabled,
+ }
+
+ if serverInfo.TLS.Enabled {
+ if err = tlsProvider.UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
+ logger.Fatal("failed to update cert", zap.Error(err))
+ }
+
+ ln = tls.NewListener(ln, &tls.Config{
+ GetCertificate: tlsProvider.GetCertificate,
+ })
+ }
+
+ return &server{
+ address: serverInfo.Address,
+ listener: ln,
+ tlsProvider: tlsProvider,
+ }
+}
+
+func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if !p.Enabled {
+ return nil, errors.New("cert provider: disabled")
+ }
+
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return p.cert, nil
+}
+
+func (p *certProvider) UpdateCert(certPath, keyPath string) error {
+ if !p.Enabled {
+ return fmt.Errorf("tls disabled")
+ }
+
+ cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+ if err != nil {
+ return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
+ }
+
+ p.mu.Lock()
+ p.certPath = certPath
+ p.keyPath = keyPath
+ p.cert = &cert
+ p.mu.Unlock()
+ return nil
+}
+
+func (p *certProvider) FilePaths() (string, string) {
+ if !p.Enabled {
+ return "", ""
+ }
+
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return p.certPath, p.keyPath
+}
diff --git a/settings.go b/settings.go
index b042ab5..bb841d0 100644
--- a/settings.go
+++ b/settings.go
@@ -27,9 +27,10 @@ const (
defaultPoolErrorThreshold uint32 = 100
- cfgListenAddress = "listen_address"
- cfgTLSCertificate = "tls_certificate"
- cfgTLSKey = "tls_key"
+ cfgServer = "server"
+ cfgTLSEnabled = "tls.enabled"
+ cfgTLSCertFile = "tls.cert_file"
+ cfgTLSKeyFile = "tls.key_file"
// Web.
cfgWebReadBufferSize = "web.read_buffer_size"
@@ -76,13 +77,14 @@ const (
cfgZipCompression = "zip.compression"
// Command line args.
- cmdHelp = "help"
- cmdVersion = "version"
- cmdPprof = "pprof"
- cmdMetrics = "metrics"
- cmdWallet = "wallet"
- cmdAddress = "address"
- cmdConfig = "config"
+ cmdHelp = "help"
+ cmdVersion = "version"
+ cmdPprof = "pprof"
+ cmdMetrics = "metrics"
+ cmdWallet = "wallet"
+ cmdAddress = "address"
+ cmdConfig = "config"
+ cmdListenAddress = "listen_address"
)
var ignore = map[string]struct{}{
@@ -118,9 +120,9 @@ func settings() *viper.Viper {
flags.Duration(cfgReqTimeout, defaultRequestTimeout, "gRPC request timeout")
flags.Duration(cfgRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer")
- flags.String(cfgListenAddress, "0.0.0.0:8082", "address to listen")
- flags.String(cfgTLSCertificate, "", "TLS certificate path")
- flags.String(cfgTLSKey, "", "TLS key path")
+ flags.String(cmdListenAddress, "0.0.0.0:8080", "addresses to listen")
+ flags.String(cfgTLSCertFile, "", "TLS certificate path")
+ flags.String(cfgTLSKeyFile, "", "TLS key path")
peers := flags.StringArrayP(cfgPeers, "p", nil, "NeoFS nodes")
resolveMethods := flags.StringSlice(cfgResolveOrder, []string{resolver.NNSResolver, resolver.DNSResolver}, "set container name resolve order")
@@ -171,10 +173,24 @@ func settings() *viper.Viper {
panic(err)
}
+ if err := v.BindPFlag(cfgServer+".0.address", flags.Lookup(cmdListenAddress)); err != nil {
+ panic(err)
+ }
+ if err := v.BindPFlag(cfgServer+".0."+cfgTLSKeyFile, flags.Lookup(cfgTLSKeyFile)); err != nil {
+ panic(err)
+ }
+ if err := v.BindPFlag(cfgServer+".0."+cfgTLSCertFile, flags.Lookup(cfgTLSCertFile)); err != nil {
+ panic(err)
+ }
+
if err := flags.Parse(os.Args); err != nil {
panic(err)
}
+ if v.IsSet(cfgServer+".0."+cfgTLSKeyFile) && v.IsSet(cfgServer+".0."+cfgTLSCertFile) {
+ v.Set(cfgServer+".0."+cfgTLSEnabled, true)
+ }
+
if resolveMethods != nil {
v.SetDefault(cfgResolveOrder, *resolveMethods)
}
@@ -297,3 +313,25 @@ func getLogLevel(v *viper.Viper) (zapcore.Level, error) {
}
return lvl, nil
}
+
+func fetchServers(v *viper.Viper) []ServerInfo {
+ var servers []ServerInfo
+
+ for i := 0; ; i++ {
+ key := cfgServer + "." + strconv.Itoa(i) + "."
+
+ var serverInfo ServerInfo
+ serverInfo.Address = v.GetString(key + "address")
+ serverInfo.TLS.Enabled = v.GetBool(key + cfgTLSEnabled)
+ serverInfo.TLS.KeyFile = v.GetString(key + cfgTLSKeyFile)
+ serverInfo.TLS.CertFile = v.GetString(key + cfgTLSCertFile)
+
+ if serverInfo.Address == "" {
+ break
+ }
+
+ servers = append(servers, serverInfo)
+ }
+
+ return servers
+}