package main

import (
	"context"
	"fmt"
	"net/http"
	"os"
	"os/signal"
	"runtime/debug"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"

	v2container "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/container"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/cache"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/frostfs/services"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/handler"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/handler/middleware"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/logs"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/metrics"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/resolver"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/response"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/tokens"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/tree"
	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/utils"
	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pool"
	treepool "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pool/tree"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
	"github.com/fasthttp/router"
	"github.com/nspcc-dev/neo-go/cli/flags"
	"github.com/nspcc-dev/neo-go/cli/input"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/nspcc-dev/neo-go/pkg/util"
	"github.com/nspcc-dev/neo-go/pkg/wallet"
	"github.com/spf13/viper"
	"github.com/valyala/fasthttp"
	"go.uber.org/zap"
	"golang.org/x/exp/slices"
)

type (
	app struct {
		ctx       context.Context
		log       *zap.Logger
		logLevel  zap.AtomicLevel
		pool      *pool.Pool
		treePool  *treepool.Pool
		key       *keys.PrivateKey
		owner     *user.ID
		cfg       *viper.Viper
		webServer *fasthttp.Server
		webDone   chan struct{}
		resolver  *resolver.ContainerResolver
		metrics   *gateMetrics
		services  []*metrics.Service
		settings  *appSettings
		servers   []Server
	}

	// App is an interface for the main gateway function.
	App interface {
		Wait()
		Serve()
	}

	// Option is an application option.
	Option func(a *app)

	gateMetrics struct {
		logger   *zap.Logger
		provider *metrics.GateMetrics
		mu       sync.RWMutex
		enabled  bool
	}

	// appSettings stores reloading parameters, so it has to provide getters and setters which use RWMutex.
	appSettings struct {
		mu                  sync.RWMutex
		defaultTimestamp    bool
		zipCompression      bool
		clientCut           bool
		bufferMaxSizeForPut uint64
		namespaceHeader     string
		defaultNamespaces   []string
	}
)

// WithLogger returns Option to set a specific logger.
func WithLogger(l *zap.Logger, lvl zap.AtomicLevel) Option {
	return func(a *app) {
		if l == nil {
			return
		}
		a.log = l
		a.logLevel = lvl
	}
}

// WithConfig returns Option to use specific Viper configuration.
func WithConfig(c *viper.Viper) Option {
	return func(a *app) {
		if c == nil {
			return
		}
		a.cfg = c
	}
}

func newApp(ctx context.Context, opt ...Option) App {
	a := &app{
		ctx:       ctx,
		log:       zap.L(),
		cfg:       viper.GetViper(),
		webServer: new(fasthttp.Server),
		webDone:   make(chan struct{}),
	}
	for i := range opt {
		opt[i](a)
	}

	// -- setup FastHTTP server --
	a.webServer.Name = "frost-http-gw"
	a.webServer.ReadBufferSize = a.cfg.GetInt(cfgWebReadBufferSize)
	a.webServer.WriteBufferSize = a.cfg.GetInt(cfgWebWriteBufferSize)
	a.webServer.ReadTimeout = a.cfg.GetDuration(cfgWebReadTimeout)
	a.webServer.WriteTimeout = a.cfg.GetDuration(cfgWebWriteTimeout)
	a.webServer.DisableHeaderNamesNormalizing = true
	a.webServer.NoDefaultServerHeader = true
	a.webServer.NoDefaultContentType = true
	a.webServer.MaxRequestBodySize = a.cfg.GetInt(cfgWebMaxRequestBodySize)
	a.webServer.DisablePreParseMultipartForm = true
	a.webServer.StreamRequestBody = a.cfg.GetBool(cfgWebStreamRequestBody)
	// -- -- -- -- -- -- -- -- -- -- -- -- -- --
	a.pool, a.treePool, a.key = getPools(ctx, a.log, a.cfg)

	var owner user.ID
	user.IDFromKey(&owner, a.key.PrivateKey.PublicKey)
	a.owner = &owner

	a.setRuntimeParameters()

	a.initAppSettings()
	a.initResolver()
	a.initMetrics()
	a.initTracing(ctx)

	return a
}

