diff --git a/app.go b/app.go index ee979c4..6a86726 100644 --- a/app.go +++ b/app.go @@ -4,7 +4,11 @@ import ( "context" "crypto/ecdsa" "fmt" + "os" + "os/signal" "strconv" + "sync" + "syscall" "github.com/fasthttp/router" "github.com/nspcc-dev/neo-go/cli/flags" @@ -28,13 +32,15 @@ import ( type ( app struct { log *zap.Logger + logLevel zap.AtomicLevel pool *pool.Pool owner *user.ID cfg *viper.Viper webServer *fasthttp.Server webDone chan struct{} resolver *resolver.ContainerResolver - metrics GateMetricsProvider + metrics *gateMetrics + services []*metrics.Service } // App is an interface for the main gateway function. @@ -46,18 +52,26 @@ type ( // Option is an application option. Option func(a *app) + gateMetrics struct { + logger *zap.Logger + provider GateMetricsProvider + mu sync.RWMutex + enabled bool + } + GateMetricsProvider interface { SetHealth(int32) } ) // WithLogger returns Option to set a specific logger. -func WithLogger(l *zap.Logger) Option { +func WithLogger(l *zap.Logger, lvl zap.AtomicLevel) Option { return func(a *app) { if l == nil { return } a.log = l + a.logLevel = lvl } } @@ -164,13 +178,47 @@ func newApp(ctx context.Context, opt ...Option) App { a.log.Info("container resolver is disabled") } - if a.cfg.GetBool(cfgPrometheusEnabled) { - a.metrics = metrics.NewGateMetrics(a.pool) - } + a.initMetrics() return a } +func (a *app) initMetrics() { + gateMetricsProvider := metrics.NewGateMetrics(a.pool) + a.metrics = newGateMetrics(a.log, gateMetricsProvider, a.cfg.GetBool(cfgPrometheusEnabled)) +} + +func newGateMetrics(logger *zap.Logger, provider GateMetricsProvider, enabled bool) *gateMetrics { + if !enabled { + logger.Warn("metrics are disabled") + } + return &gateMetrics{ + logger: logger, + provider: provider, + } +} + +func (m *gateMetrics) SetEnabled(enabled bool) { + if !enabled { + m.logger.Warn("metrics are disabled") + } + + m.mu.Lock() + m.enabled = enabled + m.mu.Unlock() +} + +func (m *gateMetrics) SetHealth(status int32) { + m.mu.RLock() + if !m.enabled { + m.mu.RUnlock() + return + } + m.mu.RUnlock() + + m.provider.SetHealth(status) +} + func remove(list []string, element string) []string { for i, item := range list { if item == element { @@ -242,19 +290,110 @@ func getKeyFromWallet(w *wallet.Wallet, addrStr string, password *string) (*ecds func (a *app) Wait() { a.log.Info("starting application", zap.String("app_name", "neofs-http-gw"), zap.String("version", Version)) - if a.metrics != nil { - a.metrics.SetHealth(1) - } + + a.setHealthStatus() <-a.webDone // wait for web-server to be stopped } +func (a *app) setHealthStatus() { + a.metrics.SetHealth(1) +} + func (a *app) Serve(ctx context.Context) { edts := a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp) uploadRoutes := uploader.New(ctx, a.AppParams(), edts) downloadSettings := downloader.Settings{ZipCompression: a.cfg.GetBool(cfgZipCompression)} downloadRoutes := downloader.New(ctx, a.AppParams(), downloadSettings) + // Configure router. + a.configureRouter(uploadRoutes, downloadRoutes) + + a.startServices() + + bind := a.cfg.GetString(cfgListenAddress) + tlsCertPath := a.cfg.GetString(cfgTLSCertificate) + tlsKeyPath := a.cfg.GetString(cfgTLSKey) + + go func() { + var err error + if tlsCertPath == "" && tlsKeyPath == "" { + a.log.Info("running web server", zap.String("address", bind)) + err = a.webServer.ListenAndServe(bind) + } else { + a.log.Info("running web server (TLS-enabled)", zap.String("address", bind)) + err = a.webServer.ListenAndServeTLS(bind, tlsCertPath, tlsKeyPath) + } + if err != nil { + a.log.Fatal("could not start server", zap.Error(err)) + } + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGHUP) + +LOOP: + for { + select { + case <-ctx.Done(): + break LOOP + case <-sigs: + a.configReload() + } + } + + a.log.Info("shutting down web server", zap.Error(a.webServer.Shutdown())) + + a.stopServices() + + close(a.webDone) +} + +func (a *app) configReload() { + a.log.Info("SIGHUP config reload") + if !a.cfg.IsSet(cmdConfig) { + a.log.Warn("failed to reload config because it's missed") + return + } + if err := readConfig(a.cfg); err != nil { + a.log.Warn("failed to reload config", zap.Error(err)) + return + } + if lvl, err := getLogLevel(a.cfg); err != nil { + a.log.Warn("log level won't be updated", zap.Error(err)) + } else { + a.logLevel.SetLevel(lvl) + } + + a.stopServices() + a.startServices() + + a.metrics.SetEnabled(a.cfg.GetBool(cfgPrometheusEnabled)) + a.setHealthStatus() +} + +func (a *app) startServices() { + pprofConfig := metrics.Config{Enabled: a.cfg.GetBool(cfgPprofEnabled), Address: a.cfg.GetString(cfgPprofAddress)} + pprofService := metrics.NewPprofService(a.log, pprofConfig) + a.services = append(a.services, pprofService) + go pprofService.Start() + + prometheusConfig := metrics.Config{Enabled: a.cfg.GetBool(cfgPrometheusEnabled), Address: a.cfg.GetString(cfgPrometheusAddress)} + prometheusService := metrics.NewPrometheusService(a.log, prometheusConfig) + a.services = append(a.services, prometheusService) + go prometheusService.Start() +} + +func (a *app) stopServices() { + ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) + defer cancel() + + for _, svc := range a.services { + svc.ShutDown(ctx) + } +} + +func (a *app) configureRouter(uploadRoutes *uploader.Uploader, downloadRoutes *downloader.Downloader) { r := router.New() r.RedirectTrailingSlash = true r.NotFound = func(r *fasthttp.RequestCtx) { @@ -274,55 +413,18 @@ func (a *app) Serve(ctx context.Context) { r.GET("/zip/{cid}/{prefix:*}", a.logger(downloadRoutes.DownloadZipped)) a.log.Info("added path /zip/{cid}/{prefix}") - pprofConfig := metrics.Config{Enabled: a.cfg.GetBool(cfgPprofEnabled), Address: a.cfg.GetString(cfgPprofAddress)} - pprof := metrics.NewPprofService(a.log, pprofConfig) - prometheusConfig := metrics.Config{Enabled: a.cfg.GetBool(cfgPrometheusEnabled), Address: a.cfg.GetString(cfgPrometheusAddress)} - prometheus := metrics.NewPrometheusService(a.log, prometheusConfig) - - bind := a.cfg.GetString(cfgListenAddress) - tlsCertPath := a.cfg.GetString(cfgTLSCertificate) - tlsKeyPath := a.cfg.GetString(cfgTLSKey) - a.webServer.Handler = r.Handler - - go pprof.Start() - go prometheus.Start() - - go func() { - var err error - if tlsCertPath == "" && tlsKeyPath == "" { - a.log.Info("running web server", zap.String("address", bind)) - err = a.webServer.ListenAndServe(bind) - } else { - a.log.Info("running web server (TLS-enabled)", zap.String("address", bind)) - err = a.webServer.ListenAndServeTLS(bind, tlsCertPath, tlsKeyPath) - } - if err != nil { - a.log.Fatal("could not start server", zap.Error(err)) - } - }() - - <-ctx.Done() - a.log.Info("shutting down web server", zap.Error(a.webServer.Shutdown())) - - ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) - defer cancel() - - pprof.ShutDown(ctx) - prometheus.ShutDown(ctx) - - close(a.webDone) } func (a *app) logger(h fasthttp.RequestHandler) fasthttp.RequestHandler { - return fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { + return func(ctx *fasthttp.RequestCtx) { a.log.Info("request", zap.String("remote", ctx.RemoteAddr().String()), zap.ByteString("method", ctx.Method()), zap.ByteString("path", ctx.Path()), zap.ByteString("query", ctx.QueryArgs().QueryString()), zap.Uint64("id", ctx.ID())) h(ctx) - }) + } } func (a *app) AppParams() *utils.AppParams { diff --git a/integration_test.go b/integration_test.go index 9f55486..2f4cd12 100644 --- a/integration_test.go +++ b/integration_test.go @@ -82,8 +82,8 @@ func runServer() context.CancelFunc { cancelCtx, cancel := context.WithCancel(context.Background()) v := getDefaultConfig() - l := newLogger(v) - application := newApp(cancelCtx, WithConfig(v), WithLogger(l)) + l, lvl := newLogger(v) + application := newApp(cancelCtx, WithConfig(v), WithLogger(l, lvl)) go application.Serve(cancelCtx) return cancel diff --git a/main.go b/main.go index 5b4d188..f997955 100644 --- a/main.go +++ b/main.go @@ -2,65 +2,16 @@ package main import ( "context" - "fmt" "os/signal" "syscall" - - "github.com/spf13/viper" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" ) func main() { - var ( - v = settings() - l = newLogger(v) - ) - globalContext, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - app := newApp(globalContext, WithLogger(l), WithConfig(v)) - go app.Serve(globalContext) - app.Wait() -} - -// newLogger constructs a zap.Logger instance for current application. -// Panics on failure. -// -// Logger is built from zap's production logging configuration with: -// - parameterized level (debug by default) -// - console encoding -// - ISO8601 time encoding -// -// Logger records a stack trace for all messages at or above fatal level. -// -// See also zapcore.Level, zap.NewProductionConfig, zap.AddStacktrace. -func newLogger(v *viper.Viper) *zap.Logger { - var lvl zapcore.Level - lvlStr := v.GetString(cfgLoggerLevel) - err := lvl.UnmarshalText([]byte(lvlStr)) - if err != nil { - panic(fmt.Sprintf("incorrect logger level configuration %s (%v), "+ - "value should be one of %v", lvlStr, err, [...]zapcore.Level{ - zapcore.DebugLevel, - zapcore.InfoLevel, - zapcore.WarnLevel, - zapcore.ErrorLevel, - zapcore.DPanicLevel, - zapcore.PanicLevel, - zapcore.FatalLevel, - })) - } - - c := zap.NewProductionConfig() - c.Level = zap.NewAtomicLevelAt(lvl) - c.Encoding = "console" - c.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder - - l, err := c.Build( - zap.AddStacktrace(zap.NewAtomicLevelAt(zap.FatalLevel)), - ) - if err != nil { - panic(fmt.Sprintf("build zap logger instance: %v", err)) - } - - return l + globalContext, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + v := settings() + logger, atomicLevel := newLogger(v) + + application := newApp(globalContext, WithLogger(logger, atomicLevel), WithConfig(v)) + go application.Serve(globalContext) + application.Wait() } diff --git a/settings.go b/settings.go index c9e6648..e687141 100644 --- a/settings.go +++ b/settings.go @@ -13,6 +13,8 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/valyala/fasthttp" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) const ( @@ -108,7 +110,7 @@ func settings() *viper.Viper { flags.StringP(cmdWallet, "w", "", `path to the wallet`) flags.String(cmdAddress, "", `address of wallet account`) - config := flags.String(cmdConfig, "", "config path") + flags.String(cmdConfig, "", "config path") flags.Duration(cfgConTimeout, defaultConnectTimeout, "gRPC connect timeout") flags.Duration(cfgReqTimeout, defaultRequestTimeout, "gRPC request timeout") flags.Duration(cfgRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer") @@ -213,9 +215,7 @@ func settings() *viper.Viper { } if v.IsSet(cmdConfig) { - if cfgFile, err := os.Open(*config); err != nil { - panic(err) - } else if err := v.ReadConfig(cfgFile); err != nil { + if err := readConfig(v); err != nil { panic(err) } } @@ -230,3 +230,67 @@ func settings() *viper.Viper { return v } + +func readConfig(v *viper.Viper) error { + cfgFileName := v.GetString(cmdConfig) + cfgFile, err := os.Open(cfgFileName) + if err != nil { + return err + } + if err = v.ReadConfig(cfgFile); err != nil { + return err + } + + return cfgFile.Close() +} + +// newLogger constructs a zap.Logger instance for current application. +// Panics on failure. +// +// Logger is built from zap's production logging configuration with: +// - parameterized level (debug by default) +// - console encoding +// - ISO8601 time encoding +// +// Logger records a stack trace for all messages at or above fatal level. +// +// See also zapcore.Level, zap.NewProductionConfig, zap.AddStacktrace. +func newLogger(v *viper.Viper) (*zap.Logger, zap.AtomicLevel) { + lvl, err := getLogLevel(v) + if err != nil { + panic(err) + } + + c := zap.NewProductionConfig() + c.Level = zap.NewAtomicLevelAt(lvl) + c.Encoding = "console" + c.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + l, err := c.Build( + zap.AddStacktrace(zap.NewAtomicLevelAt(zap.FatalLevel)), + ) + if err != nil { + panic(fmt.Sprintf("build zap logger instance: %v", err)) + } + + return l, c.Level +} + +func getLogLevel(v *viper.Viper) (zapcore.Level, error) { + var lvl zapcore.Level + lvlStr := v.GetString(cfgLoggerLevel) + err := lvl.UnmarshalText([]byte(lvlStr)) + if err != nil { + return lvl, fmt.Errorf("incorrect logger level configuration %s (%v), "+ + "value should be one of %v", lvlStr, err, [...]zapcore.Level{ + zapcore.DebugLevel, + zapcore.InfoLevel, + zapcore.WarnLevel, + zapcore.ErrorLevel, + zapcore.DPanicLevel, + zapcore.PanicLevel, + zapcore.FatalLevel, + }) + } + return lvl, nil +}