This repository has been archived on 2024-09-11. You can view files and clone it, but cannot push or open issues or pull requests.
frostfs-rest-gw/cmd/neofs-rest-gw/config.go
Denis Kirillov 6f789f149e [#71] Validate config for unknown keys
Signed-off-by: Denis Kirillov <denis@nspcc.ru>
2022-11-10 11:30:06 +03:00

502 lines
14 KiB
Go

package main
import (
"context"
"fmt"
"os"
"runtime"
"sort"
"strconv"
"strings"
"time"
"github.com/nspcc-dev/neo-go/cli/flags"
"github.com/nspcc-dev/neo-go/cli/input"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/wallet"
"github.com/nspcc-dev/neofs-rest-gw/gen/restapi"
"github.com/nspcc-dev/neofs-rest-gw/handlers"
"github.com/nspcc-dev/neofs-rest-gw/metrics"
"github.com/nspcc-dev/neofs-sdk-go/pool"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
const (
defaultConnectTimeout = 10 * time.Second
defaultHealthcheckTimeout = 15 * time.Second
defaultRebalanceTimer = 60 * time.Second
defaultShutdownTimeout = 15 * time.Second
defaultPoolErrorThreshold uint32 = 100
// Pool config.
cmdNodeDialTimeout = "node-dial-timeout"
cfgNodeDialTimeout = "pool." + cmdNodeDialTimeout
cmdHealthcheckTimeout = "healthcheck-timeout"
cfgHealthcheckTimeout = "pool." + cmdHealthcheckTimeout
cmdRebalance = "rebalance-timer"
cfgRebalance = "pool." + cmdRebalance
cfgPoolErrorThreshold = "pool.error-threshold"
cmdPeers = "peers"
cfgPeers = "pool." + cmdPeers
cfgPeerAddress = "address"
cfgPeerPriority = "priority"
cfgPeerWeight = "weight"
// Metrics / Profiler.
cfgPrometheusEnabled = "prometheus.enabled"
cfgPrometheusAddress = "prometheus.address"
cfgPprofEnabled = "pprof.enabled"
cfgPprofAddress = "pprof.address"
// Logger.
cfgLoggerLevel = "logger.level"
// Wallet.
cfgWalletPath = "wallet.path"
cfgWalletAddress = "wallet.address"
cfgWalletPassphrase = "wallet.passphrase"
// Config section for autogenerated flags.
cfgServerSection = "server."
// Command line args.
cmdHelp = "help"
cmdVersion = "version"
cmdPprof = "pprof"
cmdMetrics = "metrics"
cmdWallet = "wallet"
cmdAddress = "address"
cmdConfig = "config"
)
var ignore = map[string]struct{}{
cmdPeers: {},
cmdHelp: {},
cmdVersion: {},
}
// Prefix is a prefix used for environment variables containing gateway
// configuration.
const Prefix = "REST_GW"
var (
// Version is gateway version.
Version = "dev"
)
func config() *viper.Viper {
v := viper.New()
v.AutomaticEnv()
v.SetEnvPrefix(Prefix)
v.AllowEmptyEnv(true)
v.SetConfigType("yaml")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
// flags setup:
flagSet := pflag.NewFlagSet("commandline", pflag.ExitOnError)
flagSet.SetOutput(os.Stdout)
flagSet.SortFlags = false
flagSet.Bool(cmdPprof, false, "enable pprof")
flagSet.Bool(cmdMetrics, false, "enable prometheus")
help := flagSet.BoolP(cmdHelp, "h", false, "show help")
version := flagSet.BoolP(cmdVersion, "v", false, "show version")
flagSet.StringP(cmdWallet, "w", "", `path to the wallet`)
flagSet.String(cmdAddress, "", `address of wallet account`)
configFlag := flagSet.String(cmdConfig, "", "config path")
flagSet.Duration(cmdNodeDialTimeout, defaultConnectTimeout, "gRPC node connect timeout")
flagSet.Duration(cmdHealthcheckTimeout, defaultHealthcheckTimeout, "gRPC healthcheck timeout")
flagSet.Duration(cmdRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer")
peers := flagSet.StringArrayP(cmdPeers, "p", nil, "NeoFS nodes")
// init server flags
restapi.BindDefaultFlags(flagSet)
// set defaults:
// pool
v.SetDefault(cfgPoolErrorThreshold, defaultPoolErrorThreshold)
// metrics
v.SetDefault(cfgPprofAddress, "localhost:8091")
v.SetDefault(cfgPrometheusAddress, "localhost:8092")
// logger:
v.SetDefault(cfgLoggerLevel, "debug")
// Bind flags
if err := bindFlags(v, flagSet); err != nil {
panic(fmt.Errorf("bind flags: %w", err))
}
if err := flagSet.Parse(os.Args); err != nil {
panic(err)
}
switch {
case help != nil && *help:
fmt.Printf("NeoFS REST Gateway %s\n", Version)
flagSet.PrintDefaults()
fmt.Println()
fmt.Println("Default environments:")
fmt.Println()
cmdKeys := v.AllKeys()
sort.Strings(cmdKeys)
for i := range cmdKeys {
if _, ok := ignore[cmdKeys[i]]; ok {
continue
}
k := strings.Replace(cmdKeys[i], ".", "_", -1)
fmt.Printf("%s_%s = %v\n", Prefix, strings.ToUpper(k), v.Get(cmdKeys[i]))
}
os.Exit(0)
case version != nil && *version:
fmt.Printf("NeoFS REST Gateway\nVersion: %s\nGoVersion: %s\n", Version, runtime.Version())
os.Exit(0)
case configFlag != nil && *configFlag != "":
if cfgFile, err := os.Open(*configFlag); err != nil {
panic(err)
} else if err = v.ReadConfig(cfgFile); err != nil {
panic(err)
}
}
if peers != nil && len(*peers) > 0 {
for i := range *peers {
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerAddress, (*peers)[i])
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
}
}
return v
}
func bindFlags(v *viper.Viper, flagSet *pflag.FlagSet) error {
if err := v.BindPFlag(cfgPprofEnabled, flagSet.Lookup(cmdPprof)); err != nil {
return err
}
if err := v.BindPFlag(cfgPrometheusEnabled, flagSet.Lookup(cmdMetrics)); err != nil {
return err
}
if err := v.BindPFlag(cfgNodeDialTimeout, flagSet.Lookup(cmdNodeDialTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgHealthcheckTimeout, flagSet.Lookup(cmdHealthcheckTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgRebalance, flagSet.Lookup(cmdRebalance)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletPath, flagSet.Lookup(cmdWallet)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletAddress, flagSet.Lookup(cmdAddress)); err != nil {
return err
}
if err := restapi.BindFlagsToConfig(v, flagSet, cfgServerSection); err != nil {
return err
}
return nil
}
var knownConfigParams = map[string]struct{}{
cfgWalletAddress: {},
cfgWalletPath: {},
cfgWalletPassphrase: {},
cfgRebalance: {},
cfgHealthcheckTimeout: {},
cfgNodeDialTimeout: {},
cfgPoolErrorThreshold: {},
cfgLoggerLevel: {},
cfgPrometheusEnabled: {},
cfgPrometheusAddress: {},
cfgPprofEnabled: {},
cfgPprofAddress: {},
cfgServerSection + restapi.FlagScheme: {},
cfgServerSection + restapi.FlagCleanupTimeout: {},
cfgServerSection + restapi.FlagGracefulTimeout: {},
cfgServerSection + restapi.FlagMaxHeaderSize: {},
cfgServerSection + restapi.FlagListenAddress: {},
cfgServerSection + restapi.FlagListenLimit: {},
cfgServerSection + restapi.FlagKeepAlive: {},
cfgServerSection + restapi.FlagReadTimeout: {},
cfgServerSection + restapi.FlagWriteTimeout: {},
cfgServerSection + restapi.FlagTLSListenAddress: {},
cfgServerSection + restapi.FlagTLSCertificate: {},
cfgServerSection + restapi.FlagTLSKey: {},
cfgServerSection + restapi.FlagTLSCa: {},
cfgServerSection + restapi.FlagTLSListenLimit: {},
cfgServerSection + restapi.FlagTLSKeepAlive: {},
cfgServerSection + restapi.FlagTLSReadTimeout: {},
cfgServerSection + restapi.FlagTLSWriteTimeout: {},
}
func validateConfig(cfg *viper.Viper, logger *zap.Logger) {
peerNumsMap := make(map[int]struct{})
for _, providedKey := range cfg.AllKeys() {
if !strings.HasPrefix(providedKey, cfgPeers) {
if _, ok := knownConfigParams[providedKey]; !ok {
logger.Warn("unknown config parameter", zap.String("key", providedKey))
}
continue
}
num, ok := isValidPeerKey(providedKey)
if !ok {
logger.Warn("unknown config parameter", zap.String("key", providedKey))
} else {
peerNumsMap[num] = struct{}{}
}
}
peerNums := make([]int, 0, len(peerNumsMap))
for num := range peerNumsMap {
peerNums = append(peerNums, num)
}
sort.Ints(peerNums)
for i, num := range peerNums {
if i != num {
logger.Warn("invalid config parameter, peer indexes must be consecutive starting from 0", zap.String("key", cfgPeers+"."+strconv.Itoa(num)))
}
}
}
func isValidPeerKey(key string) (int, bool) {
trimmed := strings.TrimPrefix(key, cfgPeers)
split := strings.Split(trimmed, ".")
if len(split) != 3 {
return 0, false
}
if split[2] != cfgPeerAddress && split[2] != cfgPeerPriority && split[2] != cfgPeerWeight {
return 0, false
}
num, err := strconv.Atoi(split[1])
if err != nil || num < 0 {
return 0, false
}
return num, true
}
func getNeoFSKey(logger *zap.Logger, cfg *viper.Viper) (*keys.PrivateKey, error) {
walletPath := cfg.GetString(cmdWallet)
if len(walletPath) == 0 {
walletPath = cfg.GetString(cfgWalletPath)
}
if len(walletPath) == 0 {
logger.Info("no wallet path specified, creating ephemeral key automatically for this run")
return keys.NewPrivateKey()
}
w, err := wallet.NewWalletFromFile(walletPath)
if err != nil {
return nil, err
}
var password *string
if cfg.IsSet(cfgWalletPassphrase) {
pwd := cfg.GetString(cfgWalletPassphrase)
password = &pwd
}
address := cfg.GetString(cmdAddress)
if len(address) == 0 {
address = cfg.GetString(cfgWalletAddress)
}
return getKeyFromWallet(w, address, password)
}
func getKeyFromWallet(w *wallet.Wallet, addrStr string, password *string) (*keys.PrivateKey, error) {
var addr util.Uint160
var err error
if addrStr == "" {
addr = w.GetChangeAddress()
} else {
addr, err = flags.ParseAddress(addrStr)
if err != nil {
return nil, fmt.Errorf("invalid address")
}
}
acc := w.GetAccount(addr)
if acc == nil {
return nil, fmt.Errorf("couldn't find wallet account for %s", addrStr)
}
if password == nil {
pwd, err := input.ReadPassword("Enter password > ")
if err != nil {
return nil, fmt.Errorf("couldn't read password")
}
password = &pwd
}
if err := acc.Decrypt(*password, w.Scrypt); err != nil {
return nil, fmt.Errorf("couldn't decrypt account: %w", err)
}
return acc.PrivateKey(), nil
}
// newLogger constructs a zap.Logger instance for current application.
// Panics on failure.
//
// Logger is built from zap's production logging configuration with:
// - parameterized level (debug by default)
// - console encoding
// - ISO8601 time encoding
//
// Logger records a stack trace for all messages at or above fatal level.
//
// See also zapcore.Level, zap.NewProductionConfig, zap.AddStacktrace.
func newLogger(v *viper.Viper) *zap.Logger {
var lvl zapcore.Level
lvlStr := v.GetString(cfgLoggerLevel)
err := lvl.UnmarshalText([]byte(lvlStr))
if err != nil {
panic(fmt.Sprintf("incorrect logger level configuration %s (%v), "+
"value should be one of %v", lvlStr, err, [...]zapcore.Level{
zapcore.DebugLevel,
zapcore.InfoLevel,
zapcore.WarnLevel,
zapcore.ErrorLevel,
zapcore.DPanicLevel,
zapcore.PanicLevel,
zapcore.FatalLevel,
}))
}
c := zap.NewProductionConfig()
c.Level = zap.NewAtomicLevelAt(lvl)
c.Encoding = "console"
c.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
l, err := c.Build(
zap.AddStacktrace(zap.NewAtomicLevelAt(zap.FatalLevel)),
)
if err != nil {
panic(fmt.Sprintf("build zap logger instance: %v", err))
}
return l
}
func serverConfig(v *viper.Viper) *restapi.ServerConfig {
return &restapi.ServerConfig{
EnabledListeners: v.GetStringSlice(cfgServerSection + restapi.FlagScheme),
CleanupTimeout: v.GetDuration(cfgServerSection + restapi.FlagCleanupTimeout),
GracefulTimeout: v.GetDuration(cfgServerSection + restapi.FlagGracefulTimeout),
MaxHeaderSize: v.GetInt(cfgServerSection + restapi.FlagMaxHeaderSize),
ListenAddress: v.GetString(cfgServerSection + restapi.FlagListenAddress),
ListenLimit: v.GetInt(cfgServerSection + restapi.FlagListenLimit),
KeepAlive: v.GetDuration(cfgServerSection + restapi.FlagKeepAlive),
ReadTimeout: v.GetDuration(cfgServerSection + restapi.FlagReadTimeout),
WriteTimeout: v.GetDuration(cfgServerSection + restapi.FlagWriteTimeout),
TLSListenAddress: v.GetString(cfgServerSection + restapi.FlagTLSListenAddress),
TLSListenLimit: v.GetInt(cfgServerSection + restapi.FlagTLSListenLimit),
TLSKeepAlive: v.GetDuration(cfgServerSection + restapi.FlagTLSKeepAlive),
TLSReadTimeout: v.GetDuration(cfgServerSection + restapi.FlagTLSReadTimeout),
TLSWriteTimeout: v.GetDuration(cfgServerSection + restapi.FlagTLSWriteTimeout),
}
}
func newNeofsAPI(ctx context.Context, logger *zap.Logger, v *viper.Viper) (*handlers.API, error) {
key, err := getNeoFSKey(logger, v)
if err != nil {
return nil, err
}
var prm pool.InitParameters
prm.SetKey(&key.PrivateKey)
prm.SetNodeDialTimeout(v.GetDuration(cfgNodeDialTimeout))
prm.SetHealthcheckTimeout(v.GetDuration(cfgHealthcheckTimeout))
prm.SetClientRebalanceInterval(v.GetDuration(cfgRebalance))
prm.SetErrorThreshold(v.GetUint32(cfgPoolErrorThreshold))
for _, peer := range fetchPeers(logger, v) {
prm.AddNode(peer)
}
p, err := pool.NewPool(prm)
if err != nil {
return nil, err
}
if err = p.Dial(ctx); err != nil {
return nil, err
}
var apiPrm handlers.PrmAPI
apiPrm.Pool = p
apiPrm.Key = key
apiPrm.Logger = logger
pprofConfig := metrics.Config{Enabled: v.GetBool(cfgPprofEnabled), Address: v.GetString(cfgPprofAddress)}
apiPrm.PprofService = metrics.NewPprofService(logger, pprofConfig)
prometheusConfig := metrics.Config{Enabled: v.GetBool(cfgPrometheusEnabled), Address: v.GetString(cfgPrometheusAddress)}
apiPrm.PrometheusService = metrics.NewPrometheusService(logger, prometheusConfig)
if prometheusConfig.Enabled {
apiPrm.GateMetric = metrics.NewGateMetrics(p)
}
apiPrm.ServiceShutdownTimeout = defaultShutdownTimeout
return handlers.New(&apiPrm), nil
}
func fetchPeers(l *zap.Logger, v *viper.Viper) []pool.NodeParam {
var nodes []pool.NodeParam
for i := 0; ; i++ {
key := cfgPeers + "." + strconv.Itoa(i) + "."
address := v.GetString(key + cfgPeerAddress)
weight := v.GetFloat64(key + cfgPeerWeight)
priority := v.GetInt(key + cfgPeerPriority)
if address == "" {
break
}
if weight <= 0 { // unspecified or wrong
weight = 1
}
if priority <= 0 { // unspecified or wrong
priority = 1
}
nodes = append(nodes, pool.NewNodeParam(priority, address, weight))
l.Info("added connection peer",
zap.String("address", address),
zap.Int("priority", priority),
zap.Float64("weight", weight),
)
}
return nodes
}