func (s *appSettings) DefaultTimestamp() bool {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.defaultTimestamp
}

func (s *appSettings) setDefaultTimestamp(val bool) {
	s.mu.Lock()
	s.defaultTimestamp = val
	s.mu.Unlock()
}

func (s *appSettings) ZipCompression() bool {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.zipCompression
}

func (s *appSettings) setZipCompression(val bool) {
	s.mu.Lock()
	s.zipCompression = val
	s.mu.Unlock()
}

func (s *appSettings) ClientCut() bool {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.clientCut
}

func (s *appSettings) setClientCut(val bool) {
	s.mu.Lock()
	s.clientCut = val
	s.mu.Unlock()
}

func (s *appSettings) BufferMaxSizeForPut() uint64 {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.bufferMaxSizeForPut
}

func (s *appSettings) setBufferMaxSizeForPut(val uint64) {
	s.mu.Lock()
	s.bufferMaxSizeForPut = val
	s.mu.Unlock()
}

func (a *app) initAppSettings() {
	a.settings = &appSettings{}

	a.updateSettings()
}

func (a *app) initResolver() {
	var err error
	a.resolver, err = resolver.NewContainerResolver(a.getResolverConfig())
	if err != nil {
		a.log.Fatal(logs.FailedToCreateResolver, zap.Error(err))
	}
}

func (a *app) getResolverConfig() ([]string, *resolver.Config) {
	resolveCfg := &resolver.Config{
		FrostFS:    resolver.NewFrostFSResolver(a.pool),
		RPCAddress: a.cfg.GetString(cfgRPCEndpoint),
		Settings:   a.settings,
	}

	order := a.cfg.GetStringSlice(cfgResolveOrder)
	if resolveCfg.RPCAddress == "" {
		order = remove(order, resolver.NNSResolver)
		a.log.Warn(logs.ResolverNNSWontBeUsedSinceRPCEndpointIsntProvided)
	}

	if len(order) == 0 {
		a.log.Info(logs.ContainerResolverWillBeDisabledBecauseOfResolversResolverOrderIsEmpty)
	}

	return order, resolveCfg
}

func (a *app) initMetrics() {
	gateMetricsProvider := metrics.NewGateMetrics(a.pool)
	a.metrics = newGateMetrics(a.log, gateMetricsProvider, a.cfg.GetBool(cfgPrometheusEnabled))
	a.metrics.SetHealth(metrics.HealthStatusStarting)
}

func newGateMetrics(logger *zap.Logger, provider *metrics.GateMetrics, enabled bool) *gateMetrics {
	if !enabled {
		logger.Warn(logs.MetricsAreDisabled)
	}
	return &gateMetrics{
		logger:   logger,
		provider: provider,
		enabled:  enabled,
	}
}

func (m *gateMetrics) isEnabled() bool {
	m.mu.RLock()
	defer m.mu.RUnlock()

	return m.enabled
}

func (m *gateMetrics) SetEnabled(enabled bool) {
	if !enabled {
		m.logger.Warn(logs.MetricsAreDisabled)
	}

	m.mu.Lock()
	m.enabled = enabled
	m.mu.Unlock()
}

func (m *gateMetrics) SetHealth(status metrics.HealthStatus) {
	if !m.isEnabled() {
		return
	}

	m.provider.SetHealth(status)
}

func (m *gateMetrics) SetVersion(ver string) {
	if !m.isEnabled() {
		return
	}

	m.provider.SetVersion(ver)
}

func (m *gateMetrics) Shutdown() {
	m.mu.Lock()
	if m.enabled {
		m.provider.SetHealth(metrics.HealthStatusShuttingDown)
		m.enabled = false
	}
	m.provider.Unregister()
	m.mu.Unlock()
}

