package main

import (
	"context"
	"fmt"
	"os"
	"runtime"
	"sort"
	"strconv"
	"strings"
	"time"

	"github.com/TrueCloudLab/frostfs-rest-gw/gen/restapi"
	"github.com/TrueCloudLab/frostfs-rest-gw/handlers"
	"github.com/TrueCloudLab/frostfs-rest-gw/metrics"
	"github.com/TrueCloudLab/frostfs-sdk-go/pool"
	"github.com/nspcc-dev/neo-go/cli/flags"
	"github.com/nspcc-dev/neo-go/cli/input"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/nspcc-dev/neo-go/pkg/util"
	"github.com/nspcc-dev/neo-go/pkg/wallet"
	"github.com/spf13/pflag"
	"github.com/spf13/viper"
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"
)

const (
	defaultConnectTimeout     = 10 * time.Second
	defaultHealthcheckTimeout = 15 * time.Second
	defaultRebalanceTimer     = 60 * time.Second

	defaultShutdownTimeout = 15 * time.Second

	defaultPoolErrorThreshold uint32 = 100

	// Pool config.
	cmdNodeDialTimeout    = "node-dial-timeout"
	cfgNodeDialTimeout    = "pool." + cmdNodeDialTimeout
	cmdHealthcheckTimeout = "healthcheck-timeout"
	cfgHealthcheckTimeout = "pool." + cmdHealthcheckTimeout
	cmdRebalance          = "rebalance-timer"
	cfgRebalance          = "pool." + cmdRebalance
	cfgPoolErrorThreshold = "pool.error-threshold"
	cmdPeers              = "peers"
	cfgPeers              = "pool." + cmdPeers
	cfgPeerAddress        = "address"
	cfgPeerPriority       = "priority"
	cfgPeerWeight         = "weight"

	// Metrics / Profiler.
	cfgPrometheusEnabled = "prometheus.enabled"
	cfgPrometheusAddress = "prometheus.address"
	cfgPprofEnabled      = "pprof.enabled"
	cfgPprofAddress      = "pprof.address"

	// Logger.
	cfgLoggerLevel = "logger.level"

	// Wallet.
	cfgWalletPath       = "wallet.path"
	cfgWalletAddress    = "wallet.address"
	cfgWalletPassphrase = "wallet.passphrase"

	// Config section for autogenerated flags.
	cfgServerSection = "server."

	// Command line args.
	cmdHelp    = "help"
	cmdVersion = "version"
	cmdPprof   = "pprof"
	cmdMetrics = "metrics"
	cmdWallet  = "wallet"
	cmdAddress = "address"
	cmdConfig  = "config"
)

var ignore = map[string]struct{}{
	cmdPeers:   {},
	cmdHelp:    {},
	cmdVersion: {},
}

// Prefix is a prefix used for environment variables containing gateway
// configuration.
const Prefix = "REST_GW"

var (
	// Version is gateway version.
	Version = "dev"
)

func config() *viper.Viper {
	v := viper.New()
	v.AutomaticEnv()
	v.SetEnvPrefix(Prefix)
	v.AllowEmptyEnv(true)
	v.SetConfigType("yaml")
	v.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))

	// flags setup:
	flagSet := pflag.NewFlagSet("commandline", pflag.ExitOnError)
	flagSet.SetOutput(os.Stdout)
	flagSet.SortFlags = false

	flagSet.Bool(cmdPprof, false, "enable pprof")
	flagSet.Bool(cmdMetrics, false, "enable prometheus")

	help := flagSet.BoolP(cmdHelp, "h", false, "show help")
	version := flagSet.BoolP(cmdVersion, "v", false, "show version")

	flagSet.StringP(cmdWallet, "w", "", `path to the wallet`)
	flagSet.String(cmdAddress, "", `address of wallet account`)

	configFlag := flagSet.String(cmdConfig, "", "config path")
	flagSet.Duration(cmdNodeDialTimeout, defaultConnectTimeout, "gRPC node connect timeout")
	flagSet.Duration(cmdHealthcheckTimeout, defaultHealthcheckTimeout, "gRPC healthcheck timeout")
	flagSet.Duration(cmdRebalance, defaultRebalanceTimer, "gRPC connection rebalance timer")

	peers := flagSet.StringArrayP(cmdPeers, "p", nil, "NeoFS nodes")

	// init server flags
	restapi.BindDefaultFlags(flagSet)

	// set defaults:
	// pool
	v.SetDefault(cfgPoolErrorThreshold, defaultPoolErrorThreshold)

	// metrics
	v.SetDefault(cfgPprofAddress, "localhost:8091")
	v.SetDefault(cfgPrometheusAddress, "localhost:8092")

	// logger:
	v.SetDefault(cfgLoggerLevel, "debug")

	// Bind flags
	for cfg, cmd := range bindings {
		if err := v.BindPFlag(cfg, flagSet.Lookup(cmd)); err != nil {
			panic(fmt.Errorf("bind flags: %w", err))
		}
	}

	if err := flagSet.Parse(os.Args); err != nil {
		panic(err)
	}

	switch {
	case help != nil && *help:
		fmt.Printf("NeoFS REST Gateway %s\n", Version)
		flagSet.PrintDefaults()

		fmt.Println()
		fmt.Println("Default environments:")
		fmt.Println()
		cmdKeys := v.AllKeys()
		sort.Strings(cmdKeys)

		for i := range cmdKeys {
			if _, ok := ignore[cmdKeys[i]]; ok {
				continue
			}

			k := strings.Replace(cmdKeys[i], ".", "_", -1)
			fmt.Printf("%s_%s = %v\n", Prefix, strings.ToUpper(k), v.Get(cmdKeys[i]))
		}

		os.Exit(0)
	case version != nil && *version:
		fmt.Printf("NeoFS REST Gateway\nVersion: %s\nGoVersion: %s\n", Version, runtime.Version())
		os.Exit(0)
	case configFlag != nil && *configFlag != "":
		if cfgFile, err := os.Open(*configFlag); err != nil {
			panic(err)
		} else if err = v.ReadConfig(cfgFile); err != nil {
			panic(err)
		}
	}

	if peers != nil && len(*peers) > 0 {
		for i := range *peers {
			v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerAddress, (*peers)[i])
			v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
			v.SetDefault(cfgPeers+"."+strconv.Itoa(i)+"."+cfgPeerWeight, 1)
		}
	}

	return v
}

func init() {
	for _, flagName := range serverFlags {
		cfgName := cfgServerSection + flagName
		bindings[cfgName] = flagName
		knownConfigParams[cfgName] = struct{}{}
	}
}

var serverFlags = []string{
	restapi.FlagScheme,
	restapi.FlagCleanupTimeout,
	restapi.FlagGracefulTimeout,
	restapi.FlagMaxHeaderSize,
	restapi.FlagListenAddress,
	restapi.FlagListenLimit,
	restapi.FlagKeepAlive,
	restapi.FlagReadTimeout,
	restapi.FlagWriteTimeout,
	restapi.FlagTLSListenAddress,
	restapi.FlagTLSCertificate,
	restapi.FlagTLSKey,
	restapi.FlagTLSCa,
	restapi.FlagTLSListenLimit,
	restapi.FlagTLSKeepAlive,
	restapi.FlagTLSReadTimeout,
	restapi.FlagTLSWriteTimeout,
}

var bindings = map[string]string{
	cfgPprofEnabled:       cmdPprof,
	cfgPrometheusEnabled:  cmdMetrics,
	cfgNodeDialTimeout:    cmdNodeDialTimeout,
	cfgHealthcheckTimeout: cmdHealthcheckTimeout,
	cfgRebalance:          cmdRebalance,
	cfgWalletPath:         cmdWallet,
	cfgWalletAddress:      cmdAddress,
}

var knownConfigParams = map[string]struct{}{
	cfgWalletAddress:      {},
	cfgWalletPath:         {},
	cfgWalletPassphrase:   {},
	cfgRebalance:          {},
	cfgHealthcheckTimeout: {},
	cfgNodeDialTimeout:    {},
	cfgPoolErrorThreshold: {},
	cfgLoggerLevel:        {},
	cfgPrometheusEnabled:  {},
	cfgPrometheusAddress:  {},
	cfgPprofEnabled:       {},
	cfgPprofAddress:       {},
}

func validateConfig(cfg *viper.Viper, logger *zap.Logger) {
	peerNumsMap := make(map[int]struct{})

	for _, providedKey := range cfg.AllKeys() {
		if !strings.HasPrefix(providedKey, cfgPeers) {
			if _, ok := knownConfigParams[providedKey]; !ok {
				logger.Warn("unknown config parameter", zap.String("key", providedKey))
			}
			continue
		}

		num, ok := isValidPeerKey(providedKey)
		if !ok {
			logger.Warn("unknown config parameter", zap.String("key", providedKey))
		} else {
			peerNumsMap[num] = struct{}{}
		}
	}

	peerNums := make([]int, 0, len(peerNumsMap))
	for num := range peerNumsMap {
		peerNums = append(peerNums, num)
	}
	sort.Ints(peerNums)

	for i, num := range peerNums {
		if i != num {
			logger.Warn("invalid config parameter, peer indexes must be consecutive starting from 0", zap.String("key", cfgPeers+"."+strconv.Itoa(num)))
		}
	}
}

func isValidPeerKey(key string) (int, bool) {
	trimmed := strings.TrimPrefix(key, cfgPeers)
	split := strings.Split(trimmed, ".")

	if len(split) != 3 {
		return 0, false
	}

	if split[2] != cfgPeerAddress && split[2] != cfgPeerPriority && split[2] != cfgPeerWeight {
		return 0, false
	}

	num, err := strconv.Atoi(split[1])
	if err != nil || num < 0 {
		return 0, false
	}

	return num, true
}

func getNeoFSKey(logger *zap.Logger, cfg *viper.Viper) (*keys.PrivateKey, error) {
	walletPath := cfg.GetString(cmdWallet)
	if len(walletPath) == 0 {
		walletPath = cfg.GetString(cfgWalletPath)
	}

	if len(walletPath) == 0 {
		logger.Info("no wallet path specified, creating ephemeral key automatically for this run")
		return keys.NewPrivateKey()
	}
	w, err := wallet.NewWalletFromFile(walletPath)
	if err != nil {
		return nil, err
	}

	var password *string
	if cfg.IsSet(cfgWalletPassphrase) {
		pwd := cfg.GetString(cfgWalletPassphrase)
		password = &pwd
	}

	address := cfg.GetString(cmdAddress)
	if len(address) == 0 {
		address = cfg.GetString(cfgWalletAddress)
	}

	return getKeyFromWallet(w, address, password)
}

func getKeyFromWallet(w *wallet.Wallet, addrStr string, password *string) (*keys.PrivateKey, error) {
	var addr util.Uint160
	var err error

	if addrStr == "" {
		addr = w.GetChangeAddress()
	} else {
		addr, err = flags.ParseAddress(addrStr)
		if err != nil {
			return nil, fmt.Errorf("invalid address")
		}
	}

	acc := w.GetAccount(addr)
	if acc == nil {
		return nil, fmt.Errorf("couldn't find wallet account for %s", addrStr)
	}

	if password == nil {
		pwd, err := input.ReadPassword("Enter password > ")
		if err != nil {
			return nil, fmt.Errorf("couldn't read password")
		}
		password = &pwd
	}

	if err := acc.Decrypt(*password, w.Scrypt); err != nil {
		return nil, fmt.Errorf("couldn't decrypt account: %w", err)
	}

	return acc.PrivateKey(), nil
}

// newLogger constructs a zap.Logger instance for current application.
// Panics on failure.
//
// Logger is built from zap's production logging configuration with:
//   - parameterized level (debug by default)
//   - console encoding
//   - ISO8601 time encoding
//
// Logger records a stack trace for all messages at or above fatal level.
//
// See also zapcore.Level, zap.NewProductionConfig, zap.AddStacktrace.
func newLogger(v *viper.Viper) *zap.Logger {
	var lvl zapcore.Level
	lvlStr := v.GetString(cfgLoggerLevel)
	err := lvl.UnmarshalText([]byte(lvlStr))
	if err != nil {
		panic(fmt.Sprintf("incorrect logger level configuration %s (%v), "+
			"value should be one of %v", lvlStr, err, [...]zapcore.Level{
			zapcore.DebugLevel,
			zapcore.InfoLevel,
			zapcore.WarnLevel,
			zapcore.ErrorLevel,
			zapcore.DPanicLevel,
			zapcore.PanicLevel,
			zapcore.FatalLevel,
		}))
	}

	c := zap.NewProductionConfig()
	c.Level = zap.NewAtomicLevelAt(lvl)
	c.Encoding = "console"
	c.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder

	l, err := c.Build(
		zap.AddStacktrace(zap.NewAtomicLevelAt(zap.FatalLevel)),
	)
	if err != nil {
		panic(fmt.Sprintf("build zap logger instance: %v", err))
	}

	return l
}

func serverConfig(v *viper.Viper) *restapi.ServerConfig {
	return &restapi.ServerConfig{
		EnabledListeners: v.GetStringSlice(cfgServerSection + restapi.FlagScheme),
		CleanupTimeout:   v.GetDuration(cfgServerSection + restapi.FlagCleanupTimeout),
		GracefulTimeout:  v.GetDuration(cfgServerSection + restapi.FlagGracefulTimeout),
		MaxHeaderSize:    v.GetInt(cfgServerSection + restapi.FlagMaxHeaderSize),

		ListenAddress: v.GetString(cfgServerSection + restapi.FlagListenAddress),
		ListenLimit:   v.GetInt(cfgServerSection + restapi.FlagListenLimit),
		KeepAlive:     v.GetDuration(cfgServerSection + restapi.FlagKeepAlive),
		ReadTimeout:   v.GetDuration(cfgServerSection + restapi.FlagReadTimeout),
		WriteTimeout:  v.GetDuration(cfgServerSection + restapi.FlagWriteTimeout),

		TLSListenAddress: v.GetString(cfgServerSection + restapi.FlagTLSListenAddress),
		TLSListenLimit:   v.GetInt(cfgServerSection + restapi.FlagTLSListenLimit),
		TLSKeepAlive:     v.GetDuration(cfgServerSection + restapi.FlagTLSKeepAlive),
		TLSReadTimeout:   v.GetDuration(cfgServerSection + restapi.FlagTLSReadTimeout),
		TLSWriteTimeout:  v.GetDuration(cfgServerSection + restapi.FlagTLSWriteTimeout),
	}
}

func newNeofsAPI(ctx context.Context, logger *zap.Logger, v *viper.Viper) (*handlers.API, error) {
	key, err := getNeoFSKey(logger, v)
	if err != nil {
		return nil, err
	}

	var prm pool.InitParameters
	prm.SetKey(&key.PrivateKey)
	prm.SetNodeDialTimeout(v.GetDuration(cfgNodeDialTimeout))
	prm.SetHealthcheckTimeout(v.GetDuration(cfgHealthcheckTimeout))
	prm.SetClientRebalanceInterval(v.GetDuration(cfgRebalance))
	prm.SetErrorThreshold(v.GetUint32(cfgPoolErrorThreshold))

	for _, peer := range fetchPeers(logger, v) {
		prm.AddNode(peer)
	}

	p, err := pool.NewPool(prm)
	if err != nil {
		return nil, err
	}

	if err = p.Dial(ctx); err != nil {
		return nil, err
	}

	var apiPrm handlers.PrmAPI
	apiPrm.Pool = p
	apiPrm.Key = key
	apiPrm.Logger = logger

	pprofConfig := metrics.Config{Enabled: v.GetBool(cfgPprofEnabled), Address: v.GetString(cfgPprofAddress)}
	apiPrm.PprofService = metrics.NewPprofService(logger, pprofConfig)

	prometheusConfig := metrics.Config{Enabled: v.GetBool(cfgPrometheusEnabled), Address: v.GetString(cfgPrometheusAddress)}
	apiPrm.PrometheusService = metrics.NewPrometheusService(logger, prometheusConfig)
	if prometheusConfig.Enabled {
		apiPrm.GateMetric = metrics.NewGateMetrics(p)
	}

	apiPrm.ServiceShutdownTimeout = defaultShutdownTimeout

	return handlers.New(&apiPrm), nil
}

func fetchPeers(l *zap.Logger, v *viper.Viper) []pool.NodeParam {
	var nodes []pool.NodeParam
	for i := 0; ; i++ {
		key := cfgPeers + "." + strconv.Itoa(i) + "."
		address := v.GetString(key + cfgPeerAddress)
		weight := v.GetFloat64(key + cfgPeerWeight)
		priority := v.GetInt(key + cfgPeerPriority)

		if address == "" {
			break
		}
		if weight <= 0 { // unspecified or wrong
			weight = 1
		}
		if priority <= 0 { // unspecified or wrong
			priority = 1
		}

		nodes = append(nodes, pool.NewNodeParam(priority, address, weight))

		l.Info("added connection peer",
			zap.String("address", address),
			zap.Int("priority", priority),
			zap.Float64("weight", weight),
		)
	}

	return nodes
}