[#742] Add multiple listeners

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-11-09 13:07:18 +03:00 committed by Alex Vanin
parent 556374e3b0
commit dd4f66712c
15 changed files with 296 additions and 135 deletions

View file

@ -28,7 +28,6 @@ type (
Policy PlacementPolicy Policy PlacementPolicy
DefaultMaxAge int DefaultMaxAge int
NotificatorEnabled bool NotificatorEnabled bool
TLSEnabled bool
CopiesNumber uint32 CopiesNumber uint32
} }

View file

@ -94,7 +94,7 @@ func (h *handler) GetObjectAttributesHandler(w http.ResponseWriter, r *http.Requ
} }
info := extendedInfo.ObjectInfo info := extendedInfo.ObjectInfo
encryptionParams, err := h.formEncryptionParams(r.Header) encryptionParams, err := formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return

View file

@ -142,7 +142,7 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
encryptionParams, err := h.formEncryptionParams(r.Header) encryptionParams, err := formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return

View file

@ -3,6 +3,7 @@ package handler
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -284,6 +285,7 @@ func getEncryptedObjectRange(t *testing.T, tc *handlerContext, bktName, objName
} }
func setEncryptHeaders(r *http.Request) { func setEncryptHeaders(r *http.Request) {
r.TLS = &tls.ConnectionState{}
r.Header.Set(api.AmzServerSideEncryptionCustomerAlgorithm, layer.AESEncryptionAlgorithm) r.Header.Set(api.AmzServerSideEncryptionCustomerAlgorithm, layer.AESEncryptionAlgorithm)
r.Header.Set(api.AmzServerSideEncryptionCustomerKey, aes256Key) r.Header.Set(api.AmzServerSideEncryptionCustomerKey, aes256Key)
r.Header.Set(api.AmzServerSideEncryptionCustomerKeyMD5, aes256KeyMD5) r.Header.Set(api.AmzServerSideEncryptionCustomerKeyMD5, aes256KeyMD5)

View file

@ -150,7 +150,7 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
encryptionParams, err := h.formEncryptionParams(r.Header) encryptionParams, err := formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return

View file

@ -93,8 +93,7 @@ func prepareHandlerContext(t *testing.T) *handlerContext {
log: l, log: l,
obj: layer.NewLayer(l, tp, layerCfg), obj: layer.NewLayer(l, tp, layerCfg),
cfg: &Config{ cfg: &Config{
TLSEnabled: true, Policy: &placementPolicyMock{defaultPolicy: pp},
Policy: &placementPolicyMock{defaultPolicy: pp},
}, },
} }

View file

@ -53,7 +53,7 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
} }
info := extendedInfo.ObjectInfo info := extendedInfo.ObjectInfo
encryptionParams, err := h.formEncryptionParams(r.Header) encryptionParams, err := formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return

View file

@ -137,7 +137,7 @@ func (h *handler) CreateMultipartUploadHandler(w http.ResponseWriter, r *http.Re
} }
} }
p.Info.Encryption, err = h.formEncryptionParams(r.Header) p.Info.Encryption, err = formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return
@ -226,7 +226,7 @@ func (h *handler) UploadPartHandler(w http.ResponseWriter, r *http.Request) {
Reader: r.Body, Reader: r.Body,
} }
p.Info.Encryption, err = h.formEncryptionParams(r.Header) p.Info.Encryption, err = formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return
@ -331,7 +331,7 @@ func (h *handler) UploadPartCopy(w http.ResponseWriter, r *http.Request) {
Range: srcRange, Range: srcRange,
} }
p.Info.Encryption, err = h.formEncryptionParams(r.Header) p.Info.Encryption, err = formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return
@ -573,7 +573,7 @@ func (h *handler) ListPartsHandler(w http.ResponseWriter, r *http.Request) {
PartNumberMarker: partNumberMarker, PartNumberMarker: partNumberMarker,
} }
p.Info.Encryption, err = h.formEncryptionParams(r.Header) p.Info.Encryption, err = formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return
@ -608,7 +608,7 @@ func (h *handler) AbortMultipartUploadHandler(w http.ResponseWriter, r *http.Req
Key: reqInfo.ObjectName, Key: reqInfo.ObjectName,
} }
p.Encryption, err = h.formEncryptionParams(r.Header) p.Encryption, err = formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return

