124 lines
2.3 KiB
Go
124 lines
2.3 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type (
|
|
ServerInfo struct {
|
|
Address string
|
|
TLS ServerTLSInfo
|
|
}
|
|
|
|
ServerTLSInfo struct {
|
|
Enabled bool
|
|
CertFile string
|
|
KeyFile string
|
|
}
|
|
|
|
Server interface {
|
|
Address() string
|
|
Listener() net.Listener
|
|
UpdateCert(certFile, keyFile string) error
|
|
}
|
|
|
|
server struct {
|
|
address string
|
|
listener net.Listener
|
|
tlsProvider *certProvider
|
|
}
|
|
|
|
certProvider struct {
|
|
Enabled bool
|
|
|
|
mu sync.RWMutex
|
|
certPath string
|
|
keyPath string
|
|
cert *tls.Certificate
|
|
}
|
|
)
|
|
|
|
func (s *server) Address() string {
|
|
return s.address
|
|
}
|
|
|
|
func (s *server) Listener() net.Listener {
|
|
return s.listener
|
|
}
|
|
|
|
func (s *server) UpdateCert(certFile, keyFile string) error {
|
|
return s.tlsProvider.UpdateCert(certFile, keyFile)
|
|
}
|
|
|
|
func newServer(ctx context.Context, serverInfo ServerInfo, logger *zap.Logger) *server {
|
|
var lic net.ListenConfig
|
|
ln, err := lic.Listen(ctx, "tcp", serverInfo.Address)
|
|
if err != nil {
|
|
logger.Fatal("could not prepare listener", zap.String("address", serverInfo.Address), zap.Error(err))
|
|
}
|
|
|
|
tlsProvider := &certProvider{
|
|
Enabled: serverInfo.TLS.Enabled,
|
|
}
|
|
|
|
if serverInfo.TLS.Enabled {
|
|
if err = tlsProvider.UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
|
|
logger.Fatal("failed to update cert", zap.Error(err))
|
|
}
|
|
|
|
ln = tls.NewListener(ln, &tls.Config{
|
|
GetCertificate: tlsProvider.GetCertificate,
|
|
})
|
|
}
|
|
|
|
return &server{
|
|
address: serverInfo.Address,
|
|
listener: ln,
|
|
tlsProvider: tlsProvider,
|
|
}
|
|
}
|
|
|
|
func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
if !p.Enabled {
|
|
return nil, errors.New("cert provider: disabled")
|
|
}
|
|
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
return p.cert, nil
|
|
}
|
|
|
|
func (p *certProvider) UpdateCert(certPath, keyPath string) error {
|
|
if !p.Enabled {
|
|
return fmt.Errorf("tls disabled")
|
|
}
|
|
|
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
|
|
}
|
|
|
|
p.mu.Lock()
|
|
p.certPath = certPath
|
|
p.keyPath = keyPath
|
|
p.cert = &cert
|
|
p.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (p *certProvider) FilePaths() (string, string) {
|
|
if !p.Enabled {
|
|
return "", ""
|
|
}
|
|
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
return p.certPath, p.keyPath
|
|
}
|