package main import ( "context" "crypto/tls" "errors" "fmt" "net" "sync" "go.uber.org/zap" ) type ( ServerInfo struct { Address string TLS ServerTLSInfo } ServerTLSInfo struct { Enabled bool CertFile string KeyFile string } Server interface { Address() string Listener() net.Listener UpdateCert(certFile, keyFile string) error } server struct { address string listener net.Listener tlsProvider *certProvider } certProvider struct { Enabled bool mu sync.RWMutex certPath string keyPath string cert *tls.Certificate } ) func (s *server) Address() string { return s.address } func (s *server) Listener() net.Listener { return s.listener } func (s *server) UpdateCert(certFile, keyFile string) error { return s.tlsProvider.UpdateCert(certFile, keyFile) } func newServer(ctx context.Context, serverInfo ServerInfo, logger *zap.Logger) *server { var lic net.ListenConfig ln, err := lic.Listen(ctx, "tcp", serverInfo.Address) if err != nil { logger.Fatal("could not prepare listener", zap.String("address", serverInfo.Address), zap.Error(err)) } tlsProvider := &certProvider{ Enabled: serverInfo.TLS.Enabled, } if serverInfo.TLS.Enabled { if err = tlsProvider.UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil { logger.Fatal("failed to update cert", zap.Error(err)) } ln = tls.NewListener(ln, &tls.Config{ GetCertificate: tlsProvider.GetCertificate, }) } return &server{ address: serverInfo.Address, listener: ln, tlsProvider: tlsProvider, } } func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { if !p.Enabled { return nil, errors.New("cert provider: disabled") } p.mu.RLock() defer p.mu.RUnlock() return p.cert, nil } func (p *certProvider) UpdateCert(certPath, keyPath string) error { if !p.Enabled { return fmt.Errorf("tls disabled") } cert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err) } p.mu.Lock() p.certPath = certPath p.keyPath = keyPath p.cert = &cert p.mu.Unlock() return nil } func (p *certProvider) FilePaths() (string, string) { if !p.Enabled { return "", "" } p.mu.RLock() defer p.mu.RUnlock() return p.certPath, p.keyPath }