func (m *gateMetrics) MarkHealthy(endpoint string) {
	if !m.isEnabled() {
		return
	}

	m.provider.MarkHealthy(endpoint)
}

func (m *gateMetrics) MarkUnhealthy(endpoint string) {
	if !m.isEnabled() {
		return
	}

	m.provider.MarkUnhealthy(endpoint)
}

func remove(list []string, element string) []string {
	for i, item := range list {
		if item == element {
			return append(list[:i], list[i+1:]...)
		}
	}
	return list
}

func getFrostFSKey(cfg *viper.Viper, log *zap.Logger) (*keys.PrivateKey, error) {
	walletPath := cfg.GetString(cfgWalletPath)

	if len(walletPath) == 0 {
		log.Info(logs.NoWalletPathSpecifiedCreatingEphemeralKeyAutomaticallyForThisRun)
		key, err := keys.NewPrivateKey()
		if err != nil {
			return nil, err
		}
		return key, nil
	}
	w, err := wallet.NewWalletFromFile(walletPath)
	if err != nil {
		return nil, err
	}

	var password *string
	if cfg.IsSet(cfgWalletPassphrase) {
		pwd := cfg.GetString(cfgWalletPassphrase)
		password = &pwd
	}

	address := cfg.GetString(cfgWalletAddress)

	return getKeyFromWallet(w, address, password)
}

func getKeyFromWallet(w *wallet.Wallet, addrStr string, password *string) (*keys.PrivateKey, error) {
	var addr util.Uint160
	var err error

	if addrStr == "" {
		addr = w.GetChangeAddress()
	} else {
		addr, err = flags.ParseAddress(addrStr)
		if err != nil {
			return nil, fmt.Errorf("invalid address")
		}
	}

	acc := w.GetAccount(addr)
	if acc == nil {
		return nil, fmt.Errorf("couldn't find wallet account for %s", addrStr)
	}

	if password == nil {
		pwd, err := input.ReadPassword("Enter password > ")
		if err != nil {
			return nil, fmt.Errorf("couldn't read password")
		}
		password = &pwd
	}

	if err := acc.Decrypt(*password, w.Scrypt); err != nil {
		return nil, fmt.Errorf("couldn't decrypt account: %w", err)
	}

	return acc.PrivateKey(), nil
}

func (a *app) Wait() {
	a.log.Info(logs.StartingApplication, zap.String("app_name", "frostfs-http-gw"), zap.String("version", Version))

	a.metrics.SetVersion(Version)
	a.setHealthStatus()

	<-a.webDone // wait for web-server to be stopped
}

func (a *app) setHealthStatus() {
	a.metrics.SetHealth(metrics.HealthStatusReady)
}

func (a *app) Serve() {
	handler := handler.New(a.AppParams(), a.settings, tree.NewTree(services.NewPoolWrapper(a.treePool)))

	// Configure router.
	a.configureRouter(handler)

	a.startServices()
	a.initServers(a.ctx)

	for i := range a.servers {
		go func(i int) {
			a.log.Info(logs.StartingServer, zap.String("address", a.servers[i].Address()))
			if err := a.webServer.Serve(a.servers[i].Listener()); err != nil && err != http.ErrServerClosed {
				a.metrics.MarkUnhealthy(a.servers[i].Address())
				a.log.Fatal(logs.ListenAndServe, zap.Error(err))
			}
		}(i)
	}

	sigs := make(chan os.Signal, 1)
	signal.Notify(sigs, syscall.SIGHUP)

LOOP:
	for {
		select {
		case <-a.ctx.Done():
			break LOOP
		case <-sigs:
			a.configReload(a.ctx)
		}
	}

	a.log.Info(logs.ShuttingDownWebServer, zap.Error(a.webServer.Shutdown()))

	a.metrics.Shutdown()
	a.stopServices()
	a.shutdownTracing()

	close(a.webDone)
}

