[#71] Validate config for unknown keys

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-10-21 16:53:30 +03:00 committed by Alex Vanin
parent 3f05207530
commit 6f789f149e
7 changed files with 328 additions and 100 deletions

View file

@ -44,6 +44,9 @@ const (
cfgPoolErrorThreshold = "pool.error-threshold"
cmdPeers = "peers"
cfgPeers = "pool." + cmdPeers
cfgPeerAddress = "address"
cfgPeerPriority = "priority"
cfgPeerWeight = "weight"
// Metrics / Profiler.
cfgPrometheusEnabled = "prometheus.enabled"
@ -59,6 +62,9 @@ const (
cfgWalletAddress = "wallet.address"
cfgWalletPassphrase = "wallet.passphrase"
// Config section for autogenerated flags.
cfgServerSection = "server."
// Command line args.
cmdHelp = "help"
cmdVersion = "version"
@ -70,7 +76,7 @@ const (
)
var ignore = map[string]struct{}{
cfgPeers: {},
cmdPeers: {},
cmdHelp: {},
cmdVersion: {},
}
@ -105,7 +111,8 @@ func config() *viper.Viper {
flagSet.StringP(cmdWallet, "w", "", `path to the wallet`)
flagSet.String(cmdAddress, "", `address of wallet account`)
config := flagSet.String(cmdConfig, "", "config path")
configFlag := flagSet.String(cmdConfig, "", "config path")
flagSet.Duration(cmdNodeDialTimeout, defaultConnectTimeout, "gRPC node connect timeout")
flagSet.Duration(cmdHealthcheckTimeout, defaultHealthcheckTimeout, "gRPC healthcheck timeout")
flagSet.Duration(cmdRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer")
@ -127,24 +134,8 @@ func config() *viper.Viper {
v.SetDefault(cfgLoggerLevel, "debug")
// Bind flags
if err := v.BindPFlag(cfgPprofEnabled, flagSet.Lookup(cmdPprof)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgPrometheusEnabled, flagSet.Lookup(cmdMetrics)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgNodeDialTimeout, flagSet.Lookup(cmdNodeDialTimeout)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgHealthcheckTimeout, flagSet.Lookup(cmdHealthcheckTimeout)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgRebalance, flagSet.Lookup(cmdRebalance)); err != nil {
panic(err)
}
if err := v.BindPFlags(flagSet); err != nil {
panic(err)
if err := bindFlags(v, flagSet); err != nil {
panic(fmt.Errorf("bind flags: %w", err))
}
if err := flagSet.Parse(os.Args); err != nil {
@ -175,27 +166,140 @@ func config() *viper.Viper {
case version != nil && *version:
fmt.Printf("NeoFS REST Gateway\nVersion: %s\nGoVersion: %s\n", Version, runtime.Version())
os.Exit(0)
}
if v.IsSet(cmdConfig) {
if cfgFile, err := os.Open(*config); err != nil {
case configFlag != nil && *configFlag != "":
if cfgFile, err := os.Open(*configFlag); err != nil {
panic(err)
} else if err := v.ReadConfig(cfgFile); err != nil {
} else if err = v.ReadConfig(cfgFile); err != nil {
panic(err)
}
}
if peers != nil && len(*peers) > 0 {
for i := range *peers {
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".address", (*peers)[i])
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".weight", 1)
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".priority", 1)
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerAddress, (*peers)[i])
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
}
}
return v
}
func bindFlags(v *viper.Viper, flagSet *pflag.FlagSet) error {
if err := v.BindPFlag(cfgPprofEnabled, flagSet.Lookup(cmdPprof)); err != nil {
return err
}
if err := v.BindPFlag(cfgPrometheusEnabled, flagSet.Lookup(cmdMetrics)); err != nil {
return err
}
if err := v.BindPFlag(cfgNodeDialTimeout, flagSet.Lookup(cmdNodeDialTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgHealthcheckTimeout, flagSet.Lookup(cmdHealthcheckTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgRebalance, flagSet.Lookup(cmdRebalance)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletPath, flagSet.Lookup(cmdWallet)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletAddress, flagSet.Lookup(cmdAddress)); err != nil {
return err
}
if err := restapi.BindFlagsToConfig(v, flagSet, cfgServerSection); err != nil {
return err
}
return nil
}
var knownConfigParams = map[string]struct{}{
cfgWalletAddress: {},
cfgWalletPath: {},
cfgWalletPassphrase: {},
cfgRebalance: {},
cfgHealthcheckTimeout: {},
cfgNodeDialTimeout: {},
cfgPoolErrorThreshold: {},
cfgLoggerLevel: {},
cfgPrometheusEnabled: {},
cfgPrometheusAddress: {},
cfgPprofEnabled: {},
cfgPprofAddress: {},
cfgServerSection + restapi.FlagScheme: {},
cfgServerSection + restapi.FlagCleanupTimeout: {},
cfgServerSection + restapi.FlagGracefulTimeout: {},
cfgServerSection + restapi.FlagMaxHeaderSize: {},
cfgServerSection + restapi.FlagListenAddress: {},
cfgServerSection + restapi.FlagListenLimit: {},
cfgServerSection + restapi.FlagKeepAlive: {},
cfgServerSection + restapi.FlagReadTimeout: {},
cfgServerSection + restapi.FlagWriteTimeout: {},
cfgServerSection + restapi.FlagTLSListenAddress: {},
cfgServerSection + restapi.FlagTLSCertificate: {},
cfgServerSection + restapi.FlagTLSKey: {},
cfgServerSection + restapi.FlagTLSCa: {},
cfgServerSection + restapi.FlagTLSListenLimit: {},
cfgServerSection + restapi.FlagTLSKeepAlive: {},
cfgServerSection + restapi.FlagTLSReadTimeout: {},
cfgServerSection + restapi.FlagTLSWriteTimeout: {},
}
func validateConfig(cfg *viper.Viper, logger *zap.Logger) {
peerNumsMap := make(map[int]struct{})
for _, providedKey := range cfg.AllKeys() {
if !strings.HasPrefix(providedKey, cfgPeers) {
if _, ok := knownConfigParams[providedKey]; !ok {
logger.Warn("unknown config parameter", zap.String("key", providedKey))
}
continue
}
num, ok := isValidPeerKey(providedKey)
if !ok {
logger.Warn("unknown config parameter", zap.String("key", providedKey))
} else {
peerNumsMap[num] = struct{}{}
}
}
peerNums := make([]int, 0, len(peerNumsMap))
for num := range peerNumsMap {
peerNums = append(peerNums, num)
}
sort.Ints(peerNums)
for i, num := range peerNums {
if i != num {
logger.Warn("invalid config parameter, peer indexes must be consecutive starting from 0", zap.String("key", cfgPeers+"."+strconv.Itoa(num)))
}
}
}
func isValidPeerKey(key string) (int, bool) {
trimmed := strings.TrimPrefix(key, cfgPeers)
split := strings.Split(trimmed, ".")
if len(split) != 3 {
return 0, false
}
if split[2] != cfgPeerAddress && split[2] != cfgPeerPriority && split[2] != cfgPeerWeight {
return 0, false
}
num, err := strconv.Atoi(split[1])
if err != nil || num < 0 {
return 0, false
}
return num, true
}
func getNeoFSKey(logger *zap.Logger, cfg *viper.Viper) (*keys.PrivateKey, error) {
walletPath := cfg.GetString(cmdWallet)
if len(walletPath) == 0 {
@ -303,22 +407,22 @@ func newLogger(v *viper.Viper) *zap.Logger {
func serverConfig(v *viper.Viper) *restapi.ServerConfig {
return &restapi.ServerConfig{
EnabledListeners: v.GetStringSlice(restapi.FlagScheme),
CleanupTimeout: v.GetDuration(restapi.FlagCleanupTimeout),
GracefulTimeout: v.GetDuration(restapi.FlagGracefulTimeout),
MaxHeaderSize: v.GetInt(restapi.FlagMaxHeaderSize),
EnabledListeners: v.GetStringSlice(cfgServerSection + restapi.FlagScheme),
CleanupTimeout: v.GetDuration(cfgServerSection + restapi.FlagCleanupTimeout),
GracefulTimeout: v.GetDuration(cfgServerSection + restapi.FlagGracefulTimeout),
MaxHeaderSize: v.GetInt(cfgServerSection + restapi.FlagMaxHeaderSize),
ListenAddress: v.GetString(restapi.FlagListenAddress),
ListenLimit: v.GetInt(restapi.FlagListenLimit),
KeepAlive: v.GetDuration(restapi.FlagKeepAlive),
ReadTimeout: v.GetDuration(restapi.FlagReadTimeout),
WriteTimeout: v.GetDuration(restapi.FlagWriteTimeout),
ListenAddress: v.GetString(cfgServerSection + restapi.FlagListenAddress),
ListenLimit: v.GetInt(cfgServerSection + restapi.FlagListenLimit),
KeepAlive: v.GetDuration(cfgServerSection + restapi.FlagKeepAlive),
ReadTimeout: v.GetDuration(cfgServerSection + restapi.FlagReadTimeout),
WriteTimeout: v.GetDuration(cfgServerSection + restapi.FlagWriteTimeout),
TLSListenAddress: v.GetString(restapi.FlagTLSListenAddress),
TLSListenLimit: v.GetInt(restapi.FlagTLSListenLimit),
TLSKeepAlive: v.GetDuration(restapi.FlagTLSKeepAlive),
TLSReadTimeout: v.GetDuration(restapi.FlagTLSReadTimeout),
TLSWriteTimeout: v.GetDuration(restapi.FlagTLSWriteTimeout),
TLSListenAddress: v.GetString(cfgServerSection + restapi.FlagTLSListenAddress),
TLSListenLimit: v.GetInt(cfgServerSection + restapi.FlagTLSListenLimit),
TLSKeepAlive: v.GetDuration(cfgServerSection + restapi.FlagTLSKeepAlive),
TLSReadTimeout: v.GetDuration(cfgServerSection + restapi.FlagTLSReadTimeout),
TLSWriteTimeout: v.GetDuration(cfgServerSection + restapi.FlagTLSWriteTimeout),
}
}
@ -371,9 +475,9 @@ func fetchPeers(l *zap.Logger, v *viper.Viper) []pool.NodeParam {
var nodes []pool.NodeParam
for i := 0; ; i++ {
key := cfgPeers + "." + strconv.Itoa(i) + "."
address := v.GetString(key + "address")
weight := v.GetFloat64(key + "weight")
priority := v.GetInt(key + "priority")
address := v.GetString(key + cfgPeerAddress)
weight := v.GetFloat64(key + cfgPeerWeight)
priority := v.GetInt(key + cfgPeerPriority)
if address == "" {
break