[#742] Add multiple listeners

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-11-09 13:07:18 +03:00 committed by Alex Vanin
parent 556374e3b0
commit dd4f66712c
15 changed files with 296 additions and 135 deletions

View file

@ -28,7 +28,6 @@ type (
Policy PlacementPolicy
DefaultMaxAge int
NotificatorEnabled bool
TLSEnabled bool
CopiesNumber uint32
}

View file

@ -94,7 +94,7 @@ func (h *handler) GetObjectAttributesHandler(w http.ResponseWriter, r *http.Requ
}
info := extendedInfo.ObjectInfo
encryptionParams, err := h.formEncryptionParams(r.Header)
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return

View file

@ -142,7 +142,7 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
}
}
encryptionParams, err := h.formEncryptionParams(r.Header)
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return

View file

@ -3,6 +3,7 @@ package handler
import (
"bytes"
"crypto/rand"
"crypto/tls"
"fmt"
"io"
"net/http"
@ -284,6 +285,7 @@ func getEncryptedObjectRange(t *testing.T, tc *handlerContext, bktName, objName
}
func setEncryptHeaders(r *http.Request) {
r.TLS = &tls.ConnectionState{}
r.Header.Set(api.AmzServerSideEncryptionCustomerAlgorithm, layer.AESEncryptionAlgorithm)
r.Header.Set(api.AmzServerSideEncryptionCustomerKey, aes256Key)
r.Header.Set(api.AmzServerSideEncryptionCustomerKeyMD5, aes256KeyMD5)

View file

@ -150,7 +150,7 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
return
}
encryptionParams, err := h.formEncryptionParams(r.Header)
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return

View file

@ -93,7 +93,6 @@ func prepareHandlerContext(t *testing.T) *handlerContext {
log: l,
obj: layer.NewLayer(l, tp, layerCfg),
cfg: &Config{
TLSEnabled: true,
Policy: &placementPolicyMock{defaultPolicy: pp},
},
}

View file

@ -53,7 +53,7 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
}
info := extendedInfo.ObjectInfo
encryptionParams, err := h.formEncryptionParams(r.Header)
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return

View file

@ -137,7 +137,7 @@ func (h *handler) CreateMultipartUploadHandler(w http.ResponseWriter, r *http.Re
}
}
p.Info.Encryption, err = h.formEncryptionParams(r.Header)
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
@ -226,7 +226,7 @@ func (h *handler) UploadPartHandler(w http.ResponseWriter, r *http.Request) {
Reader: r.Body,
}
p.Info.Encryption, err = h.formEncryptionParams(r.Header)
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
@ -331,7 +331,7 @@ func (h *handler) UploadPartCopy(w http.ResponseWriter, r *http.Request) {
Range: srcRange,
}
p.Info.Encryption, err = h.formEncryptionParams(r.Header)
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
@ -573,7 +573,7 @@ func (h *handler) ListPartsHandler(w http.ResponseWriter, r *http.Request) {
PartNumberMarker: partNumberMarker,
}
p.Info.Encryption, err = h.formEncryptionParams(r.Header)
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
@ -608,7 +608,7 @@ func (h *handler) AbortMultipartUploadHandler(w http.ResponseWriter, r *http.Req
Key: reqInfo.ObjectName,
}
p.Encryption, err = h.formEncryptionParams(r.Header)
p.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return

View file

@ -218,7 +218,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
return
}
encryption, err := h.formEncryptionParams(r.Header)
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
@ -230,7 +230,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
Reader: r.Body,
Size: r.ContentLength,
Header: metadata,
Encryption: encryption,
Encryption: encryptionParams,
CopiesNumber: copiesNumber,
}
@ -304,7 +304,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
if settings.VersioningEnabled() {
w.Header().Set(api.AmzVersionID, objInfo.VersionID())
}
if encryption.Enabled() {
if encryptionParams.Enabled() {
addSSECHeaders(w.Header(), r.Header)
}
@ -326,16 +326,16 @@ func getCopiesNumberOrDefault(metadata map[string]string, defaultCopiesNumber ui
return uint32(copiesNumber), nil
}
func (h handler) formEncryptionParams(header http.Header) (enc encryption.Params, err error) {
sseCustomerAlgorithm := header.Get(api.AmzServerSideEncryptionCustomerAlgorithm)
sseCustomerKey := header.Get(api.AmzServerSideEncryptionCustomerKey)
sseCustomerKeyMD5 := header.Get(api.AmzServerSideEncryptionCustomerKeyMD5)
func formEncryptionParams(r *http.Request) (enc encryption.Params, err error) {
sseCustomerAlgorithm := r.Header.Get(api.AmzServerSideEncryptionCustomerAlgorithm)
sseCustomerKey := r.Header.Get(api.AmzServerSideEncryptionCustomerKey)
sseCustomerKeyMD5 := r.Header.Get(api.AmzServerSideEncryptionCustomerKeyMD5)
if len(sseCustomerAlgorithm) == 0 && len(sseCustomerKey) == 0 && len(sseCustomerKeyMD5) == 0 {
return
}
if !h.cfg.TLSEnabled {
if r.TLS == nil {
return enc, errorsStd.New("encryption available only when TLS is enabled")
}

View file

@ -161,6 +161,7 @@ func logErrorResponse(l *zap.Logger) mux.MiddlewareFunc {
l.Info("call method",
zap.Int("status", lw.statusCode),
zap.String("host", r.Host),
zap.String("request_id", GetRequestID(r.Context())),
zap.String("method", mux.CurrentRoute(r).GetName()),
zap.String("bucket", reqInfo.BucketName),

View file

@ -2,12 +2,9 @@ package main
import (
"context"
"crypto/tls"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
@ -46,9 +43,10 @@ type (
obj layer.Client
api api.Handler
servers []Server
metrics *appMetrics
bucketResolver *resolver.BucketResolver
tlsProvider *certProvider
services []*Service
settings *appSettings
maxClients api.MaxClients
@ -67,15 +65,6 @@ type (
lvl zap.AtomicLevel
}
certProvider struct {
Enabled bool
mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
}
appMetrics struct {
logger *zap.Logger
provider GateMetricsCollector
@ -123,7 +112,7 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
func (a *App) init(ctx context.Context) {
a.initAPI(ctx)
a.initMetrics()
a.initTLSProvider()
a.initServers(ctx)
}
func (a *App) initLayer(ctx context.Context) {
@ -206,12 +195,6 @@ func (a *App) initResolver() {
}
}
func (a *App) initTLSProvider() {
a.tlsProvider = &certProvider{
Enabled: a.cfg.IsSet(cfgTLSCertFile) || a.cfg.IsSet(cfgTLSKeyFile),
}
}
func (a *App) getResolverConfig() ([]string, *resolver.Config) {
resolveCfg := &resolver.Config{
NeoFS: neofs.NewResolverNeoFS(a.pool),
@ -401,44 +384,6 @@ func (m *appMetrics) Shutdown() {
m.mu.Unlock()
}
func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}
func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}
p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}
func (p *certProvider) FilePaths() (string, string) {
if !p.Enabled {
return "", ""
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.certPath, p.keyPath
}
func remove(list []string, element string) []string {
for i, item := range list {
if item == element {
@ -485,34 +430,15 @@ func (a *App) Serve(ctx context.Context) {
a.startServices()
go func() {
addr := a.cfg.GetString(cfgListenAddress)
a.log.Info("starting server", zap.String("bind", addr))
for i := range a.servers {
go func(i int) {
a.log.Info("starting server", zap.String("address", a.servers[i].Address()))
var lic net.ListenConfig
ln, err := lic.Listen(ctx, "tcp", addr)
if err != nil {
a.log.Fatal("could not prepare listener", zap.Error(err))
}
if a.tlsProvider.Enabled {
certFile := a.cfg.GetString(cfgTLSCertFile)
keyFile := a.cfg.GetString(cfgTLSKeyFile)
a.log.Info("using certificate", zap.String("cert", certFile), zap.String("key", keyFile))
if err = a.tlsProvider.UpdateCert(certFile, keyFile); err != nil {
a.log.Fatal("failed to update cert", zap.Error(err))
}
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: a.tlsProvider.GetCertificate,
})
}
if err = srv.Serve(ln); err != nil && err != http.ErrServerClosed {
if err := srv.Serve(a.servers[i].Listener()); err != nil && err != http.ErrServerClosed {
a.log.Fatal("listen and serve", zap.Error(err))
}
}()
}(i)
}
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP)
@ -558,8 +484,8 @@ func (a *App) configReload() {
a.log.Warn("failed to reload resolvers", zap.Error(err))
}
if err := a.tlsProvider.UpdateCert(a.cfg.GetString(cfgTLSCertFile), a.cfg.GetString(cfgTLSKeyFile)); err != nil {
a.log.Warn("failed to reload TLS certs", zap.Error(err))
if err := a.updateServers(); err != nil {
a.log.Warn("failed to reload server parameters", zap.Error(err))
}
a.stopServices()
@ -586,6 +512,8 @@ func (a *App) updateSettings() {
}
func (a *App) startServices() {
a.services = a.services[:0]
pprofService := NewPprofService(a.cfg, a.log)
a.services = append(a.services, pprofService)
go pprofService.Start()
@ -595,6 +523,40 @@ func (a *App) startServices() {
go prometheusService.Start()
}
func (a *App) initServers(ctx context.Context) {
serversInfo := fetchServers(a.cfg)
a.servers = make([]Server, len(serversInfo))
for i, serverInfo := range serversInfo {
a.log.Info("added server",
zap.String("address", serverInfo.Address), zap.Bool("tls enabled", serverInfo.TLS.Enabled),
zap.String("tls cert", serverInfo.TLS.CertFile), zap.String("tls key", serverInfo.TLS.KeyFile))
a.servers[i] = newServer(ctx, serverInfo, a.log)
}
}
func (a *App) updateServers() error {
serversInfo := fetchServers(a.cfg)
if len(serversInfo) != len(a.servers) {
return fmt.Errorf("invalid servers configuration: amount mismatch: old '%d', new '%d", len(a.servers), len(serversInfo))
}
for i, serverInfo := range serversInfo {
if serverInfo.Address != a.servers[i].Address() {
return fmt.Errorf("invalid servers configuration: addresses mismatch: old '%s', new '%s", a.servers[i].Address(), serverInfo.Address)
}
if serverInfo.TLS.Enabled {
if err := a.servers[i].UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
return fmt.Errorf("failed to update tls certs: %w", err)
}
}
}
return nil
}
func (a *App) stopServices() {
ctx, cancel := shutdownContext()
defer cancel()
@ -690,7 +652,6 @@ func (a *App) initHandler() {
Policy: a.settings.policies,
DefaultMaxAge: handler.DefaultMaxAge,
NotificatorEnabled: a.cfg.GetBool(cfgEnableNATS),
TLSEnabled: a.cfg.IsSet(cfgTLSKeyFile) && a.cfg.IsSet(cfgTLSCertFile),
CopiesNumber: handler.DefaultCopiesNumber,
}

View file

@ -42,7 +42,9 @@ const ( // Settings.
cmdWallet = "wallet"
cmdAddress = "address"
// HTTPS/TLS.
// Server.
cfgServer = "server"
cfgTLSEnabled = "tls.enabled"
cfgTLSKeyFile = "tls.key_file"
cfgTLSCertFile = "tls.cert_file"
@ -94,7 +96,6 @@ const ( // Settings.
cfgPProfEnabled = "pprof.enabled"
cfgPProfAddress = "pprof.address"
cfgListenAddress = "listen_address"
cfgListenDomains = "listen_domains"
// Peers.
@ -118,6 +119,8 @@ const ( // Settings.
cmdPProf = "pprof"
cmdMetrics = "metrics"
cmdListenAddress = "listen_address"
// Configuration of parameters of requests to NeoFS.
// Number of the object copies to consider PUT to NeoFS successful.
cfgSetCopiesNumber = "neofs.set_copies_number"
@ -167,6 +170,28 @@ func fetchPeers(l *zap.Logger, v *viper.Viper) []pool.NodeParam {
return nodes
}
func fetchServers(v *viper.Viper) []ServerInfo {
var servers []ServerInfo
for i := 0; ; i++ {
key := cfgServer + "." + strconv.Itoa(i) + "."
var serverInfo ServerInfo
serverInfo.Address = v.GetString(key + "address")
serverInfo.TLS.Enabled = v.GetBool(key + cfgTLSEnabled)
serverInfo.TLS.KeyFile = v.GetString(key + cfgTLSKeyFile)
serverInfo.TLS.CertFile = v.GetString(key + cfgTLSCertFile)
if serverInfo.Address == "" {
break
}
servers = append(servers, serverInfo)
}
return servers
}
func newSettings() *viper.Viper {
v := viper.New()
@ -198,7 +223,7 @@ func newSettings() *viper.Viper {
flags.Int(cfgMaxClientsCount, defaultMaxClientsCount, "set max-clients count")
flags.Duration(cfgMaxClientsDeadline, defaultMaxClientsDeadline, "set max-clients deadline")
flags.String(cfgListenAddress, "0.0.0.0:8080", "set address to listen")
flags.String(cmdListenAddress, "0.0.0.0:8080", "set address to listen")
flags.String(cfgTLSCertFile, "", "TLS certificate file to use")
flags.String(cfgTLSKeyFile, "", "TLS key file to use")
@ -221,29 +246,19 @@ func newSettings() *viper.Viper {
v.SetDefault(cfgPProfAddress, "localhost:8085")
v.SetDefault(cfgPrometheusAddress, "localhost:8086")
// Binding flags
if err := v.BindPFlag(cfgPProfEnabled, flags.Lookup(cmdPProf)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgPrometheusEnabled, flags.Lookup(cmdMetrics)); err != nil {
panic(err)
}
if err := v.BindPFlags(flags); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgWalletPath, flags.Lookup(cmdWallet)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgWalletAddress, flags.Lookup(cmdAddress)); err != nil {
panic(err)
// Bind flags
if err := bindFlags(v, flags); err != nil {
panic(fmt.Errorf("bind flags: %w", err))
}
if err := flags.Parse(os.Args); err != nil {
panic(err)
}
if v.IsSet(cfgServer+".0."+cfgTLSKeyFile) && v.IsSet(cfgServer+".0."+cfgTLSCertFile) {
v.Set(cfgServer+".0."+cfgTLSEnabled, true)
}
if resolveMethods != nil {
v.SetDefault(cfgResolveOrder, *resolveMethods)
}
@ -252,6 +267,7 @@ func newSettings() *viper.Viper {
for i := range *peers {
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".address", (*peers)[i])
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".weight", 1)
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".priority", 1)
}
}
@ -306,6 +322,54 @@ func newSettings() *viper.Viper {
return v
}
func bindFlags(v *viper.Viper, flags *pflag.FlagSet) error {
if err := v.BindPFlag(cfgPProfEnabled, flags.Lookup(cmdPProf)); err != nil {
return err
}
if err := v.BindPFlag(cfgPrometheusEnabled, flags.Lookup(cmdMetrics)); err != nil {
return err
}
if err := v.BindPFlag(cmdConfig, flags.Lookup(cmdConfig)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletPath, flags.Lookup(cmdWallet)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletAddress, flags.Lookup(cmdAddress)); err != nil {
return err
}
if err := v.BindPFlag(cfgHealthcheckTimeout, flags.Lookup(cfgHealthcheckTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgConnectTimeout, flags.Lookup(cfgConnectTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgRebalanceInterval, flags.Lookup(cfgRebalanceInterval)); err != nil {
return err
}
if err := v.BindPFlag(cfgMaxClientsCount, flags.Lookup(cfgMaxClientsCount)); err != nil {
return err
}
if err := v.BindPFlag(cfgMaxClientsDeadline, flags.Lookup(cfgMaxClientsDeadline)); err != nil {
return err
}
if err := v.BindPFlag(cfgRPCEndpoint, flags.Lookup(cfgRPCEndpoint)); err != nil {
return err
}
if err := v.BindPFlag(cfgServer+".0.address", flags.Lookup(cmdListenAddress)); err != nil {
return err
}
if err := v.BindPFlag(cfgServer+".0."+cfgTLSKeyFile, flags.Lookup(cfgTLSKeyFile)); err != nil {
return err
}
if err := v.BindPFlag(cfgServer+".0."+cfgTLSCertFile, flags.Lookup(cfgTLSCertFile)); err != nil {
return err
}
return nil
}
func readConfig(v *viper.Viper) error {
cfgFileName := v.GetString(cmdConfig)
cfgFile, err := os.Open(cfgFileName)

124
cmd/s3-gw/server.go Normal file
View file

@ -0,0 +1,124 @@
package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"go.uber.org/zap"
)
type (
ServerInfo struct {
Address string
TLS ServerTLSInfo
}
ServerTLSInfo struct {
Enabled bool
CertFile string
KeyFile string
}
Server interface {
Address() string
Listener() net.Listener
UpdateCert(certFile, keyFile string) error
}
server struct {
address string
listener net.Listener
tlsProvider *certProvider
}
certProvider struct {
Enabled bool
mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
}
)
func (s *server) Address() string {
return s.address
}
func (s *server) Listener() net.Listener {
return s.listener
}
func (s *server) UpdateCert(certFile, keyFile string) error {
return s.tlsProvider.UpdateCert(certFile, keyFile)
}
func newServer(ctx context.Context, serverInfo ServerInfo, logger *zap.Logger) *server {
var lic net.ListenConfig
ln, err := lic.Listen(ctx, "tcp", serverInfo.Address)
if err != nil {
logger.Fatal("could not prepare listener", zap.String("address", serverInfo.Address), zap.Error(err))
}
tlsProvider := &certProvider{
Enabled: serverInfo.TLS.Enabled,
}
if serverInfo.TLS.Enabled {
if err = tlsProvider.UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
logger.Fatal("failed to update cert", zap.Error(err))
}
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: tlsProvider.GetCertificate,
})
}
return &server{
address: serverInfo.Address,
listener: ln,
tlsProvider: tlsProvider,
}
}
func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}
func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}
p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}
func (p *certProvider) FilePaths() (string, string) {
if !p.Enabled {
return "", ""
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.certPath, p.keyPath
}

View file

@ -24,9 +24,14 @@ S3_GW_PEERS_2_PRIORITY=2
S3_GW_PEERS_2_WEIGHT=0.9
# Address to listen and TLS
S3_GW_LISTEN_ADDRESS=0.0.0.0:8080
S3_GW_TLS_CERT_FILE=/path/to/tls/cert
S3_GW_TLS_KEY_FILE=/path/to/tls/key
S3_GW_SERVER_0_ADDRESS=0.0.0.0:8080
S3_GW_SERVER_0_TLS_ENABLED=false
S3_GW_SERVER_0_TLS_CERT_FILE=/path/to/tls/cert
S3_GW_SERVER_0_TLS_KEY_FILE=/path/to/tls/key
S3_GW_SERVER_1_ADDRESS=0.0.0.0:8081
S3_GW_SERVER_1_TLS_ENABLED=true
S3_GW_SERVER_1_TLS_CERT_FILE=/path/to/tls/cert
S3_GW_SERVER_1_TLS_KEY_FILE=/path/to/tls/key
# Domains to be able to use virtual-hosted-style access to bucket.
S3_GW_LISTEN_DOMAINS=s3dev.neofs.devenv

View file

@ -25,9 +25,15 @@ peers:
priority: 2
weight: 0.9
# Address to listen and TLS
listen_address: 0.0.0.0:8084
server:
- address: 0.0.0.0:8080
tls:
enabled: false
cert_file: /path/to/cert
key_file: /path/to/key
- address: 0.0.0.0:8081
tls:
enabled: true
cert_file: /path/to/cert
key_file: /path/to/key