diff --git a/app.go b/app.go index 39dfa9f..9b34267 100644 --- a/app.go +++ b/app.go @@ -3,7 +3,10 @@ package main import ( "context" "crypto/ecdsa" + "crypto/tls" + "errors" "fmt" + "net" "os" "os/signal" "strconv" @@ -45,8 +48,9 @@ type ( } appSettings struct { - Uploader *uploader.Settings - Downloader *downloader.Settings + Uploader *uploader.Settings + Downloader *downloader.Settings + TLSProvider *certProvider } // App is an interface for the main gateway function. @@ -174,8 +178,9 @@ func newApp(ctx context.Context, opt ...Option) App { func (a *app) initAppSettings() { a.settings = &appSettings{ - Uploader: &uploader.Settings{}, - Downloader: &downloader.Settings{}, + Uploader: &uploader.Settings{}, + Downloader: &downloader.Settings{}, + TLSProvider: &certProvider{Enabled: a.cfg.IsSet(cfgTLSCertificate) || a.cfg.IsSet(cfgTLSKey)}, } a.updateSettings() @@ -335,6 +340,43 @@ func (a *app) setHealthStatus() { a.metrics.SetHealth(1) } +type certProvider struct { + Enabled bool + + mu sync.RWMutex + certPath string + keyPath string + cert *tls.Certificate +} + +func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + if !p.Enabled { + return nil, errors.New("cert provider: disabled") + } + + p.mu.RLock() + defer p.mu.RUnlock() + return p.cert, nil +} + +func (p *certProvider) UpdateCert(certPath, keyPath string) error { + if !p.Enabled { + return fmt.Errorf("tls disabled") + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err) + } + + p.mu.Lock() + p.certPath = certPath + p.keyPath = keyPath + p.cert = &cert + p.mu.Unlock() + return nil +} + func (a *app) Serve(ctx context.Context) { uploadRoutes := uploader.New(ctx, a.AppParams(), a.settings.Uploader) downloadRoutes := downloader.New(ctx, a.AppParams(), a.settings.Downloader) @@ -344,21 +386,34 @@ func (a *app) Serve(ctx context.Context) { a.startServices() - bind := a.cfg.GetString(cfgListenAddress) - tlsCertPath := a.cfg.GetString(cfgTLSCertificate) - tlsKeyPath := a.cfg.GetString(cfgTLSKey) - go func() { var err error - if tlsCertPath == "" && tlsKeyPath == "" { + defer func() { + if err != nil { + a.log.Fatal("could not start server", zap.Error(err)) + } + }() + + bind := a.cfg.GetString(cfgListenAddress) + + if a.settings.TLSProvider.Enabled { + if err = a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil { + return + } + + var ln net.Listener + if ln, err = net.Listen("tcp4", bind); err != nil { + return + } + lnTLS := tls.NewListener(ln, &tls.Config{ + GetCertificate: a.settings.TLSProvider.GetCertificate, + }) + + a.log.Info("running web server (TLS-enabled)", zap.String("address", bind)) + err = a.webServer.Serve(lnTLS) + } else { a.log.Info("running web server", zap.String("address", bind)) err = a.webServer.ListenAndServe(bind) - } else { - a.log.Info("running web server (TLS-enabled)", zap.String("address", bind)) - err = a.webServer.ListenAndServeTLS(bind, tlsCertPath, tlsKeyPath) - } - if err != nil { - a.log.Fatal("could not start server", zap.Error(err)) } }() @@ -417,6 +472,10 @@ func (a *app) configReload() { func (a *app) updateSettings() { a.settings.Uploader.SetDefaultTimestamp(a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp)) a.settings.Downloader.SetZipCompression(a.cfg.GetBool(cfgZipCompression)) + + if err := a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil { + a.log.Warn("failed to reload TLS certs", zap.Error(err)) + } } func (a *app) startServices() {