package control

import (
	"crypto/ecdsa"
	"errors"
	"fmt"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/cmd/frostfs-cli/internal/common"
	"git.frostfs.info/TrueCloudLab/frostfs-node/cmd/frostfs-cli/internal/commonflags"
	"git.frostfs.info/TrueCloudLab/frostfs-node/cmd/frostfs-cli/internal/key"
	commonCmd "git.frostfs.info/TrueCloudLab/frostfs-node/cmd/internal/common"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/control"
	rawclient "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/rpc/client"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client"
	"github.com/spf13/cobra"
)

const (
	netmapStatusFlag = "status"

	netmapStatusOnline      = "online"
	netmapStatusOffline     = "offline"
	netmapStatusMaintenance = "maintenance"

	maxSetStatusMaxWaitTime = 30 * time.Minute
	setStatusWaitTimeout    = 30 * time.Second
)

var errNetmapStatusAwaitFailed = errors.New("netmap status hasn't changed for 30 minutes")

var setNetmapStatusCmd = &cobra.Command{
	Use:   "set-status",
	Short: "Set status of the storage node in FrostFS network map",
	Long:  "Set status of the storage node in FrostFS network map",
	Run:   setNetmapStatus,
}

func initControlSetNetmapStatusCmd() {
	initControlFlags(setNetmapStatusCmd)

	flags := setNetmapStatusCmd.Flags()
	flags.String(netmapStatusFlag, "",
		fmt.Sprintf("New netmap status keyword ('%s', '%s', '%s')",
			netmapStatusOnline,
			netmapStatusOffline,
			netmapStatusMaintenance,
		),
	)

	_ = setNetmapStatusCmd.MarkFlagRequired(netmapStatusFlag)

	flags.BoolP(commonflags.ForceFlag, commonflags.ForceFlagShorthand, false,
		"Force turning to local maintenance")

	flags.Bool(commonflags.AwaitFlag, false, commonflags.AwaitFlagUsage)
}

func setNetmapStatus(cmd *cobra.Command, _ []string) {
	pk := key.Get(cmd)
	body := new(control.SetNetmapStatusRequest_Body)
	force, _ := cmd.Flags().GetBool(commonflags.ForceFlag)

	printIgnoreForce := func(st control.NetmapStatus) {
		if force {
			common.PrintVerbose(cmd, "Ignore --%s flag for %s state.", commonflags.ForceFlag, st)
		}
	}

	await, _ := cmd.Flags().GetBool(commonflags.AwaitFlag)
	var targetStatus control.NetmapStatus
	switch st, _ := cmd.Flags().GetString(netmapStatusFlag); st {
	default:
		commonCmd.ExitOnErr(cmd, "", fmt.Errorf("unsupported status %s", st))
	case netmapStatusOnline:
		body.SetStatus(control.NetmapStatus_ONLINE)
		printIgnoreForce(control.NetmapStatus_ONLINE)
		targetStatus = control.NetmapStatus_ONLINE
	case netmapStatusOffline:
		body.SetStatus(control.NetmapStatus_OFFLINE)
		printIgnoreForce(control.NetmapStatus_OFFLINE)
		targetStatus = control.NetmapStatus_OFFLINE
	case netmapStatusMaintenance:
		body.SetStatus(control.NetmapStatus_MAINTENANCE)

		if force {
			body.SetForceMaintenance(true)
			common.PrintVerbose(cmd, "Local maintenance will be forced.")
		}
		targetStatus = control.NetmapStatus_MAINTENANCE
	}

	req := new(control.SetNetmapStatusRequest)
	req.SetBody(body)

	signRequest(cmd, pk, req)

	cli := getClient(cmd, pk)

	var resp *control.SetNetmapStatusResponse
	var err error
	err = cli.ExecRaw(func(client *rawclient.Client) error {
		resp, err = control.SetNetmapStatus(client, req)
		return err
	})
	commonCmd.ExitOnErr(cmd, "rpc error: %w", err)

	verifyResponse(cmd, resp.GetSignature(), resp.GetBody())

	cmd.Println("Network status update request successfully sent.")

	if await {
		awaitSetNetmapStatus(cmd, pk, cli, targetStatus)
	}
}

func awaitSetNetmapStatus(cmd *cobra.Command, pk *ecdsa.PrivateKey, cli *client.Client, targetStatus control.NetmapStatus) {
	req := &control.GetNetmapStatusRequest{
		Body: &control.GetNetmapStatusRequest_Body{},
	}
	signRequest(cmd, pk, req)
	var epoch uint64
	var status control.NetmapStatus
	startTime := time.Now()
	cmd.Println("Wait until epoch and netmap status change...")
	for {
		var resp *control.GetNetmapStatusResponse
		var err error
		err = cli.ExecRaw(func(client *rawclient.Client) error {
			resp, err = control.GetNetmapStatus(client, req)
			return err
		})
		commonCmd.ExitOnErr(cmd, "failed to get current netmap status: %w", err)

		if epoch == 0 {
			epoch = resp.GetBody().GetEpoch()
		}

		status = resp.GetBody().GetStatus()
		if resp.GetBody().GetEpoch() > epoch {
			epoch = resp.GetBody().GetEpoch()
			cmd.Printf("Epoch changed to %d\n", resp.GetBody().GetEpoch())
		}

		if status == targetStatus {
			break
		}

		if time.Since(startTime) > maxSetStatusMaxWaitTime {
			commonCmd.ExitOnErr(cmd, "failed to wait netmap status: %w", errNetmapStatusAwaitFailed)
			return
		}

		time.Sleep(setStatusWaitTimeout)

		cmd.Printf("Current netmap status '%s', target status '%s'\n", status.String(), targetStatus.String())
	}
	cmd.Printf("Netmap status changed to '%s' successfully.\n", status.String())
}