[#200] Reload certs on SIGHUP

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-09-09 19:00:04 +03:00 committed by Kirillov Denis
parent 82eba97505
commit ad2c7ca671

81
app.go
View file

@ -3,7 +3,10 @@ package main
import ( import (
"context" "context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/tls"
"errors"
"fmt" "fmt"
"net"
"os" "os"
"os/signal" "os/signal"
"strconv" "strconv"
@ -47,6 +50,7 @@ type (
appSettings struct { appSettings struct {
Uploader *uploader.Settings Uploader *uploader.Settings
Downloader *downloader.Settings Downloader *downloader.Settings
TLSProvider *certProvider
} }
// App is an interface for the main gateway function. // App is an interface for the main gateway function.
@ -176,6 +180,7 @@ func (a *app) initAppSettings() {
a.settings = &appSettings{ a.settings = &appSettings{
Uploader: &uploader.Settings{}, Uploader: &uploader.Settings{},
Downloader: &downloader.Settings{}, Downloader: &downloader.Settings{},
TLSProvider: &certProvider{Enabled: a.cfg.IsSet(cfgTLSCertificate) || a.cfg.IsSet(cfgTLSKey)},
} }
a.updateSettings() a.updateSettings()
@ -335,6 +340,43 @@ func (a *app) setHealthStatus() {
a.metrics.SetHealth(1) a.metrics.SetHealth(1)
} }
type certProvider struct {
Enabled bool
mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
}
func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}
p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}
func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}
p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}
func (a *app) Serve(ctx context.Context) { func (a *app) Serve(ctx context.Context) {
uploadRoutes := uploader.New(ctx, a.AppParams(), a.settings.Uploader) uploadRoutes := uploader.New(ctx, a.AppParams(), a.settings.Uploader)
downloadRoutes := downloader.New(ctx, a.AppParams(), a.settings.Downloader) downloadRoutes := downloader.New(ctx, a.AppParams(), a.settings.Downloader)
@ -344,24 +386,37 @@ func (a *app) Serve(ctx context.Context) {
a.startServices() a.startServices()
bind := a.cfg.GetString(cfgListenAddress)
tlsCertPath := a.cfg.GetString(cfgTLSCertificate)
tlsKeyPath := a.cfg.GetString(cfgTLSKey)
go func() { go func() {
var err error var err error
if tlsCertPath == "" && tlsKeyPath == "" { defer func() {
a.log.Info("running web server", zap.String("address", bind))
err = a.webServer.ListenAndServe(bind)
} else {
a.log.Info("running web server (TLS-enabled)", zap.String("address", bind))
err = a.webServer.ListenAndServeTLS(bind, tlsCertPath, tlsKeyPath)
}
if err != nil { if err != nil {
a.log.Fatal("could not start server", zap.Error(err)) a.log.Fatal("could not start server", zap.Error(err))
} }
}() }()
bind := a.cfg.GetString(cfgListenAddress)
if a.settings.TLSProvider.Enabled {
if err = a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil {
return
}
var ln net.Listener
if ln, err = net.Listen("tcp4", bind); err != nil {
return
}
lnTLS := tls.NewListener(ln, &tls.Config{
GetCertificate: a.settings.TLSProvider.GetCertificate,
})
a.log.Info("running web server (TLS-enabled)", zap.String("address", bind))
err = a.webServer.Serve(lnTLS)
} else {
a.log.Info("running web server", zap.String("address", bind))
err = a.webServer.ListenAndServe(bind)
}
}()
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP) signal.Notify(sigs, syscall.SIGHUP)
@ -417,6 +472,10 @@ func (a *app) configReload() {
func (a *app) updateSettings() { func (a *app) updateSettings() {
a.settings.Uploader.SetDefaultTimestamp(a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp)) a.settings.Uploader.SetDefaultTimestamp(a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp))
a.settings.Downloader.SetZipCompression(a.cfg.GetBool(cfgZipCompression)) a.settings.Downloader.SetZipCompression(a.cfg.GetBool(cfgZipCompression))
if err := a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil {
a.log.Warn("failed to reload TLS certs", zap.Error(err))
}
} }
func (a *app) startServices() { func (a *app) startServices() {