package innerring

import (
	"context"
	"errors"
	"fmt"
	"io"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/innerring/config"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/innerring/processors/governance"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/innerring/processors/netmap"
	timerEvent "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/innerring/timers"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/metrics"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/client"
	balanceClient "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/client/balance"
	nmClient "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/client/netmap"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/event"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/subscriber"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/timer"
	control "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/control/ir"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/precision"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/state"
	"github.com/nspcc-dev/neo-go/pkg/core/block"
	"github.com/nspcc-dev/neo-go/pkg/core/transaction"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/nspcc-dev/neo-go/pkg/encoding/address"
	"github.com/nspcc-dev/neo-go/pkg/util"
	"github.com/spf13/viper"
	"go.uber.org/atomic"
	"go.uber.org/zap"
)

type (
	// Server is the inner ring application structure, that contains all event
	// processors, shared variables and event handlers.
	Server struct {
		log *logger.Logger

		// event producers
		morphListener   event.Listener
		mainnetListener event.Listener
		blockTimers     []*timer.BlockTimer
		epochTimer      *timer.BlockTimer

		// global state
		morphClient   *client.Client
		mainnetClient *client.Client
		epochCounter  atomic.Uint64
		epochDuration atomic.Uint64
		statusIndex   *innerRingIndexer
		precision     precision.Fixed8Converter
		healthStatus  atomic.Value
		balanceClient *balanceClient.Client
		netmapClient  *nmClient.Client
		persistate    *state.PersistentStorage

		// metrics
		metrics *metrics.InnerRingServiceMetrics

		// notary configuration
		feeConfig        *config.FeeConfig
		mainNotaryConfig *notaryConfig
		sideNotaryConfig *notaryConfig

		// internal variables
		key                   *keys.PrivateKey
		pubKey                []byte
		contracts             *contracts
		predefinedValidators  keys.PublicKeys
		initialEpochTickDelta uint32
		withoutMainNet        bool

		// runtime processors
		netmapProcessor *netmap.Processor

		workers []func(context.Context)

		// Set of local resources that must be
		// initialized at the very beginning of
		// Server's work, (e.g. opening files).
		//
		// If any starter returns an error, Server's
		// starting fails immediately.
		starters []func() error

		// Set of local resources that must be
		// released at Server's work completion
		// (e.g closing files).
		//
		// Closer's wrong outcome shouldn't be critical.
		//
		// Errors are logged.
		closers []func() error

		// Set of component runners which
		// should report start errors
		// to the application.
		runners []func(chan<- error) error
	}

	chainParams struct {
		log  *logger.Logger
		cfg  *viper.Viper
		key  *keys.PrivateKey
		name string
		sgn  *transaction.Signer
		from uint32 // block height
	}
)

const (
	morphPrefix   = "morph"
	mainnetPrefix = "mainnet"

	// extra blocks to overlap two deposits, we do that to make sure that
	// there won't be any blocks without deposited assets in notary contract;
	// make sure it is bigger than any extra rounding value in notary client.
	notaryExtraBlocks = 300
	// amount of tries before notary deposit timeout.
	notaryDepositTimeout = 100
)

var (
	errDepositTimeout = errors.New("notary deposit didn't appear in the network")
	errDepositFail    = errors.New("notary tx has faulted")
)

// Start runs all event providers.
func (s *Server) Start(ctx context.Context, intError chan<- error) (err error) {
	s.setHealthStatus(control.HealthStatus_STARTING)
	defer func() {
		if err == nil {
			s.setHealthStatus(control.HealthStatus_READY)
		}
	}()

	err = s.launchStarters()
	if err != nil {
		return err
	}

	err = s.initConfigFromBlockchain()
	if err != nil {
		return err
	}

	if s.IsAlphabet() {
		err = s.initMainNotary(ctx)
		if err != nil {
			return err
		}

		err = s.initSideNotary(ctx)
		if err != nil {
			return err
		}
	}

	prm := governance.VoteValidatorPrm{}
	prm.Validators = s.predefinedValidators

	// vote for sidechain validator if it is prepared in config
	err = s.voteForSidechainValidator(prm)
	if err != nil {
		// we don't stop inner ring execution on this error
		s.log.Warn(logs.InnerringCantVoteForPreparedValidators,
			zap.String("error", err.Error()))
	}

	s.tickInitialExpoch()

	morphErr := make(chan error)
	mainnnetErr := make(chan error)

	// anonymous function to multiplex error channels
	go func() {
		select {
		case <-ctx.Done():
			return
		case err := <-morphErr:
			intError <- fmt.Errorf("sidechain: %w", err)
		case err := <-mainnnetErr:
			intError <- fmt.Errorf("mainnet: %w", err)
		}
	}()

	s.registerMorphNewBlockEventHandler()
	s.registerMainnetNewBlockEventHandler()

	if err := s.startRunners(intError); err != nil {
		return err
	}

	go s.morphListener.ListenWithError(ctx, morphErr)      // listen for neo:morph events
	go s.mainnetListener.ListenWithError(ctx, mainnnetErr) // listen for neo:mainnet events

	if err := s.startBlockTimers(); err != nil {
		return fmt.Errorf("could not start block timers: %w", err)
	}

	s.startWorkers(ctx)

	return nil
}

func (s *Server) registerMorphNewBlockEventHandler() {
	s.morphListener.RegisterBlockHandler(func(b *block.Block) {
		s.log.Debug(logs.InnerringNewBlock,
			zap.Uint32("index", b.Index),
		)

		err := s.persistate.SetUInt32(persistateSideChainLastBlockKey, b.Index)
		if err != nil {
			s.log.Warn(logs.InnerringCantUpdatePersistentState,
				zap.String("chain", "side"),
				zap.Uint32("block_index", b.Index))
		}

		s.tickTimers(b.Index)
	})
}

func (s *Server) registerMainnetNewBlockEventHandler() {
	if !s.withoutMainNet {
		s.mainnetListener.RegisterBlockHandler(func(b *block.Block) {
			err := s.persistate.SetUInt32(persistateMainChainLastBlockKey, b.Index)
			if err != nil {
				s.log.Warn(logs.InnerringCantUpdatePersistentState,
					zap.String("chain", "main"),
					zap.Uint32("block_index", b.Index))
			}
		})
	}
}

func (s *Server) startRunners(errCh chan<- error) error {
	for _, runner := range s.runners {
		if err := runner(errCh); err != nil {
			return err
		}
	}
	return nil
}

func (s *Server) launchStarters() error {
	for _, starter := range s.starters {
		if err := starter(); err != nil {
			return err
		}
	}
	return nil
}

func (s *Server) initMainNotary(ctx context.Context) error {
	if !s.mainNotaryConfig.disabled {
		return s.initNotary(ctx,
			s.depositMainNotary,
			s.awaitMainNotaryDeposit,
			"waiting to accept main notary deposit",
		)
	}
	return nil
}

func (s *Server) initSideNotary(ctx context.Context) error {
	if !s.sideNotaryConfig.disabled {
		return s.initNotary(ctx,
			s.depositSideNotary,
			s.awaitSideNotaryDeposit,
			"waiting to accept side notary deposit",
		)
	}
	return nil
}

func (s *Server) tickInitialExpoch() {
	initialEpochTicker := timer.NewOneTickTimer(
		timer.StaticBlockMeter(s.initialEpochTickDelta),
		func() {
			s.netmapProcessor.HandleNewEpochTick(timerEvent.NewEpochTick{})
		})
	s.addBlockTimer(initialEpochTicker)
}

func (s *Server) startWorkers(ctx context.Context) {
	for _, w := range s.workers {
		go w(ctx)
	}
}

// Stop closes all subscription channels.
func (s *Server) Stop() {
	s.setHealthStatus(control.HealthStatus_SHUTTING_DOWN)

	go s.morphListener.Stop()
	go s.mainnetListener.Stop()

	for _, c := range s.closers {
		if err := c(); err != nil {
			s.log.Warn(logs.InnerringCloserError,
				zap.String("error", err.Error()),
			)
		}
	}
}

func (s *Server) registerNoErrCloser(c func()) {
	s.registerCloser(func() error {
		c()
		return nil
	})
}

func (s *Server) registerIOCloser(c io.Closer) {
	s.registerCloser(c.Close)
}

func (s *Server) registerCloser(f func() error) {
	s.closers = append(s.closers, f)
}

func (s *Server) registerStarter(f func() error) {
	s.starters = append(s.starters, f)
}

// New creates instance of inner ring sever structure.
func New(ctx context.Context, log *logger.Logger, cfg *viper.Viper, errChan chan<- error) (*Server, error) {
	var err error
	server := &Server{log: log}

	server.setHealthStatus(control.HealthStatus_HEALTH_STATUS_UNDEFINED)

	// parse notary support
	server.feeConfig = config.NewFeeConfig(cfg)

	err = server.initKey(cfg)
	if err != nil {
		return nil, err
	}

	server.persistate, err = initPersistentStateStorage(cfg)
	if err != nil {
		return nil, err
	}
	server.registerCloser(server.persistate.Close)

	var morphChain *chainParams
	morphChain, err = server.initMorph(ctx, cfg, errChan)
	if err != nil {
		return nil, err
	}

	err = server.initMainnet(ctx, cfg, morphChain, errChan)
	if err != nil {
		return nil, err
	}

	server.initNotaryConfig()

	err = server.initContracts(cfg)
	if err != nil {
		return nil, err
	}

	err = server.enableNotarySupport()
	if err != nil {
		return nil, err
	}

	// parse default validators
	server.predefinedValidators, err = parsePredefinedValidators(cfg)
	if err != nil {
		return nil, fmt.Errorf("ir: can't parse predefined validators list: %w", err)
	}

	server.pubKey = server.key.PublicKey().Bytes()

	var morphClients *serverMorphClients
	morphClients, err = server.initClientsFromMorph()
	if err != nil {
		return nil, err
	}

	var processors *serverProcessors
	processors, err = server.initProcessors(cfg, morphClients)
	if err != nil {
		return nil, err
	}

	server.initTimers(cfg, processors, morphClients)

	err = server.initGRPCServer(cfg)
	if err != nil {
		return nil, err
	}

	server.initMetrics(cfg)

	return server, nil
}

func createListener(ctx context.Context, cli *client.Client, p *chainParams) (event.Listener, error) {
	var (
		sub subscriber.Subscriber
		err error
	)

	sub, err = subscriber.New(ctx, &subscriber.Params{
		Log:            p.log,
		StartFromBlock: p.from,
		Client:         cli,
	})
	if err != nil {
		return nil, err
	}

	listener, err := event.NewListener(event.ListenerParams{
		Logger:     &logger.Logger{Logger: p.log.With(zap.String("chain", p.name))},
		Subscriber: sub,
	})
	if err != nil {
		return nil, err
	}

	return listener, err
}

func createClient(ctx context.Context, p *chainParams, errChan chan<- error) (*client.Client, error) {
	// config name left unchanged for compatibility, may be its better to rename it to "endpoints" or "clients"
	var endpoints []client.Endpoint

	// defaultPriority is a default endpoint priority
	const defaultPriority = 1

	section := p.name + ".endpoint.client"
	for i := 0; ; i++ {
		addr := p.cfg.GetString(fmt.Sprintf("%s.%d.%s", section, i, "address"))
		if addr == "" {
			break
		}

		priority := p.cfg.GetInt(section + ".priority")
		if priority <= 0 {
			priority = defaultPriority
		}

		endpoints = append(endpoints, client.Endpoint{
			Address:  addr,
			Priority: priority,
		})
	}

	if len(endpoints) == 0 {
		return nil, fmt.Errorf("%s chain client endpoints not provided", p.name)
	}

	return client.New(
		ctx,
		p.key,
		client.WithLogger(p.log),
		client.WithDialTimeout(p.cfg.GetDuration(p.name+".dial_timeout")),
		client.WithSigner(p.sgn),
		client.WithEndpoints(endpoints...),
		client.WithConnLostCallback(func() {
			errChan <- fmt.Errorf("%s chain connection has been lost", p.name)
		}),
		client.WithSwitchInterval(p.cfg.GetDuration(p.name+".switch_interval")),
	)
}