View file

@ -218,7 +218,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
encryption, err := h.formEncryptionParams(r.Header) encryptionParams, err := formEncryptionParams(r)
if err != nil { if err != nil {
h.logAndSendError(w, "invalid sse headers", reqInfo, err) h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return return
@ -230,7 +230,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
Reader: r.Body, Reader: r.Body,
Size: r.ContentLength, Size: r.ContentLength,
Header: metadata, Header: metadata,
Encryption: encryption, Encryption: encryptionParams,
CopiesNumber: copiesNumber, CopiesNumber: copiesNumber,
} }
@ -304,7 +304,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
if settings.VersioningEnabled() { if settings.VersioningEnabled() {
w.Header().Set(api.AmzVersionID, objInfo.VersionID()) w.Header().Set(api.AmzVersionID, objInfo.VersionID())
} }
if encryption.Enabled() { if encryptionParams.Enabled() {
addSSECHeaders(w.Header(), r.Header) addSSECHeaders(w.Header(), r.Header)
} }
@ -326,16 +326,16 @@ func getCopiesNumberOrDefault(metadata map[string]string, defaultCopiesNumber ui
return uint32(copiesNumber), nil return uint32(copiesNumber), nil
} }
func (h handler) formEncryptionParams(header http.Header) (enc encryption.Params, err error) { func formEncryptionParams(r *http.Request) (enc encryption.Params, err error) {
sseCustomerAlgorithm := header.Get(api.AmzServerSideEncryptionCustomerAlgorithm) sseCustomerAlgorithm := r.Header.Get(api.AmzServerSideEncryptionCustomerAlgorithm)
sseCustomerKey := header.Get(api.AmzServerSideEncryptionCustomerKey) sseCustomerKey := r.Header.Get(api.AmzServerSideEncryptionCustomerKey)
sseCustomerKeyMD5 := header.Get(api.AmzServerSideEncryptionCustomerKeyMD5) sseCustomerKeyMD5 := r.Header.Get(api.AmzServerSideEncryptionCustomerKeyMD5)
if len(sseCustomerAlgorithm) == 0 && len(sseCustomerKey) == 0 && len(sseCustomerKeyMD5) == 0 { if len(sseCustomerAlgorithm) == 0 && len(sseCustomerKey) == 0 && len(sseCustomerKeyMD5) == 0 {
return return
} }
if !h.cfg.TLSEnabled { if r.TLS == nil {
return enc, errorsStd.New("encryption available only when TLS is enabled") return enc, errorsStd.New("encryption available only when TLS is enabled")
} }

View file

@ -161,6 +161,7 @@ func logErrorResponse(l *zap.Logger) mux.MiddlewareFunc {
l.Info("call method", l.Info("call method",
zap.Int("status", lw.statusCode), zap.Int("status", lw.statusCode),
zap.String("host", r.Host),
zap.String("request_id", GetRequestID(r.Context())), zap.String("request_id", GetRequestID(r.Context())),
zap.String("method", mux.CurrentRoute(r).GetName()), zap.String("method", mux.CurrentRoute(r).GetName()),
zap.String("bucket", reqInfo.BucketName), zap.String("bucket", reqInfo.BucketName),

View file

@ -2,12 +2,9 @@ package main
import ( import (
"context" "context"
"crypto/tls"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -46,9 +43,10 @@ type (
obj layer.Client obj layer.Client
api api.Handler api api.Handler
servers []Server
metrics *appMetrics metrics *appMetrics
bucketResolver *resolver.BucketResolver bucketResolver *resolver.BucketResolver
tlsProvider *certProvider
services []*Service services []*Service
settings *appSettings settings *appSettings
maxClients api.MaxClients maxClients api.MaxClients
@ -67,15 +65,6 @@ type (
lvl zap.AtomicLevel lvl zap.AtomicLevel
} }
certProvider struct {
Enabled bool
mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
}
appMetrics struct { appMetrics struct {
logger *zap.Logger logger *zap.Logger
provider GateMetricsCollector provider GateMetricsCollector
@ -123,7 +112,7 @@ func newApp(ctx context.Context, log *Logger, v *viper.Viper) *App {
func (a *App) init(ctx context.Context) { func (a *App) init(ctx context.Context) {
a.initAPI(ctx) a.initAPI(ctx)
a.initMetrics() a.initMetrics()
a.initTLSProvider() a.initServers(ctx)
} }
func (a *App) initLayer(ctx context.Context) { func (a *App) initLayer(ctx context.Context) {
@ -206,12 +195,6 @@ func (a *App) initResolver() {
} }
} }
func (a *App) initTLSProvider() {
a.tlsProvider = &certProvider{
Enabled: a.cfg.IsSet(cfgTLSCertFile) || a.cfg.IsSet(cfgTLSKeyFile),
}
}
func (a *App) getResolverConfig() ([]string, *resolver.Config) { func (a *App) getResolverConfig() ([]string, *resolver.Config) {
resolveCfg := &resolver.Config{ resolveCfg := &resolver.Config{
NeoFS: neofs.NewResolverNeoFS(a.pool), NeoFS: neofs.NewResolverNeoFS(a.pool),
@ -401,44 +384,6 @@ func (m *appMetrics) Shutdown() {
m.mu.Unlock() m.mu.Unlock()
} }
func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}
func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}
p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}
func (p *certProvider) FilePaths() (string, string) {
if !p.Enabled {
return "", ""
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.certPath, p.keyPath
}
func remove(list []string, element string) []string { func remove(list []string, element string) []string {
for i, item := range list { for i, item := range list {
if item == element { if item == element {
@ -485,34 +430,15 @@ func (a *App) Serve(ctx context.Context) {
a.startServices() a.startServices()
go func() { for i := range a.servers {
addr := a.cfg.GetString(cfgListenAddress) go func(i int) {
a.log.Info("starting server", zap.String("bind", addr)) a.log.Info("starting server", zap.String("address", a.servers[i].Address()))
var lic net.ListenConfig if err := srv.Serve(a.servers[i].Listener()); err != nil && err != http.ErrServerClosed {
ln, err := lic.Listen(ctx, "tcp", addr) a.log.Fatal("listen and serve", zap.Error(err))
if err != nil {
a.log.Fatal("could not prepare listener", zap.Error(err))
}
if a.tlsProvider.Enabled {
certFile := a.cfg.GetString(cfgTLSCertFile)
keyFile := a.cfg.GetString(cfgTLSKeyFile)
a.log.Info("using certificate", zap.String("cert", certFile), zap.String("key", keyFile))
if err = a.tlsProvider.UpdateCert(certFile, keyFile); err != nil {
a.log.Fatal("failed to update cert", zap.Error(err))
} }
}(i)
ln = tls.NewListener(ln, &tls.Config{ }
GetCertificate: a.tlsProvider.GetCertificate,
})
}
if err = srv.Serve(ln); err != nil && err != http.ErrServerClosed {
a.log.Fatal("listen and serve", zap.Error(err))
}
}()
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP) signal.Notify(sigs, syscall.SIGHUP)
@ -558,8 +484,8 @@ func (a *App) configReload() {
a.log.Warn("failed to reload resolvers", zap.Error(err)) a.log.Warn("failed to reload resolvers", zap.Error(err))
} }
if err := a.tlsProvider.UpdateCert(a.cfg.GetString(cfgTLSCertFile), a.cfg.GetString(cfgTLSKeyFile)); err != nil { if err := a.updateServers(); err != nil {
a.log.Warn("failed to reload TLS certs", zap.Error(err)) a.log.Warn("failed to reload server parameters", zap.Error(err))
} }
a.stopServices() a.stopServices()
@ -586,6 +512,8 @@ func (a *App) updateSettings() {
} }
func (a *App) startServices() { func (a *App) startServices() {
a.services = a.services[:0]
pprofService := NewPprofService(a.cfg, a.log) pprofService := NewPprofService(a.cfg, a.log)
a.services = append(a.services, pprofService) a.services = append(a.services, pprofService)
go pprofService.Start() go pprofService.Start()
@ -595,6 +523,40 @@ func (a *App) startServices() {
go prometheusService.Start() go prometheusService.Start()
} }
func (a *App) initServers(ctx context.Context) {
serversInfo := fetchServers(a.cfg)
a.servers = make([]Server, len(serversInfo))
for i, serverInfo := range serversInfo {
a.log.Info("added server",
zap.String("address", serverInfo.Address), zap.Bool("tls enabled", serverInfo.TLS.Enabled),
zap.String("tls cert", serverInfo.TLS.CertFile), zap.String("tls key", serverInfo.TLS.KeyFile))
a.servers[i] = newServer(ctx, serverInfo, a.log)
}
}
func (a *App) updateServers() error {
serversInfo := fetchServers(a.cfg)
if len(serversInfo) != len(a.servers) {
return fmt.Errorf("invalid servers configuration: amount mismatch: old '%d', new '%d", len(a.servers), len(serversInfo))
}
for i, serverInfo := range serversInfo {
if serverInfo.Address != a.servers[i].Address() {
return fmt.Errorf("invalid servers configuration: addresses mismatch: old '%s', new '%s", a.servers[i].Address(), serverInfo.Address)
}
if serverInfo.TLS.Enabled {
if err := a.servers[i].UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
return fmt.Errorf("failed to update tls certs: %w", err)
}
}
}
return nil
}
func (a *App) stopServices() { func (a *App) stopServices() {
ctx, cancel := shutdownContext() ctx, cancel := shutdownContext()
defer cancel() defer cancel()
@ -690,7 +652,6 @@ func (a *App) initHandler() {
Policy: a.settings.policies, Policy: a.settings.policies,
DefaultMaxAge: handler.DefaultMaxAge, DefaultMaxAge: handler.DefaultMaxAge,
NotificatorEnabled: a.cfg.GetBool(cfgEnableNATS), NotificatorEnabled: a.cfg.GetBool(cfgEnableNATS),
TLSEnabled: a.cfg.IsSet(cfgTLSKeyFile) && a.cfg.IsSet(cfgTLSCertFile),
CopiesNumber: handler.DefaultCopiesNumber, CopiesNumber: handler.DefaultCopiesNumber,
} }

View file

@ -42,7 +42,9 @@ const ( // Settings.
cmdWallet = "wallet" cmdWallet = "wallet"
cmdAddress = "address" cmdAddress = "address"
// HTTPS/TLS. // Server.
cfgServer = "server"
cfgTLSEnabled = "tls.enabled"
cfgTLSKeyFile = "tls.key_file" cfgTLSKeyFile = "tls.key_file"
cfgTLSCertFile = "tls.cert_file" cfgTLSCertFile = "tls.cert_file"
@ -94,7 +96,6 @@ const ( // Settings.
cfgPProfEnabled = "pprof.enabled" cfgPProfEnabled = "pprof.enabled"
cfgPProfAddress = "pprof.address" cfgPProfAddress = "pprof.address"
cfgListenAddress = "listen_address"
cfgListenDomains = "listen_domains" cfgListenDomains = "listen_domains"
// Peers. // Peers.
@ -118,6 +119,8 @@ const ( // Settings.
cmdPProf = "pprof" cmdPProf = "pprof"
cmdMetrics = "metrics" cmdMetrics = "metrics"
cmdListenAddress = "listen_address"
// Configuration of parameters of requests to NeoFS. // Configuration of parameters of requests to NeoFS.
// Number of the object copies to consider PUT to NeoFS successful. // Number of the object copies to consider PUT to NeoFS successful.
cfgSetCopiesNumber = "neofs.set_copies_number" cfgSetCopiesNumber = "neofs.set_copies_number"
@ -167,6 +170,28 @@ func fetchPeers(l *zap.Logger, v *viper.Viper) []pool.NodeParam {
return nodes return nodes
} }
func fetchServers(v *viper.Viper) []ServerInfo {
var servers []ServerInfo
for i := 0; ; i++ {
key := cfgServer + "." + strconv.Itoa(i) + "."
var serverInfo ServerInfo
serverInfo.Address = v.GetString(key + "address")
serverInfo.TLS.Enabled = v.GetBool(key + cfgTLSEnabled)
serverInfo.TLS.KeyFile = v.GetString(key + cfgTLSKeyFile)
serverInfo.TLS.CertFile = v.GetString(key + cfgTLSCertFile)
if serverInfo.Address == "" {
break
}
servers = append(servers, serverInfo)
}
return servers
}
func newSettings() *viper.Viper { func newSettings() *viper.Viper {
v := viper.New() v := viper.New()
@ -198,7 +223,7 @@ func newSettings() *viper.Viper {
flags.Int(cfgMaxClientsCount, defaultMaxClientsCount, "set max-clients count") flags.Int(cfgMaxClientsCount, defaultMaxClientsCount, "set max-clients count")
flags.Duration(cfgMaxClientsDeadline, defaultMaxClientsDeadline, "set max-clients deadline") flags.Duration(cfgMaxClientsDeadline, defaultMaxClientsDeadline, "set max-clients deadline")
flags.String(cfgListenAddress, "0.0.0.0:8080", "set address to listen") flags.String(cmdListenAddress, "0.0.0.0:8080", "set address to listen")
flags.String(cfgTLSCertFile, "", "TLS certificate file to use") flags.String(cfgTLSCertFile, "", "TLS certificate file to use")
flags.String(cfgTLSKeyFile, "", "TLS key file to use") flags.String(cfgTLSKeyFile, "", "TLS key file to use")
@ -221,29 +246,19 @@ func newSettings() *viper.Viper {
v.SetDefault(cfgPProfAddress, "localhost:8085") v.SetDefault(cfgPProfAddress, "localhost:8085")
v.SetDefault(cfgPrometheusAddress, "localhost:8086") v.SetDefault(cfgPrometheusAddress, "localhost:8086")
// Binding flags // Bind flags
if err := v.BindPFlag(cfgPProfEnabled, flags.Lookup(cmdPProf)); err != nil { if err := bindFlags(v, flags); err != nil {
panic(err) panic(fmt.Errorf("bind flags: %w", err))
}
if err := v.BindPFlag(cfgPrometheusEnabled, flags.Lookup(cmdMetrics)); err != nil {
panic(err)
}
if err := v.BindPFlags(flags); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgWalletPath, flags.Lookup(cmdWallet)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgWalletAddress, flags.Lookup(cmdAddress)); err != nil {
panic(err)
} }
if err := flags.Parse(os.Args); err != nil { if err := flags.Parse(os.Args); err != nil {
panic(err) panic(err)
} }
if v.IsSet(cfgServer+".0."+cfgTLSKeyFile) && v.IsSet(cfgServer+".0."+cfgTLSCertFile) {
v.Set(cfgServer+".0."+cfgTLSEnabled, true)
}
if resolveMethods != nil { if resolveMethods != nil {
v.SetDefault(cfgResolveOrder, *resolveMethods) v.SetDefault(cfgResolveOrder, *resolveMethods)
} }
@ -252,6 +267,7 @@ func newSettings() *viper.Viper {
for i := range *peers { for i := range *peers {
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".address", (*peers)[i]) v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".address", (*peers)[i])
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".weight", 1) v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".weight", 1)
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".priority", 1)
} }
} }
@ -306,6 +322,54 @@ func newSettings() *viper.Viper {
return v return v
} }
func bindFlags(v *viper.Viper, flags *pflag.FlagSet) error {
if err := v.BindPFlag(cfgPProfEnabled, flags.Lookup(cmdPProf)); err != nil {
return err
}
if err := v.BindPFlag(cfgPrometheusEnabled, flags.Lookup(cmdMetrics)); err != nil {
return err
}
if err := v.BindPFlag(cmdConfig, flags.Lookup(cmdConfig)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletPath, flags.Lookup(cmdWallet)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletAddress, flags.Lookup(cmdAddress)); err != nil {
return err
}
if err := v.BindPFlag(cfgHealthcheckTimeout, flags.Lookup(cfgHealthcheckTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgConnectTimeout, flags.Lookup(cfgConnectTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgRebalanceInterval, flags.Lookup(cfgRebalanceInterval)); err != nil {
return err
}
if err := v.BindPFlag(cfgMaxClientsCount, flags.Lookup(cfgMaxClientsCount)); err != nil {
return err
}
if err := v.BindPFlag(cfgMaxClientsDeadline, flags.Lookup(cfgMaxClientsDeadline)); err != nil {
return err
}
if err := v.BindPFlag(cfgRPCEndpoint, flags.Lookup(cfgRPCEndpoint)); err != nil {
return err
}
if err := v.BindPFlag(cfgServer+".0.address", flags.Lookup(cmdListenAddress)); err != nil {
return err
}
if err := v.BindPFlag(cfgServer+".0."+cfgTLSKeyFile, flags.Lookup(cfgTLSKeyFile)); err != nil {
return err
}
if err := v.BindPFlag(cfgServer+".0."+cfgTLSCertFile, flags.Lookup(cfgTLSCertFile)); err != nil {
return err
}
return nil
}
func readConfig(v *viper.Viper) error { func readConfig(v *viper.Viper) error {
cfgFileName := v.GetString(cmdConfig) cfgFileName := v.GetString(cmdConfig)
cfgFile, err := os.Open(cfgFileName) cfgFile, err := os.Open(cfgFileName)

124
cmd/s3-gw/server.go Normal file
View file

@ -0,0 +1,124 @@
package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"go.uber.org/zap"
)
type (
ServerInfo struct {
Address string
TLS ServerTLSInfo
}
ServerTLSInfo struct {
Enabled bool
CertFile string
KeyFile string
}
Server interface {
Address() string
Listener() net.Listener
UpdateCert(certFile, keyFile string) error
}
server struct {
address string
listener net.Listener
tlsProvider *certProvider
}
certProvider struct {
Enabled bool
mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
}
)
func (s *server) Address() string {
return s.address
}
func (s *server) Listener() net.Listener {
return s.listener
}
func (s *server) UpdateCert(certFile, keyFile string) error {
return s.tlsProvider.UpdateCert(certFile, keyFile)
}
func newServer(ctx context.Context, serverInfo ServerInfo, logger *zap.Logger) *server {
var lic net.ListenConfig
ln, err := lic.Listen(ctx, "tcp", serverInfo.Address)
if err != nil {
logger.Fatal("could not prepare listener", zap.String("address", serverInfo.Address), zap.Error(err))
}
tlsProvider := &certProvider{
Enabled: serverInfo.TLS.Enabled,
}
if serverInfo.TLS.Enabled {
if err = tlsProvider.UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
logger.Fatal("failed to update cert", zap.Error(err))
}
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: tlsProvider.GetCertificate,
})
}
return &server{
address: serverInfo.Address,
listener: ln,
tlsProvider: tlsProvider,
}
}
func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}
func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}
p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}
func (p *certProvider) FilePaths() (string, string) {
if !p.Enabled {
return "", ""
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.certPath, p.keyPath
}

View file

@ -24,9 +24,14 @@ S3_GW_PEERS_2_PRIORITY=2
S3_GW_PEERS_2_WEIGHT=0.9 S3_GW_PEERS_2_WEIGHT=0.9
# Address to listen and TLS # Address to listen and TLS
S3_GW_LISTEN_ADDRESS=0.0.0.0:8080 S3_GW_SERVER_0_ADDRESS=0.0.0.0:8080
S3_GW_TLS_CERT_FILE=/path/to/tls/cert S3_GW_SERVER_0_TLS_ENABLED=false
S3_GW_TLS_KEY_FILE=/path/to/tls/key S3_GW_SERVER_0_TLS_CERT_FILE=/path/to/tls/cert
S3_GW_SERVER_0_TLS_KEY_FILE=/path/to/tls/key
S3_GW_SERVER_1_ADDRESS=0.0.0.0:8081
S3_GW_SERVER_1_TLS_ENABLED=true
S3_GW_SERVER_1_TLS_CERT_FILE=/path/to/tls/cert
S3_GW_SERVER_1_TLS_KEY_FILE=/path/to/tls/key
# Domains to be able to use virtual-hosted-style access to bucket. # Domains to be able to use virtual-hosted-style access to bucket.
S3_GW_LISTEN_DOMAINS=s3dev.neofs.devenv S3_GW_LISTEN_DOMAINS=s3dev.neofs.devenv

View file

@ -25,11 +25,17 @@ peers:
priority: 2 priority: 2
weight: 0.9 weight: 0.9
# Address to listen and TLS server:
listen_address: 0.0.0.0:8084 - address: 0.0.0.0:8080
tls: tls:
cert_file: /path/to/cert enabled: false
key_file: /path/to/key cert_file: /path/to/cert
key_file: /path/to/key
- address: 0.0.0.0:8081
tls:
enabled: true
cert_file: /path/to/cert
key_file: /path/to/key
# Domains to be able to use virtual-hosted-style access to bucket. # Domains to be able to use virtual-hosted-style access to bucket.
listen_domains: listen_domains: