import itertools
import logging
import os
import random

import allure
import pytest
from frostfs_testlib import reporter
from frostfs_testlib.resources.wellknown_acl import PUBLIC_ACL
from frostfs_testlib.steps.cli.container import StorageContainer, StorageContainerInfo, create_container
from frostfs_testlib.steps.cli.object import get_object, get_object_nodes, put_object
from frostfs_testlib.steps.node_management import check_node_in_map, check_node_not_in_map
from frostfs_testlib.storage.cluster import ClusterNode, StorageNode
from frostfs_testlib.storage.controllers import ClusterStateController
from frostfs_testlib.storage.dataclasses.object_size import ObjectSize
from frostfs_testlib.storage.dataclasses.storage_object_info import StorageObjectInfo
from frostfs_testlib.storage.dataclasses.wallet import WalletInfo
from frostfs_testlib.testing.cluster_test_base import ClusterTestBase
from frostfs_testlib.testing.parallel import parallel, parallel_workers_limit
from frostfs_testlib.testing.test_control import wait_for_success
from frostfs_testlib.utils.file_utils import get_file_hash
from pytest import FixtureRequest

logger = logging.getLogger("NeoLogger")


@pytest.mark.failover
@pytest.mark.failover_server
class TestFailoverServer(ClusterTestBase):
    @wait_for_success(max_wait_time=120, interval=1)
    def wait_node_not_in_map(self, *args, **kwargs):
        check_node_not_in_map(*args, **kwargs)

    @wait_for_success(max_wait_time=120, interval=1)
    def wait_node_in_map(self, *args, **kwargs):
        check_node_in_map(*args, **kwargs)

    @allure.title("[Test] Create containers")
    @pytest.fixture
    def containers(
        self,
        request: FixtureRequest,
        default_wallet: WalletInfo,
    ) -> list[StorageContainer]:

        placement_rule = "REP 2 CBF 2 SELECT 2 FROM *"

        containers_count = request.param
        results = parallel(
            [create_container for _ in range(containers_count)],
            wallet=default_wallet,
            shell=self.shell,
            endpoint=self.cluster.default_rpc_endpoint,
            rule=placement_rule,
            basic_acl=PUBLIC_ACL,
        )

        containers = [
            StorageContainer(StorageContainerInfo(result.result(), default_wallet), self.shell, self.cluster) for result in results
        ]

        return containers

    @allure.title("[Test] Create container")
    @pytest.fixture()
    def container(self, default_wallet: WalletInfo) -> StorageContainer:
        select = len(self.cluster.cluster_nodes)
        placement_rule = f"REP {select - 1} CBF 1 SELECT {select} FROM *"
        cont_id = create_container(
            default_wallet,
            shell=self.shell,
            endpoint=self.cluster.default_rpc_endpoint,
            rule=placement_rule,
            basic_acl=PUBLIC_ACL,
        )
        storage_cont_info = StorageContainerInfo(cont_id, default_wallet)
        return StorageContainer(storage_cont_info, self.shell, self.cluster)

    @allure.title("[Class] Create objects")
    @pytest.fixture(scope="class")
    def storage_objects(
        self,
        request: FixtureRequest,
        containers: list[StorageContainer],
        simple_object_size: ObjectSize,
        complex_object_size: ObjectSize,
    ) -> list[StorageObjectInfo]:
        object_count = request.param
        sizes_samples = [simple_object_size, complex_object_size]
        samples_count = len(sizes_samples)
        assert object_count >= samples_count, f"Object count is too low, must be >= {samples_count}"

        sizes_weights = [2, 1]
        sizes = sizes_samples + random.choices(sizes_samples, weights=sizes_weights, k=object_count - samples_count)

        results = parallel(
            [container.generate_object for _ in sizes for container in containers],
            size=itertools.cycle([size.value for size in sizes]),
        )

        return [result.result() for result in results]

    @allure.title("[Test] Create objects and get nodes with object")
    @pytest.fixture()
    def object_and_nodes(self, simple_object_size: ObjectSize, container: StorageContainer) -> tuple[StorageObjectInfo, list[ClusterNode]]:
        object_info = container.generate_object(simple_object_size.value)
        object_nodes = get_object_nodes(self.cluster, object_info.cid, object_info.oid, self.cluster.cluster_nodes[0])
        return object_info, object_nodes

    def _verify_object(self, storage_object: StorageObjectInfo, node: StorageNode):
        with reporter.step(f"Verify object {storage_object.oid} from node {node}"):
            file_path = get_object(
                storage_object.wallet,
                storage_object.cid,
                storage_object.oid,
                endpoint=node.get_rpc_endpoint(),
                shell=self.shell,
                timeout="60s",
            )

            assert storage_object.file_hash == get_file_hash(file_path)

    @reporter.step("Verify objects")
    def verify_objects(self, nodes: list[StorageNode], storage_objects: list[StorageObjectInfo]) -> None:
        workers_count = os.environ.get("PARALLEL_CUSTOM_LIMIT", 50)
        with parallel_workers_limit(int(workers_count)):
            parallel(self._verify_object, storage_objects * len(nodes), node=itertools.cycle(nodes))

    @allure.title("Full shutdown node")
    @pytest.mark.parametrize("containers, storage_objects", [(5, 10)], indirect=True)
    def test_complete_node_shutdown(
        self,
        storage_objects: list[StorageObjectInfo],
        node_under_test: ClusterNode,
        cluster_state_controller: ClusterStateController,
    ):
        with reporter.step(f"Remove one node from the list of nodes"):
            alive_nodes = list(set(self.cluster.cluster_nodes) - {node_under_test})

        storage_nodes = [cluster.storage_node for cluster in alive_nodes]

        with reporter.step("Tick 2 epochs and wait for 2 blocks"):
            self.tick_epochs(2, storage_nodes[0], wait_block=2)

        with reporter.step(f"Stop node"):
            cluster_state_controller.stop_node_host(node_under_test, "hard")

        with reporter.step("Verify that there are no corrupted objects"):
            self.verify_objects(storage_nodes, storage_objects)

        with reporter.step(f"Check node still in map"):
            self.wait_node_in_map(node_under_test.storage_node, self.shell, alive_node=storage_nodes[0])

        count_tick_epoch = int(alive_nodes[0].ir_node.get_netmap_cleaner_threshold()) + 4

        with reporter.step(f"Tick {count_tick_epoch} epochs and wait for 2 blocks"):
            self.tick_epochs(count_tick_epoch, storage_nodes[0], wait_block=2)

        with reporter.step(f"Check node in not map after {count_tick_epoch} epochs"):
            self.wait_node_not_in_map(node_under_test.storage_node, self.shell, alive_node=storage_nodes[0])

        with reporter.step(f"Verify that there are no corrupted objects after {count_tick_epoch} epochs"):
            self.verify_objects(storage_nodes, storage_objects)

    @allure.title("Temporarily disable a node")
    @pytest.mark.parametrize("containers, storage_objects", [(5, 10)], indirect=True)
    def test_temporarily_disable_a_node(
        self,
        storage_objects: list[StorageObjectInfo],
        node_under_test: ClusterNode,
        cluster_state_controller: ClusterStateController,
    ):
        with reporter.step(f"Remove one node from the list"):
            storage_nodes = list(set(self.cluster.storage_nodes) - {node_under_test.storage_node})

        with reporter.step("Tick 2 epochs and wait for 2 blocks"):
            self.tick_epochs(2, storage_nodes[0], wait_block=2)

        with reporter.step(f"Stop node"):
            cluster_state_controller.stop_node_host(node_under_test, "hard")

        with reporter.step("Verify that there are no corrupted objects"):
            self.verify_objects(storage_nodes, storage_objects)

        with reporter.step(f"Check node still in map"):
            self.wait_node_in_map(node_under_test.storage_node, self.shell, alive_node=storage_nodes[0])

        with reporter.step(f"Start node"):
            cluster_state_controller.start_node_host(node_under_test)

        with reporter.step("Verify that there are no corrupted objects"):
            self.verify_objects(storage_nodes, storage_objects)

    @allure.title("Not enough nodes in the container with policy - 'REP 3 CBF 1 SELECT 4 FROM *'")
    def test_not_enough_nodes_in_container_rep_3(
        self,
        object_and_nodes: tuple[StorageObjectInfo, list[ClusterNode]],
        default_wallet: WalletInfo,
        cluster_state_controller: ClusterStateController,
        simple_file: str,
    ):
        object_info, object_nodes = object_and_nodes
        endpoint_without_object = list(set(self.cluster.cluster_nodes) - set(object_nodes))[0].storage_node.get_rpc_endpoint()
        endpoint_with_object = object_nodes[0].storage_node.get_rpc_endpoint()

        with reporter.step("Stop all nodes with object except first one"):
            parallel(cluster_state_controller.stop_node_host, object_nodes[1:], mode="hard")

        with reporter.step(f"Get object from node without object"):
            get_object(default_wallet, object_info.cid, object_info.oid, self.shell, endpoint_without_object)

        with reporter.step(f"Get object from node with object"):
            get_object(default_wallet, object_info.cid, object_info.oid, self.shell, endpoint_with_object)

        with reporter.step(f"[Negative] Put operation to node with object"):
            with pytest.raises(RuntimeError):
                put_object(default_wallet, simple_file, object_info.cid, self.shell, endpoint_with_object)

    @allure.title("Not enough nodes in the container with policy - 'REP 2 CBF 2 SELECT 4 FROM *'")
    def test_not_enough_nodes_in_container_rep_2(
        self,
        default_wallet: WalletInfo,
        cluster_state_controller: ClusterStateController,
        simple_file: str,
    ):
        with reporter.step("Create container with full network map"):
            node_count = len(self.cluster.cluster_nodes)
            placement_rule = f"REP {node_count - 2} IN X CBF 2 SELECT {node_count} FROM * AS X"
            cid = create_container(
                default_wallet,
                self.shell,
                self.cluster.default_rpc_endpoint,
                rule=placement_rule,
                basic_acl=PUBLIC_ACL,
            )

        with reporter.step("Put object"):
            oid = put_object(default_wallet, simple_file, cid, self.shell, self.cluster.default_rpc_endpoint)

        with reporter.step("Search nodes with object"):
            object_nodes = get_object_nodes(self.cluster, cid, oid, self.cluster.cluster_nodes[0])

        with reporter.step("Choose node to stop"):
            node_under_test = random.choice(object_nodes)
            alive_node_with_object = random.choice(list(set(object_nodes) - {node_under_test}))
            alive_endpoint_with_object = alive_node_with_object.storage_node.get_rpc_endpoint()

        with reporter.step("Stop random node with object"):
            cluster_state_controller.stop_node_host(node_under_test, "hard")

        with reporter.step("Put object to alive node with object"):
            oid_2 = put_object(default_wallet, simple_file, cid, self.shell, alive_endpoint_with_object)

        with reporter.step("Get object from alive node with object"):
            get_object(default_wallet, cid, oid_2, self.shell, alive_endpoint_with_object)

        with reporter.step("Create container on alive node"):
            create_container(
                default_wallet,
                self.shell,
                alive_endpoint_with_object,
                rule=placement_rule,
                basic_acl=PUBLIC_ACL,
            )