func parsePredefinedValidators(cfg *viper.Viper) (keys.PublicKeys, error) {
	publicKeyStrings := cfg.GetStringSlice("morph.validators")

	return ParsePublicKeysFromStrings(publicKeyStrings)
}

// ParsePublicKeysFromStrings returns slice of neo public keys from slice
// of hex encoded strings.
func ParsePublicKeysFromStrings(pubKeys []string) (keys.PublicKeys, error) {
	publicKeys := make(keys.PublicKeys, 0, len(pubKeys))

	for i := range pubKeys {
		key, err := keys.NewPublicKeyFromString(pubKeys[i])
		if err != nil {
			return nil, fmt.Errorf("can't decode public key: %w", err)
		}

		publicKeys = append(publicKeys, key)
	}

	return publicKeys, nil
}

// parseWalletAddressesFromStrings returns a slice of util.Uint160 from a slice
// of strings.
func parseWalletAddressesFromStrings(wallets []string) ([]util.Uint160, error) {
	if len(wallets) == 0 {
		return nil, nil
	}

	var err error
	extraWallets := make([]util.Uint160, len(wallets))
	for i := range wallets {
		extraWallets[i], err = address.StringToUint160(wallets[i])
		if err != nil {
			return nil, err
		}
	}
	return extraWallets, nil
}

func (s *Server) initConfigFromBlockchain() error {
	// get current epoch
	epoch, err := s.netmapClient.Epoch()
	if err != nil {
		return fmt.Errorf("can't read epoch number: %w", err)
	}

	// get current epoch duration
	epochDuration, err := s.netmapClient.EpochDuration()
	if err != nil {
		return fmt.Errorf("can't read epoch duration: %w", err)
	}

	// get balance precision
	balancePrecision, err := s.balanceClient.Decimals()
	if err != nil {
		return fmt.Errorf("can't read balance contract precision: %w", err)
	}

	s.epochCounter.Store(epoch)
	s.epochDuration.Store(epochDuration)
	s.precision.SetBalancePrecision(balancePrecision)

	// get next epoch delta tick
	s.initialEpochTickDelta, err = s.nextEpochBlockDelta()
	if err != nil {
		return err
	}

	s.log.Debug(logs.InnerringReadConfigFromBlockchain,
		zap.Bool("active", s.IsActive()),
		zap.Bool("alphabet", s.IsAlphabet()),
		zap.Uint64("epoch", epoch),
		zap.Uint32("precision", balancePrecision),
		zap.Uint32("init_epoch_tick_delta", s.initialEpochTickDelta),
	)

	return nil
}

func (s *Server) nextEpochBlockDelta() (uint32, error) {
	epochBlock, err := s.netmapClient.LastEpochBlock()
	if err != nil {
		return 0, fmt.Errorf("can't read last epoch block: %w", err)
	}

	blockHeight, err := s.morphClient.BlockCount()
	if err != nil {
		return 0, fmt.Errorf("can't get side chain height: %w", err)
	}

	delta := uint32(s.epochDuration.Load()) + epochBlock
	if delta < blockHeight {
		return 0, nil
	}

	return delta - blockHeight, nil
}

// onlyAlphabet wrapper around event handler that executes it
// only if inner ring node is alphabet node.
func (s *Server) onlyAlphabetEventHandler(f event.Handler) event.Handler {
	return func(ev event.Event) {
		if s.IsAlphabet() {
			f(ev)
		}
	}
}

func (s *Server) newEpochTickHandlers() []newEpochHandler {
	newEpochHandlers := []newEpochHandler{
		func() {
			s.netmapProcessor.HandleNewEpochTick(timerEvent.NewEpochTick{})
		},
	}

	return newEpochHandlers
}