func (a *app) shutdownTracing() {
	const tracingShutdownTimeout = 5 * time.Second
	shdnCtx, cancel := context.WithTimeout(context.Background(), tracingShutdownTimeout)
	defer cancel()

	if err := tracing.Shutdown(shdnCtx); err != nil {
		a.log.Warn(logs.FailedToShutdownTracing, zap.Error(err))
	}
}

func (a *app) configReload(ctx context.Context) {
	a.log.Info(logs.SIGHUPConfigReloadStarted)
	if !a.cfg.IsSet(cmdConfig) && !a.cfg.IsSet(cmdConfigDir) {
		a.log.Warn(logs.FailedToReloadConfigBecauseItsMissed)
		return
	}
	if err := readInConfig(a.cfg); err != nil {
		a.log.Warn(logs.FailedToReloadConfig, zap.Error(err))
		return
	}

	if lvl, err := getLogLevel(a.cfg); err != nil {
		a.log.Warn(logs.LogLevelWontBeUpdated, zap.Error(err))
	} else {
		a.logLevel.SetLevel(lvl)
	}

	if err := a.resolver.UpdateResolvers(a.getResolverConfig()); err != nil {
		a.log.Warn(logs.FailedToUpdateResolvers, zap.Error(err))
	}

	if err := a.updateServers(); err != nil {
		a.log.Warn(logs.FailedToReloadServerParameters, zap.Error(err))
	}

	a.setRuntimeParameters()

	a.stopServices()
	a.startServices()

	a.updateSettings()

	a.metrics.SetEnabled(a.cfg.GetBool(cfgPrometheusEnabled))
	a.initTracing(ctx)
	a.setHealthStatus()

	a.log.Info(logs.SIGHUPConfigReloadCompleted)
}

func (a *app) updateSettings() {
	a.settings.setDefaultTimestamp(a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp))
	a.settings.setZipCompression(a.cfg.GetBool(cfgZipCompression))
	a.settings.setClientCut(a.cfg.GetBool(cfgClientCut))
	a.settings.setBufferMaxSizeForPut(a.cfg.GetUint64(cfgBufferMaxSizeForPut))
	a.settings.setNamespaceHeader(a.cfg.GetString(cfgResolveNamespaceHeader))
	a.settings.setDefaultNamespaces(a.cfg.GetStringSlice(cfgResolveDefaultNamespaces))
}

func (a *app) startServices() {
	pprofConfig := metrics.Config{Enabled: a.cfg.GetBool(cfgPprofEnabled), Address: a.cfg.GetString(cfgPprofAddress)}
	pprofService := metrics.NewPprofService(a.log, pprofConfig)
	a.services = append(a.services, pprofService)
	go pprofService.Start()

	prometheusConfig := metrics.Config{Enabled: a.cfg.GetBool(cfgPrometheusEnabled), Address: a.cfg.GetString(cfgPrometheusAddress)}
	prometheusService := metrics.NewPrometheusService(a.log, prometheusConfig)
	a.services = append(a.services, prometheusService)
	go prometheusService.Start()
}

func (a *app) stopServices() {
	ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownTimeout)
	defer cancel()

	for _, svc := range a.services {
		svc.ShutDown(ctx)
	}
}

func (a *app) configureRouter(handler *handler.Handler) {
	r := router.New()
	r.RedirectTrailingSlash = true
	r.NotFound = func(r *fasthttp.RequestCtx) {
		response.Error(r, "Not found", fasthttp.StatusNotFound)
	}
	r.MethodNotAllowed = func(r *fasthttp.RequestCtx) {
		response.Error(r, "Method Not Allowed", fasthttp.StatusMethodNotAllowed)
	}

	r.POST("/upload/{cid}", a.logger(a.tokenizer(a.tracer(a.reqNamespace(handler.Upload)))))
	a.log.Info(logs.AddedPathUploadCid)
	r.GET("/get/{cid}/{oid:*}", a.logger(a.tokenizer(a.tracer(a.reqNamespace(handler.DownloadByAddressOrBucketName)))))
	r.HEAD("/get/{cid}/{oid:*}", a.logger(a.tokenizer(a.tracer(a.reqNamespace(handler.HeadByAddressOrBucketName)))))
	a.log.Info(logs.AddedPathGetCidOid)
	r.GET("/get_by_attribute/{cid}/{attr_key}/{attr_val:*}", a.logger(a.tokenizer(a.tracer(a.reqNamespace(handler.DownloadByAttribute)))))
	r.HEAD("/get_by_attribute/{cid}/{attr_key}/{attr_val:*}", a.logger(a.tokenizer(a.tracer(a.reqNamespace(handler.HeadByAttribute)))))
	a.log.Info(logs.AddedPathGetByAttributeCidAttrKeyAttrVal)
	r.GET("/zip/{cid}/{prefix:*}", a.logger(a.tokenizer(a.tracer(a.reqNamespace(handler.DownloadZipped)))))
	a.log.Info(logs.AddedPathZipCidPrefix)

	a.webServer.Handler = r.Handler
}

func (a *app) logger(h fasthttp.RequestHandler) fasthttp.RequestHandler {
	return func(req *fasthttp.RequestCtx) {
		a.log.Info(logs.Request, zap.String("remote", req.RemoteAddr().String()),
			zap.ByteString("method", req.Method()),
			zap.ByteString("path", req.Path()),
			zap.ByteString("query", req.QueryArgs().QueryString()),
			zap.Uint64("id", req.ID()))
		h(req)
	}
}

func (a *app) tokenizer(h fasthttp.RequestHandler) fasthttp.RequestHandler {
	return func(req *fasthttp.RequestCtx) {
		appCtx, err := tokens.StoreBearerTokenAppCtx(a.ctx, req)
		if err != nil {
			a.log.Error(logs.CouldNotFetchAndStoreBearerToken, zap.Error(err))
			response.Error(req, "could not fetch and store bearer token: "+err.Error(), fasthttp.StatusBadRequest)
		}
		utils.SetContextToRequest(appCtx, req)
		h(req)
	}
}

func (a *app) tracer(h fasthttp.RequestHandler) fasthttp.RequestHandler {
	return func(req *fasthttp.RequestCtx) {
		appCtx := utils.GetContextFromRequest(req)

		appCtx, span := utils.StartHTTPServerSpan(appCtx, req, "REQUEST")
		defer func() {
			utils.SetHTTPTraceInfo(appCtx, span, req)
			span.End()
		}()

		appCtx = treepool.SetRequestID(appCtx, strconv.FormatUint(req.ID(), 10))

		utils.SetContextToRequest(appCtx, req)
		h(req)
	}
}

func (a *app) reqNamespace(h fasthttp.RequestHandler) fasthttp.RequestHandler {
	return func(req *fasthttp.RequestCtx) {
		appCtx := utils.GetContextFromRequest(req)

		nsBytes := req.Request.Header.Peek(a.settings.NamespaceHeader())
		appCtx = middleware.SetNamespace(appCtx, string(nsBytes))

		utils.SetContextToRequest(appCtx, req)
		h(req)
	}
}

func (a *app) AppParams() *utils.AppParams {
	return &utils.AppParams{
		Logger:   a.log,
		Pool:     a.pool,
		Owner:    a.owner,
		Resolver: a.resolver,
		Cache:    cache.NewBucketCache(getCacheOptions(a.cfg, a.log)),
	}
}

