import os

import pytest

from frostfs_testlib.shell.interfaces import CommandOptions, InteractiveInput
from frostfs_testlib.shell.ssh_shell import SshConnectionProvider, SSHShell
from helpers import format_error_details, get_output_lines


def get_shell(host: str):
    port = os.getenv("SSH_SHELL_PORT", "22")
    login = os.getenv("SSH_SHELL_LOGIN")

    password = os.getenv("SSH_SHELL_PASSWORD", "")
    private_key_path = os.getenv("SSH_SHELL_PRIVATE_KEY_PATH", "")
    private_key_passphrase = os.getenv("SSH_SHELL_PRIVATE_KEY_PASSPHRASE", "")

    if not all([host, login, private_key_path, private_key_passphrase]):
        # TODO: in the future we might use https://pypi.org/project/mock-ssh-server,
        # at the moment it is not suitable for us because of its issues with stdin
        pytest.skip("SSH connection is not configured")

    return SSHShell(
        host=host,
        port=port,
        login=login,
        password=password,
        private_key_path=private_key_path,
        private_key_passphrase=private_key_passphrase,
    )


@pytest.fixture(scope="module")
def shell() -> SSHShell:
    return get_shell(host=os.getenv("SSH_SHELL_HOST"))


@pytest.fixture(scope="module")
def shell_same_host() -> SSHShell:
    return get_shell(host=os.getenv("SSH_SHELL_HOST"))


@pytest.fixture(scope="module")
def shell_another_host() -> SSHShell:
    return get_shell(host=os.getenv("SSH_SHELL_HOST_2"))


@pytest.fixture(scope="function", autouse=True)
def reset_connection():
    provider = SshConnectionProvider()
    provider.drop_all()


class TestSSHShellInteractive:
    def test_command_with_one_prompt(self, shell: SSHShell):
        script = "password = input('Password: '); print('\\n' + password)"

        inputs = [InteractiveInput(prompt_pattern="Password", input="test")]
        result = shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))

        assert result.return_code == 0
        assert ["Password: test", "test"] == get_output_lines(result)
        assert not result.stderr

    def test_command_with_several_prompts(self, shell: SSHShell):
        script = (
            "input1 = input('Input1: '); print('\\n' + input1); "
            "input2 = input('Input2: '); print('\\n' + input2)"
        )
        inputs = [
            InteractiveInput(prompt_pattern="Input1", input="test1"),
            InteractiveInput(prompt_pattern="Input2", input="test2"),
        ]

        result = shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))

        assert result.return_code == 0
        assert ["Input1: test1", "test1", "Input2: test2", "test2"] == get_output_lines(result)
        assert not result.stderr

    def test_invalid_command_with_check(self, shell: SSHShell):
        script = "invalid script"
        inputs = [InteractiveInput(prompt_pattern=".*", input="test")]

        with pytest.raises(RuntimeError) as raised:
            shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))

        error = format_error_details(raised.value)
        assert "SyntaxError" in error
        assert "return code: 1" in error

    def test_invalid_command_without_check(self, shell: SSHShell):
        script = "invalid script"
        inputs = [InteractiveInput(prompt_pattern=".*", input="test")]

        result = shell.exec(
            f'python3 -c "{script}"',
            CommandOptions(interactive_inputs=inputs, check=False),
        )
        assert "SyntaxError" in result.stdout
        assert result.return_code == 1

    def test_non_existing_binary(self, shell: SSHShell):
        inputs = [InteractiveInput(prompt_pattern=".*", input="test")]

        with pytest.raises(RuntimeError) as raised:
            shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs))

        error = format_error_details(raised.value)
        assert "return code: 127" in error


class TestSSHShellNonInteractive:
    def test_correct_command(self, shell: SSHShell):
        script = "print('test')"

        result = shell.exec(f'python3 -c "{script}"')

        assert result.return_code == 0
        assert result.stdout.strip() == "test"
        assert not result.stderr

    def test_invalid_command_with_check(self, shell: SSHShell):
        script = "invalid script"

        with pytest.raises(RuntimeError) as raised:
            shell.exec(f'python3 -c "{script}"')

        error = format_error_details(raised.value)
        assert "Error" in error
        assert "return code: 1" in error

    def test_invalid_command_without_check(self, shell: SSHShell):
        script = "invalid script"

        result = shell.exec(f'python3 -c "{script}"', CommandOptions(check=False))

        assert result.return_code == 1
        # TODO: we have inconsistency with local shell here, the local shell captures error info
        # in stdout while ssh shell captures it in stderr
        assert "Error" in result.stderr

    def test_non_existing_binary(self, shell: SSHShell):
        with pytest.raises(RuntimeError) as raised:
            shell.exec("not-a-command")

        error = format_error_details(raised.value)
        assert "Error" in error
        assert "return code: 127" in error


class TestSSHShellConnection:
    def test_connection_provider_is_singleton(self):
        provider = SshConnectionProvider()
        provider2 = SshConnectionProvider()
        assert id(provider) == id(provider2)

    def test_connection_provider_has_creds(self, shell: SSHShell):
        provider = SshConnectionProvider()
        assert len(provider.creds) == 1
        assert len(provider.connections) == 0

    def test_connection_provider_has_only_one_connection(self, shell: SSHShell):
        provider = SshConnectionProvider()
        assert len(provider.connections) == 0
        shell.exec("echo 1")
        assert len(provider.connections) == 1
        shell.exec("echo 2")
        assert len(provider.connections) == 1
        shell.drop()
        assert len(provider.connections) == 0

    def test_connection_same_host(self, shell: SSHShell, shell_same_host: SSHShell):
        provider = SshConnectionProvider()
        assert len(provider.connections) == 0

        shell.exec("echo 1")
        assert len(provider.connections) == 1

        shell_same_host.exec("echo 2")
        assert len(provider.connections) == 1

        shell.drop()
        assert len(provider.connections) == 0

        shell.exec("echo 3")
        assert len(provider.connections) == 1

    def test_connection_another_host(self, shell: SSHShell, shell_another_host: SSHShell):
        provider = SshConnectionProvider()
        assert len(provider.connections) == 0

        shell.exec("echo 1")
        assert len(provider.connections) == 1

        shell_another_host.exec("echo 2")
        assert len(provider.connections) == 2

        shell.drop()
        assert len(provider.connections) == 1

        shell_another_host.drop()
        assert len(provider.connections) == 0