[#71] Validate config for unknown keys

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-10-21 16:53:30 +03:00 committed by Alex Vanin
parent 3f05207530
commit 6f789f149e
7 changed files with 328 additions and 100 deletions

View file

@ -44,6 +44,9 @@ const (
cfgPoolErrorThreshold = "pool.error-threshold" cfgPoolErrorThreshold = "pool.error-threshold"
cmdPeers = "peers" cmdPeers = "peers"
cfgPeers = "pool." + cmdPeers cfgPeers = "pool." + cmdPeers
cfgPeerAddress = "address"
cfgPeerPriority = "priority"
cfgPeerWeight = "weight"
// Metrics / Profiler. // Metrics / Profiler.
cfgPrometheusEnabled = "prometheus.enabled" cfgPrometheusEnabled = "prometheus.enabled"
@ -59,6 +62,9 @@ const (
cfgWalletAddress = "wallet.address" cfgWalletAddress = "wallet.address"
cfgWalletPassphrase = "wallet.passphrase" cfgWalletPassphrase = "wallet.passphrase"
// Config section for autogenerated flags.
cfgServerSection = "server."
// Command line args. // Command line args.
cmdHelp = "help" cmdHelp = "help"
cmdVersion = "version" cmdVersion = "version"
@ -70,7 +76,7 @@ const (
) )
var ignore = map[string]struct{}{ var ignore = map[string]struct{}{
cfgPeers: {}, cmdPeers: {},
cmdHelp: {}, cmdHelp: {},
cmdVersion: {}, cmdVersion: {},
} }
@ -105,7 +111,8 @@ func config() *viper.Viper {
flagSet.StringP(cmdWallet, "w", "", `path to the wallet`) flagSet.StringP(cmdWallet, "w", "", `path to the wallet`)
flagSet.String(cmdAddress, "", `address of wallet account`) flagSet.String(cmdAddress, "", `address of wallet account`)
config := flagSet.String(cmdConfig, "", "config path")
configFlag := flagSet.String(cmdConfig, "", "config path")
flagSet.Duration(cmdNodeDialTimeout, defaultConnectTimeout, "gRPC node connect timeout") flagSet.Duration(cmdNodeDialTimeout, defaultConnectTimeout, "gRPC node connect timeout")
flagSet.Duration(cmdHealthcheckTimeout, defaultHealthcheckTimeout, "gRPC healthcheck timeout") flagSet.Duration(cmdHealthcheckTimeout, defaultHealthcheckTimeout, "gRPC healthcheck timeout")
flagSet.Duration(cmdRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer") flagSet.Duration(cmdRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer")
@ -127,24 +134,8 @@ func config() *viper.Viper {
v.SetDefault(cfgLoggerLevel, "debug") v.SetDefault(cfgLoggerLevel, "debug")
// Bind flags // Bind flags
if err := v.BindPFlag(cfgPprofEnabled, flagSet.Lookup(cmdPprof)); err != nil { if err := bindFlags(v, flagSet); err != nil {
panic(err) panic(fmt.Errorf("bind flags: %w", err))
}
if err := v.BindPFlag(cfgPrometheusEnabled, flagSet.Lookup(cmdMetrics)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgNodeDialTimeout, flagSet.Lookup(cmdNodeDialTimeout)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgHealthcheckTimeout, flagSet.Lookup(cmdHealthcheckTimeout)); err != nil {
panic(err)
}
if err := v.BindPFlag(cfgRebalance, flagSet.Lookup(cmdRebalance)); err != nil {
panic(err)
}
if err := v.BindPFlags(flagSet); err != nil {
panic(err)
} }
if err := flagSet.Parse(os.Args); err != nil { if err := flagSet.Parse(os.Args); err != nil {
@ -175,27 +166,140 @@ func config() *viper.Viper {
case version != nil && *version: case version != nil && *version:
fmt.Printf("NeoFS REST Gateway\nVersion: %s\nGoVersion: %s\n", Version, runtime.Version()) fmt.Printf("NeoFS REST Gateway\nVersion: %s\nGoVersion: %s\n", Version, runtime.Version())
os.Exit(0) os.Exit(0)
} case configFlag != nil && *configFlag != "":
if cfgFile, err := os.Open(*configFlag); err != nil {
if v.IsSet(cmdConfig) {
if cfgFile, err := os.Open(*config); err != nil {
panic(err) panic(err)
} else if err := v.ReadConfig(cfgFile); err != nil { } else if err = v.ReadConfig(cfgFile); err != nil {
panic(err) panic(err)
} }
} }
if peers != nil && len(*peers) > 0 { if peers != nil && len(*peers) > 0 {
for i := range *peers { for i := range *peers {
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".address", (*peers)[i]) v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerAddress, (*peers)[i])
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".weight", 1) v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+".priority", 1) v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
} }
} }
return v return v
} }
func bindFlags(v *viper.Viper, flagSet *pflag.FlagSet) error {
if err := v.BindPFlag(cfgPprofEnabled, flagSet.Lookup(cmdPprof)); err != nil {
return err
}
if err := v.BindPFlag(cfgPrometheusEnabled, flagSet.Lookup(cmdMetrics)); err != nil {
return err
}
if err := v.BindPFlag(cfgNodeDialTimeout, flagSet.Lookup(cmdNodeDialTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgHealthcheckTimeout, flagSet.Lookup(cmdHealthcheckTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgRebalance, flagSet.Lookup(cmdRebalance)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletPath, flagSet.Lookup(cmdWallet)); err != nil {
return err
}
if err := v.BindPFlag(cfgWalletAddress, flagSet.Lookup(cmdAddress)); err != nil {
return err
}
if err := restapi.BindFlagsToConfig(v, flagSet, cfgServerSection); err != nil {
return err
}
return nil
}
var knownConfigParams = map[string]struct{}{
cfgWalletAddress: {},
cfgWalletPath: {},
cfgWalletPassphrase: {},
cfgRebalance: {},
cfgHealthcheckTimeout: {},
cfgNodeDialTimeout: {},
cfgPoolErrorThreshold: {},
cfgLoggerLevel: {},
cfgPrometheusEnabled: {},
cfgPrometheusAddress: {},
cfgPprofEnabled: {},
cfgPprofAddress: {},
cfgServerSection + restapi.FlagScheme: {},
cfgServerSection + restapi.FlagCleanupTimeout: {},
cfgServerSection + restapi.FlagGracefulTimeout: {},
cfgServerSection + restapi.FlagMaxHeaderSize: {},
cfgServerSection + restapi.FlagListenAddress: {},
cfgServerSection + restapi.FlagListenLimit: {},
cfgServerSection + restapi.FlagKeepAlive: {},
cfgServerSection + restapi.FlagReadTimeout: {},
cfgServerSection + restapi.FlagWriteTimeout: {},
cfgServerSection + restapi.FlagTLSListenAddress: {},
cfgServerSection + restapi.FlagTLSCertificate: {},
cfgServerSection + restapi.FlagTLSKey: {},
cfgServerSection + restapi.FlagTLSCa: {},
cfgServerSection + restapi.FlagTLSListenLimit: {},
cfgServerSection + restapi.FlagTLSKeepAlive: {},
cfgServerSection + restapi.FlagTLSReadTimeout: {},
cfgServerSection + restapi.FlagTLSWriteTimeout: {},
}
func validateConfig(cfg *viper.Viper, logger *zap.Logger) {
peerNumsMap := make(map[int]struct{})
for _, providedKey := range cfg.AllKeys() {
if !strings.HasPrefix(providedKey, cfgPeers) {
if _, ok := knownConfigParams[providedKey]; !ok {
logger.Warn("unknown config parameter", zap.String("key", providedKey))
}
continue
}
num, ok := isValidPeerKey(providedKey)
if !ok {
logger.Warn("unknown config parameter", zap.String("key", providedKey))
} else {
peerNumsMap[num] = struct{}{}
}
}
peerNums := make([]int, 0, len(peerNumsMap))
for num := range peerNumsMap {
peerNums = append(peerNums, num)
}
sort.Ints(peerNums)
for i, num := range peerNums {
if i != num {
logger.Warn("invalid config parameter, peer indexes must be consecutive starting from 0", zap.String("key", cfgPeers+"."+strconv.Itoa(num)))
}
}
}
func isValidPeerKey(key string) (int, bool) {
trimmed := strings.TrimPrefix(key, cfgPeers)
split := strings.Split(trimmed, ".")
if len(split) != 3 {
return 0, false
}
if split[2] != cfgPeerAddress && split[2] != cfgPeerPriority && split[2] != cfgPeerWeight {
return 0, false
}
num, err := strconv.Atoi(split[1])
if err != nil || num < 0 {
return 0, false
}
return num, true
}
func getNeoFSKey(logger *zap.Logger, cfg *viper.Viper) (*keys.PrivateKey, error) { func getNeoFSKey(logger *zap.Logger, cfg *viper.Viper) (*keys.PrivateKey, error) {
walletPath := cfg.GetString(cmdWallet) walletPath := cfg.GetString(cmdWallet)
if len(walletPath) == 0 { if len(walletPath) == 0 {
@ -303,22 +407,22 @@ func newLogger(v *viper.Viper) *zap.Logger {
func serverConfig(v *viper.Viper) *restapi.ServerConfig { func serverConfig(v *viper.Viper) *restapi.ServerConfig {
return &restapi.ServerConfig{ return &restapi.ServerConfig{
EnabledListeners: v.GetStringSlice(restapi.FlagScheme), EnabledListeners: v.GetStringSlice(cfgServerSection + restapi.FlagScheme),
CleanupTimeout: v.GetDuration(restapi.FlagCleanupTimeout), CleanupTimeout: v.GetDuration(cfgServerSection + restapi.FlagCleanupTimeout),
GracefulTimeout: v.GetDuration(restapi.FlagGracefulTimeout), GracefulTimeout: v.GetDuration(cfgServerSection + restapi.FlagGracefulTimeout),
MaxHeaderSize: v.GetInt(restapi.FlagMaxHeaderSize), MaxHeaderSize: v.GetInt(cfgServerSection + restapi.FlagMaxHeaderSize),
ListenAddress: v.GetString(restapi.FlagListenAddress), ListenAddress: v.GetString(cfgServerSection + restapi.FlagListenAddress),
ListenLimit: v.GetInt(restapi.FlagListenLimit), ListenLimit: v.GetInt(cfgServerSection + restapi.FlagListenLimit),
KeepAlive: v.GetDuration(restapi.FlagKeepAlive), KeepAlive: v.GetDuration(cfgServerSection + restapi.FlagKeepAlive),
ReadTimeout: v.GetDuration(restapi.FlagReadTimeout), ReadTimeout: v.GetDuration(cfgServerSection + restapi.FlagReadTimeout),
WriteTimeout: v.GetDuration(restapi.FlagWriteTimeout), WriteTimeout: v.GetDuration(cfgServerSection + restapi.FlagWriteTimeout),
TLSListenAddress: v.GetString(restapi.FlagTLSListenAddress), TLSListenAddress: v.GetString(cfgServerSection + restapi.FlagTLSListenAddress),
TLSListenLimit: v.GetInt(restapi.FlagTLSListenLimit), TLSListenLimit: v.GetInt(cfgServerSection + restapi.FlagTLSListenLimit),
TLSKeepAlive: v.GetDuration(restapi.FlagTLSKeepAlive), TLSKeepAlive: v.GetDuration(cfgServerSection + restapi.FlagTLSKeepAlive),
TLSReadTimeout: v.GetDuration(restapi.FlagTLSReadTimeout), TLSReadTimeout: v.GetDuration(cfgServerSection + restapi.FlagTLSReadTimeout),
TLSWriteTimeout: v.GetDuration(restapi.FlagTLSWriteTimeout), TLSWriteTimeout: v.GetDuration(cfgServerSection + restapi.FlagTLSWriteTimeout),
} }
} }
@ -371,9 +475,9 @@ func fetchPeers(l *zap.Logger, v *viper.Viper) []pool.NodeParam {
var nodes []pool.NodeParam var nodes []pool.NodeParam
for i := 0; ; i++ { for i := 0; ; i++ {
key := cfgPeers + "." + strconv.Itoa(i) + "." key := cfgPeers + "." + strconv.Itoa(i) + "."
address := v.GetString(key + "address") address := v.GetString(key + cfgPeerAddress)
weight := v.GetFloat64(key + "weight") weight := v.GetFloat64(key + cfgPeerWeight)
priority := v.GetInt(key + "priority") priority := v.GetInt(key + cfgPeerPriority)
if address == "" { if address == "" {
break break

View file

@ -192,8 +192,8 @@ func getDefaultConfig(node string) *viper.Viper {
v.SetDefault(cfgPeers+".0.address", node) v.SetDefault(cfgPeers+".0.address", node)
v.SetDefault(cfgPeers+".0.weight", 1) v.SetDefault(cfgPeers+".0.weight", 1)
v.SetDefault(cfgPeers+".0.priority", 1) v.SetDefault(cfgPeers+".0.priority", 1)
v.SetDefault(restapi.FlagListenAddress, testListenAddress) v.SetDefault(cfgServerSection+restapi.FlagListenAddress, testListenAddress)
v.SetDefault(restapi.FlagWriteTimeout, 60*time.Second) v.SetDefault(cfgServerSection+restapi.FlagWriteTimeout, 60*time.Second)
return v return v
} }

View file

@ -16,6 +16,7 @@ func main() {
v := config() v := config()
logger := newLogger(v) logger := newLogger(v)
validateConfig(v, logger)
neofsAPI, err := newNeofsAPI(ctx, logger, v) neofsAPI, err := newNeofsAPI(ctx, logger, v)
if err != nil { if err != nil {

View file

@ -46,39 +46,39 @@ REST_GW_POOL_REBALANCE_TIMER=60s
REST_GW_POOL_ERROR_THRESHOLD=100 REST_GW_POOL_ERROR_THRESHOLD=100
# Grace period for which to wait before killing idle connections # Grace period for which to wait before killing idle connections
REST_GW_CLEANUP_TIMEOUT=10s REST_GW_SERVER_CLEANUP_TIMEOUT=10s
# Grace period for which to wait before shutting down the server # Grace period for which to wait before shutting down the server
REST_GW_GRACEFUL_TIMEOUT=15s REST_GW_SERVER_GRACEFUL_TIMEOUT=15s
# Controls the maximum number of bytes the server will read parsing the request header's keys and values, # Controls the maximum number of bytes the server will read parsing the request header's keys and values,
# including the request line. It does not limit the size of the request body. # including the request line. It does not limit the size of the request body.
REST_GW_MAX_HEADER_SIZE=1000000 REST_GW_SERVER_MAX_HEADER_SIZE=1000000
# The IP and port to listen on. # The IP and port to listen on.
REST_GW_LISTEN_ADDRESS=localhost:8080 REST_GW_SERVER_LISTEN_ADDRESS=localhost:8080
# Limit the number of outstanding requests. # Limit the number of outstanding requests.
REST_GW_LISTEN_LIMIT=0 REST_GW_SERVER_LISTEN_LIMIT=0
# Sets the TCP keep-alive timeouts on accepted connections. # Sets the TCP keep-alive timeouts on accepted connections.
# It prunes dead TCP connections ( e.g. closing laptop mid-download). # It prunes dead TCP connections ( e.g. closing laptop mid-download).
REST_GW_KEEP_ALIVE=3m REST_GW_SERVER_KEEP_ALIVE=3m
# Maximum duration before timing out read of the request. # Maximum duration before timing out read of the request.
REST_GW_READ_TIMEOUT=30s REST_GW_SERVER_READ_TIMEOUT=30s
# Maximum duration before timing out write of the response. # Maximum duration before timing out write of the response.
REST_GW_WRITE_TIMEOUT=30s REST_GW_SERVER_WRITE_TIMEOUT=30s
# The IP and port to listen on. # The IP and port to listen on.
REST_GW_TLS_LISTEN_ADDRESS=localhost:8081 REST_GW_SERVER_TLS_LISTEN_ADDRESS=localhost:8081
# The certificate file to use for secure connections. # The certificate file to use for secure connections.
REST_GW_TLS_CERTIFICATE=/path/to/tls/cert REST_GW_SERVER_TLS_CERTIFICATE=/path/to/tls/cert
# The private key file to use for secure connections (without passphrase). # The private key file to use for secure connections (without passphrase).
REST_GW_TLS_KEY=/path/to/tls/key REST_GW_SERVER_TLS_KEY=/path/to/tls/key
# The certificate authority certificate file to be used with mutual tls auth. # The certificate authority certificate file to be used with mutual tls auth.
REST_GW_TLS_CA=/path/to/tls/ca REST_GW_SERVER_TLS_CA=/path/to/tls/ca
# Limit the number of outstanding requests. # Limit the number of outstanding requests.
REST_GW_TLS_LISTEN_LIMIT=0 REST_GW_SERVER_TLS_LISTEN_LIMIT=0
# Sets the TCP keep-alive timeouts on accepted connections. # Sets the TCP keep-alive timeouts on accepted connections.
# It prunes dead TCP connections ( e.g. closing laptop mid-download). # It prunes dead TCP connections ( e.g. closing laptop mid-download).
REST_GW_TLS_KEEP_ALIVE=3m REST_GW_SERVER_TLS_KEEP_ALIVE=3m
# Maximum duration before timing out read of the request. # Maximum duration before timing out read of the request.
REST_GW_TLS_READ_TIMEOUT=30s REST_GW_SERVER_TLS_READ_TIMEOUT=30s
# Maximum duration before timing out write of the response. # Maximum duration before timing out write of the response.
REST_GW_TLS_WRITE_TIMEOUT=30s REST_GW_SERVER_TLS_WRITE_TIMEOUT=30s

View file

@ -50,42 +50,43 @@ pool:
priority: 2 priority: 2
weight: 9 weight: 9
# The listeners to enable, this can be repeated and defaults to the schemes in the swagger spec. server:
scheme: [ http ] # The listeners to enable, this can be repeated and defaults to the schemes in the swagger spec.
# Grace period for which to wait before killing idle connections scheme: [ http ]
cleanup-timeout: 10s # Grace period for which to wait before killing idle connections
# Grace period for which to wait before shutting down the server cleanup-timeout: 10s
graceful-timeout: 15s # Grace period for which to wait before shutting down the server
# Controls the maximum number of bytes the server will read parsing the request header's keys and values, graceful-timeout: 15s
# including the request line. It does not limit the size of the request body. # Controls the maximum number of bytes the server will read parsing the request header's keys and values,
max-header-size: 1000000 # including the request line. It does not limit the size of the request body.
max-header-size: 1000000
# The IP and port to listen on. # The IP and port to listen on.
listen-address: localhost:8080 listen-address: localhost:8080
# Limit the number of outstanding requests. # Limit the number of outstanding requests.
listen-limit: 0 listen-limit: 0
# Sets the TCP keep-alive timeouts on accepted connections. # Sets the TCP keep-alive timeouts on accepted connections.
# It prunes dead TCP connections ( e.g. closing laptop mid-download). # It prunes dead TCP connections ( e.g. closing laptop mid-download).
keep-alive: 3m keep-alive: 3m
# Maximum duration before timing out read of the request. # Maximum duration before timing out read of the request.
read-timeout: 30s read-timeout: 30s
# Maximum duration before timing out write of the response. # Maximum duration before timing out write of the response.
write-timeout: 30s write-timeout: 30s
# The IP and port to listen on. # The IP and port to listen on.
tls-listen-address: localhost:8081 tls-listen-address: localhost:8081
# The certificate file to use for secure connections. # The certificate file to use for secure connections.
tls-certificate: /path/to/tls/cert tls-certificate: /path/to/tls/cert
# The private key file to use for secure connections (without passphrase). # The private key file to use for secure connections (without passphrase).
tls-key: /path/to/tls/key tls-key: /path/to/tls/key
# The certificate authority certificate file to be used with mutual tls auth. # The certificate authority certificate file to be used with mutual tls auth.
tls-ca: /path/to/tls/ca tls-ca: /path/to/tls/ca
# Limit the number of outstanding requests. # Limit the number of outstanding requests.
tls-listen-limit: 0 tls-listen-limit: 0
# Sets the TCP keep-alive timeouts on accepted connections. # Sets the TCP keep-alive timeouts on accepted connections.
# It prunes dead TCP connections ( e.g. closing laptop mid-download). # It prunes dead TCP connections ( e.g. closing laptop mid-download).
tls-keep-alive: 3m tls-keep-alive: 3m
# Maximum duration before timing out read of the request. # Maximum duration before timing out read of the request.
tls-read-timeout: 30s tls-read-timeout: 30s
# Maximum duration before timing out write of the response. # Maximum duration before timing out write of the response.
tls-write-timeout: 30s tls-write-timeout: 30s

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/spf13/viper"
) )
const ( const (
@ -51,3 +52,63 @@ func BindDefaultFlags(flagSet *pflag.FlagSet) {
flagSet.Duration(FlagTLSReadTimeout, 30*time.Second, "maximum duration before timing out read of the request") flagSet.Duration(FlagTLSReadTimeout, 30*time.Second, "maximum duration before timing out read of the request")
flagSet.Duration(FlagTLSWriteTimeout, 30*time.Second, "maximum duration before timing out write of the response") flagSet.Duration(FlagTLSWriteTimeout, 30*time.Second, "maximum duration before timing out write of the response")
} }
// BindFlagsToConfig maps flags to viper config in specific section.
func BindFlagsToConfig(v *viper.Viper, flagSet *pflag.FlagSet, section string) error {
if err := v.BindPFlag(section+FlagScheme, flagSet.Lookup(FlagScheme)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagCleanupTimeout, flagSet.Lookup(FlagCleanupTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagGracefulTimeout, flagSet.Lookup(FlagGracefulTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagMaxHeaderSize, flagSet.Lookup(FlagMaxHeaderSize)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagListenAddress, flagSet.Lookup(FlagListenAddress)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagListenLimit, flagSet.Lookup(FlagListenLimit)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagKeepAlive, flagSet.Lookup(FlagKeepAlive)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagReadTimeout, flagSet.Lookup(FlagReadTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagWriteTimeout, flagSet.Lookup(FlagWriteTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSListenAddress, flagSet.Lookup(FlagTLSListenAddress)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSCertificate, flagSet.Lookup(FlagTLSCertificate)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSKey, flagSet.Lookup(FlagTLSKey)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSCa, flagSet.Lookup(FlagTLSCa)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSListenLimit, flagSet.Lookup(FlagTLSListenLimit)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSKeepAlive, flagSet.Lookup(FlagTLSKeepAlive)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSReadTimeout, flagSet.Lookup(FlagTLSReadTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSWriteTimeout, flagSet.Lookup(FlagTLSWriteTimeout)); err != nil {
return err
}
return nil
}

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/spf13/viper"
) )
const ( const (
@ -55,3 +56,63 @@ func BindDefaultFlags(flagSet *pflag.FlagSet) {
flagSet.Duration(FlagTLSReadTimeout, 30*time.Second, "maximum duration before timing out read of the request") flagSet.Duration(FlagTLSReadTimeout, 30*time.Second, "maximum duration before timing out read of the request")
flagSet.Duration(FlagTLSWriteTimeout, 30*time.Second, "maximum duration before timing out write of the response") flagSet.Duration(FlagTLSWriteTimeout, 30*time.Second, "maximum duration before timing out write of the response")
} }
// BindFlagsToConfig maps flags to viper config in specific section.
func BindFlagsToConfig(v *viper.Viper, flagSet *pflag.FlagSet, section string) error {
if err := v.BindPFlag(section+FlagScheme, flagSet.Lookup(FlagScheme)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagCleanupTimeout, flagSet.Lookup(FlagCleanupTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagGracefulTimeout, flagSet.Lookup(FlagGracefulTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagMaxHeaderSize, flagSet.Lookup(FlagMaxHeaderSize)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagListenAddress, flagSet.Lookup(FlagListenAddress)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagListenLimit, flagSet.Lookup(FlagListenLimit)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagKeepAlive, flagSet.Lookup(FlagKeepAlive)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagReadTimeout, flagSet.Lookup(FlagReadTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagWriteTimeout, flagSet.Lookup(FlagWriteTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSListenAddress, flagSet.Lookup(FlagTLSListenAddress)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSCertificate, flagSet.Lookup(FlagTLSCertificate)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSKey, flagSet.Lookup(FlagTLSKey)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSCa, flagSet.Lookup(FlagTLSCa)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSListenLimit, flagSet.Lookup(FlagTLSListenLimit)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSKeepAlive, flagSet.Lookup(FlagTLSKeepAlive)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSReadTimeout, flagSet.Lookup(FlagTLSReadTimeout)); err != nil {
return err
}
if err := v.BindPFlag(section+FlagTLSWriteTimeout, flagSet.Lookup(FlagTLSWriteTimeout)); err != nil {
return err
}
return nil
}