func (a *app) initServers(ctx context.Context) {
	serversInfo := fetchServers(a.cfg)

	a.servers = make([]Server, 0, len(serversInfo))
	for _, serverInfo := range serversInfo {
		fields := []zap.Field{
			zap.String("address", serverInfo.Address), zap.Bool("tls enabled", serverInfo.TLS.Enabled),
			zap.String("tls cert", serverInfo.TLS.CertFile), zap.String("tls key", serverInfo.TLS.KeyFile),
		}
		srv, err := newServer(ctx, serverInfo)
		if err != nil {
			a.metrics.MarkUnhealthy(serverInfo.Address)
			a.log.Warn(logs.FailedToAddServer, append(fields, zap.Error(err))...)
			continue
		}
		a.metrics.MarkHealthy(serverInfo.Address)

		a.servers = append(a.servers, srv)
		a.log.Info(logs.AddServer, fields...)
	}

	if len(a.servers) == 0 {
		a.log.Fatal(logs.NoHealthyServers)
	}
}

func (a *app) updateServers() error {
	serversInfo := fetchServers(a.cfg)

	var found bool
	for _, serverInfo := range serversInfo {
		index := a.serverIndex(serverInfo.Address)
		if index == -1 {
			continue
		}

		if serverInfo.TLS.Enabled {
			if err := a.servers[index].UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
				return fmt.Errorf("failed to update tls certs: %w", err)
			}
		}
		found = true
	}

	if !found {
		return fmt.Errorf("invalid servers configuration: no known server found")
	}

	return nil
}

func (a *app) serverIndex(address string) int {
	for i := range a.servers {
		if a.servers[i].Address() == address {
			return i
		}
	}
	return -1
}

func (a *app) initTracing(ctx context.Context) {
	instanceID := ""
	if len(a.servers) > 0 {
		instanceID = a.servers[0].Address()
	}
	cfg := tracing.Config{
		Enabled:    a.cfg.GetBool(cfgTracingEnabled),
		Exporter:   tracing.Exporter(a.cfg.GetString(cfgTracingExporter)),
		Endpoint:   a.cfg.GetString(cfgTracingEndpoint),
		Service:    "frostfs-http-gw",
		InstanceID: instanceID,
		Version:    Version,
	}
	updated, err := tracing.Setup(ctx, cfg)
	if err != nil {
		a.log.Warn(logs.FailedToInitializeTracing, zap.Error(err))
	}
	if updated {
		a.log.Info(logs.TracingConfigUpdated)
	}
}

func (a *app) setRuntimeParameters() {
	if len(os.Getenv("GOMEMLIMIT")) != 0 {
		// default limit < yaml limit < app env limit < GOMEMLIMIT
		a.log.Warn(logs.RuntimeSoftMemoryDefinedWithGOMEMLIMIT)
		return
	}

	softMemoryLimit := fetchSoftMemoryLimit(a.cfg)
	previous := debug.SetMemoryLimit(softMemoryLimit)
	if softMemoryLimit != previous {
		a.log.Info(logs.RuntimeSoftMemoryLimitUpdated,
			zap.Int64("new_value", softMemoryLimit),
			zap.Int64("old_value", previous))
	}
}

func (s *appSettings) NamespaceHeader() string {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.namespaceHeader
}

func (s *appSettings) setNamespaceHeader(nsHeader string) {
	s.mu.Lock()
	s.namespaceHeader = nsHeader
	s.mu.Unlock()
}

func (s *appSettings) FormContainerZone(ns string) (zone string, isDefault bool) {
	s.mu.RLock()
	namespaces := s.defaultNamespaces
	s.mu.RUnlock()
	if slices.Contains(namespaces, ns) {
		return v2container.SysAttributeZoneDefault, true
	}

	return ns + ".ns", false
}

func (s *appSettings) setDefaultNamespaces(namespaces []string) {
	for i := range namespaces { // to be set namespaces in env variable as `HTTP_GW_RESOLVE_BUCKET_DEFAULT_NAMESPACES="" "root"`
		namespaces[i] = strings.Trim(namespaces[i], "\"")
	}

	s.mu.Lock()
	s.defaultNamespaces = namespaces
	s.mu.Unlock()
}