From d3e5ee22313087b16dfdf99eef28533e4c9704f9 Mon Sep 17 00:00:00 2001 From: Vladimir Domnich Date: Wed, 24 Aug 2022 15:41:11 +0400 Subject: [PATCH] Implement basic version of ssh shell Signed-off-by: Vladimir Domnich --- .gitignore | 3 - .pre-commit-config.yaml | 11 + README.md | 10 + pyproject.toml | 8 + reporter/__init__.py | 6 +- reporter/allure_reporter.py | 2 +- reporter/dummy_reporter.py | 2 +- reporter/interfaces.py | 5 +- requirements.txt | 4 + shell/interfaces.py | 6 +- shell/local_shell.py | 50 ++-- shell/ssh_shell.py | 239 ++++++++++++++++++ tests/helpers.py | 25 +- ...ell_interactive.py => test_local_shell.py} | 62 ++++- tests/test_local_shell_non_interactive.py | 46 ---- tests/test_ssh_shell.py | 138 ++++++++++ 16 files changed, 525 insertions(+), 92 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml create mode 100644 shell/ssh_shell.py rename tests/{test_local_shell_interactive.py => test_local_shell.py} (53%) delete mode 100644 tests/test_local_shell_non_interactive.py create mode 100644 tests/test_ssh_shell.py diff --git a/.gitignore b/.gitignore index e61db30..743b23b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,3 @@ # ignore caches under any path **/__pycache__ - -# ignore virtual environments -venv*/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ad9846a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: + - repo: https://github.com/psf/black + rev: 22.8.0 + hooks: + - id: black + language_version: python3.9 + - repo: https://github.com/pycqa/isort + rev: 5.10.1 + hooks: + - id: isort + name: isort (python) diff --git a/README.md b/README.md index 736ec88..e493042 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,16 @@ $ source venv/bin/activate $ pip install -r requirements.txt ``` +3. Setup pre-commit hooks to run code formatters on staged files before you run a `git commit` command: + +``` +pre-commit install +``` + +Optionally you might want to integrate code formatters with your code editor to apply formatters to code files as you go: +* isort is supported by [PyCharm](https://plugins.jetbrains.com/plugin/15434-isortconnect), [VS Code](https://cereblanco.medium.com/setup-black-and-isort-in-vscode-514804590bf9). Plugins exist for other IDEs/editors as well. +* black can be integrated with multiple editors, please, instructions are available [here](https://black.readthedocs.io/en/stable/integrations/editors.html). + ### Unit Tests Before submitting any changes to the library, please, make sure that all unit tests are passing. To run the tests, please, use the following command: ``` diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bd0087b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.isort] +profile = "black" +src_paths = ["reporter", "shell", "tests"] +line_length = 100 + +[tool.black] +line-length = 100 +target-version = ["py39"] diff --git a/reporter/__init__.py b/reporter/__init__.py index d312fcc..31bbdf7 100644 --- a/reporter/__init__.py +++ b/reporter/__init__.py @@ -1,8 +1,8 @@ import os -from .allure_reporter import AllureReporter -from .interfaces import Reporter -from .dummy_reporter import DummyReporter +from reporter.allure_reporter import AllureReporter +from reporter.dummy_reporter import DummyReporter +from reporter.interfaces import Reporter def get_reporter() -> Reporter: diff --git a/reporter/allure_reporter.py b/reporter/allure_reporter.py index 6522859..0277214 100644 --- a/reporter/allure_reporter.py +++ b/reporter/allure_reporter.py @@ -6,7 +6,7 @@ from typing import Any import allure from allure import attachment_type -from .interfaces import Reporter +from reporter.interfaces import Reporter class AllureReporter(Reporter): diff --git a/reporter/dummy_reporter.py b/reporter/dummy_reporter.py index e559193..9061101 100644 --- a/reporter/dummy_reporter.py +++ b/reporter/dummy_reporter.py @@ -1,7 +1,7 @@ from contextlib import AbstractContextManager, contextmanager from typing import Any -from .interfaces import Reporter +from reporter.interfaces import Reporter @contextmanager diff --git a/reporter/interfaces.py b/reporter/interfaces.py index de7bcb7..347f71f 100644 --- a/reporter/interfaces.py +++ b/reporter/interfaces.py @@ -16,14 +16,13 @@ class Reporter(ABC): :param str name: Name of the step :return: step context """ - pass @abstractmethod def attach(self, content: Any, file_name: str) -> None: """ Attach specified content with given file name to the test report. - :param any name: content to attach. If not a string, it will be converted to a string. + :param any content: content to attach. If content value is not a string, it will be + converted to a string. :param str file_name: file name of attachment. """ - pass diff --git a/requirements.txt b/requirements.txt index babc3a7..5e62371 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,6 @@ allure-python-commons==2.9.45 +black==22.8.0 +isort==5.10.1 +paramiko==2.10.3 pexpect==4.8.0 +pre-commit==2.20.0 diff --git a/shell/interfaces.py b/shell/interfaces.py index 1e194dd..97ba7cc 100644 --- a/shell/interfaces.py +++ b/shell/interfaces.py @@ -11,6 +11,7 @@ class InteractiveInput: :attr str prompt_pattern: regular expression that defines expected prompt from the command. :attr str input: user input that should be supplied to the command in response to the prompt. """ + prompt_pattern: str input: str @@ -21,11 +22,12 @@ class CommandOptions: Options that control command execution. :attr list interactive_inputs: user inputs that should be interactively supplied to - the command during its' execution. + the command during execution. :attr int timeout: timeout for command execution (in seconds). :attr bool check: controls whether to check return code of the command. Set to False to ignore non-zero return codes. """ + interactive_inputs: Optional[list[InteractiveInput]] = None timeout: int = 30 check: bool = True @@ -36,6 +38,7 @@ class CommandResult: """ Represents a result of a command executed via shell. """ + stdout: str stderr: str return_code: int @@ -56,4 +59,3 @@ class Shell(ABC): :param CommandOptions options: options that control command execution. :return command result. """ - pass diff --git a/shell/local_shell.py b/shell/local_shell.py index 2345ede..f542cc4 100644 --- a/shell/local_shell.py +++ b/shell/local_shell.py @@ -9,12 +9,15 @@ import pexpect from reporter import get_reporter from shell.interfaces import CommandOptions, CommandResult, Shell - logger = logging.getLogger("neofs.testlib.shell") reporter = get_reporter() class LocalShell(Shell): + """ + Implements command shell on a local machine. + """ + def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult: # If no options were provided, use default options options = options or CommandOptions() @@ -41,12 +44,16 @@ class LocalShell(Shell): result = self._get_pexpect_process_result(command_process, command) if options.check and result.return_code != 0: - raise RuntimeError(f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}") + raise RuntimeError( + f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}" + ) return result except pexpect.ExceptionPexpect as exc: result = self._get_pexpect_process_result(command_process, command) - message = f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}" + message = ( + f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}" + ) if options.check: raise RuntimeError(message) from exc else: @@ -54,7 +61,9 @@ class LocalShell(Shell): return result except OSError as exc: result = self._get_pexpect_process_result(command_process, command) - message = f"Command: {command}\nreturn code: {result.return_code}\nOutput: {exc.strerror}" + message = ( + f"Command: {command}\nreturn code: {result.return_code}\nOutput: {exc.strerror}" + ) if options.check: raise RuntimeError(message) from exc else: @@ -80,7 +89,7 @@ class LocalShell(Shell): stdout=subprocess.PIPE, stderr=subprocess.STDOUT, timeout=options.timeout, - shell=True + shell=True, ) result = CommandResult( @@ -92,9 +101,11 @@ class LocalShell(Shell): except subprocess.CalledProcessError as exc: # TODO: always set check flag to false and capture command result normally result = self._get_failing_command_result(command) - raise RuntimeError(f"Command: {command}\nError:\n" - f"return code: {exc.returncode}\n" - f"output: {exc.output}") from exc + raise RuntimeError( + f"Command: {command}\nError:\n" + f"return code: {exc.returncode}\n" + f"output: {exc.output}" + ) from exc except OSError as exc: raise RuntimeError(f"Command: {command}\nOutput: {exc.strerror}") from exc except Exception as exc: @@ -106,14 +117,11 @@ class LocalShell(Shell): def _get_failing_command_result(self, command: str) -> CommandResult: return_code, cmd_output = subprocess.getstatusoutput(command) - return CommandResult( - stdout=cmd_output, - stderr="", - return_code=return_code - ) + return CommandResult(stdout=cmd_output, stderr="", return_code=return_code) - def _get_pexpect_process_result(self, command_process: Optional[pexpect.spawn], - command: str) -> CommandResult: + def _get_pexpect_process_result( + self, command_process: Optional[pexpect.spawn], command: str + ) -> CommandResult: """ If command process is not None, captures output of this process. If command process is None, then command fails when we attempt to start it, in this case @@ -137,14 +145,20 @@ class LocalShell(Shell): return CommandResult(stdout=output, stderr="", return_code=return_code) - def _report_command_result(self, command: str, start_time: datetime, end_time: datetime, - result: Optional[CommandResult]) -> None: + def _report_command_result( + self, + command: str, + start_time: datetime, + end_time: datetime, + result: Optional[CommandResult], + ) -> None: # TODO: increase logging level if return code is non 0, should be warning at least logger.info( f"Command: {command}\n" f"{'Success:' if result and result.return_code == 0 else 'Error:'}\n" f"return code: {result.return_code if result else ''} " - f"\nOutput: {result.stdout if result else ''}") + f"\nOutput: {result.stdout if result else ''}" + ) if result: elapsed_time = end_time - start_time diff --git a/shell/ssh_shell.py b/shell/ssh_shell.py new file mode 100644 index 0000000..5a272a3 --- /dev/null +++ b/shell/ssh_shell.py @@ -0,0 +1,239 @@ +import logging +import socket +import textwrap +from datetime import datetime +from functools import lru_cache, wraps +from time import sleep +from typing import ClassVar, Optional + +from paramiko import ( + AutoAddPolicy, + ECDSAKey, + Ed25519Key, + PKey, + RSAKey, + SSHClient, + SSHException, + ssh_exception, +) +from paramiko.ssh_exception import AuthenticationException + +from reporter import get_reporter +from shell.interfaces import CommandOptions, CommandResult, Shell + +logger = logging.getLogger("neofs.testlib.shell") +reporter = get_reporter() + + +class HostIsNotAvailable(Exception): + """Raised when host is not reachable via SSH connection""" + + def __init__(self, host: str = None): + msg = f"Host {host} is not available" + super().__init__(msg) + + +def log_command(func): + @wraps(func) + def wrapper(shell: "SSHShell", command: str, *args, **kwargs) -> CommandResult: + command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n") + with reporter.step(command_info): + logging.info(f'Execute command "{command}" on "{shell.host}"') + + start_time = datetime.utcnow() + result = func(shell, command, *args, **kwargs) + end_time = datetime.utcnow() + + elapsed_time = end_time - start_time + log_message = ( + f"HOST: {shell.host}\n" + f"COMMAND:\n{textwrap.indent(command, ' ')}\n" + f"RC:\n {result.return_code}\n" + f"STDOUT:\n{textwrap.indent(result.stdout, ' ')}\n" + f"STDERR:\n{textwrap.indent(result.stderr, ' ')}\n" + f"Start / End / Elapsed\t {start_time.time()} / {end_time.time()} / {elapsed_time}" + ) + + logger.info(log_message) + reporter.attach(log_message, "SSH command.txt") + return result + + return wrapper + + +@lru_cache +def _load_private_key(file_path: str, password: Optional[str]) -> PKey: + """ + Loads private key from specified file. + + We support several type formats, however paramiko doesn't provide functionality to determine + key type in advance. So we attempt to load file with each of the supported formats and then + cache the result so that we don't need to figure out type again on subsequent calls. + """ + logger.debug(f"Loading ssh key from {file_path}") + for key_type in (Ed25519Key, ECDSAKey, RSAKey): + try: + return key_type.from_private_key_file(file_path, password) + except SSHException as ex: + logger.warn(f"SSH key {file_path} can't be loaded with {key_type}: {ex}") + continue + raise SSHException(f"SSH key {file_path} is not supported") + + +class SSHShell(Shell): + """ + Implements command shell on a remote machine via SSH connection. + """ + + # Time in seconds to delay after remote command has completed. The delay is required + # to allow remote command to flush its output buffer + DELAY_AFTER_EXIT = 0.2 + + SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 3 + CONNECTION_TIMEOUT = 90 + + def __init__( + self, + host: str, + login: str, + password: Optional[str] = None, + private_key_path: Optional[str] = None, + private_key_passphrase: Optional[str] = None, + port: str = "22", + ) -> None: + self.host = host + self.port = port + self.login = login + self.password = password + self.private_key_path = private_key_path + self.private_key_passphrase = private_key_passphrase + self.__connection: Optional[SSHClient] = None + + @property + def _connection(self): + if not self.__connection: + self.__connection = self._create_connection() + return self.__connection + + def drop(self): + self._reset_connection() + + def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult: + options = options or CommandOptions() + + if options.interactive_inputs: + result = self._exec_interactive(command, options) + else: + result = self._exec_non_interactive(command, options) + + if options.check and result.return_code != 0: + raise RuntimeError( + f"Command: {command}\nreturn code: {result.return_code}" + f"\nOutput: {result.stdout}" + ) + return result + + @log_command + def _exec_interactive(self, command: str, options: CommandOptions) -> CommandResult: + stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout) + for interactive_input in options.interactive_inputs: + input = interactive_input.input + if not input.endswith("\n"): + input = f"{input}\n" + try: + stdin.write(input) + except OSError: + logger.exception(f"Error while feeding {input} into command {command}") + # stdin.close() + + # Wait for command to complete and flush its buffer before we attempt to read output + sleep(self.DELAY_AFTER_EXIT) + return_code = stdout.channel.recv_exit_status() + sleep(self.DELAY_AFTER_EXIT) + + result = CommandResult( + stdout=stdout.read().decode(errors="ignore"), + stderr=stderr.read().decode(errors="ignore"), + return_code=return_code, + ) + return result + + @log_command + def _exec_non_interactive(self, command: str, options: CommandOptions) -> CommandResult: + try: + _, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout) + + # Wait for command to complete and flush its buffer before we attempt to read output + return_code = stdout.channel.recv_exit_status() + sleep(self.DELAY_AFTER_EXIT) + + return CommandResult( + stdout=stdout.read().decode(errors="ignore"), + stderr=stderr.read().decode(errors="ignore"), + return_code=return_code, + ) + except ( + SSHException, + TimeoutError, + ssh_exception.NoValidConnectionsError, + ConnectionResetError, + AttributeError, + socket.timeout, + ) as exc: + logger.exception(f"Can't execute command {command} on host: {self.host}") + self._reset_connection() + raise HostIsNotAvailable(self.host) from exc + + def _create_connection(self, attempts: int = SSH_CONNECTION_ATTEMPTS) -> SSHClient: + for attempt in range(attempts): + connection = SSHClient() + connection.set_missing_host_key_policy(AutoAddPolicy()) + try: + if self.private_key_path: + logging.info( + f"Trying to connect to host {self.host} as {self.login} using SSH key " + f"{self.private_key_path} (attempt {attempt})" + ) + connection.connect( + hostname=self.host, + port=self.port, + username=self.login, + pkey=_load_private_key(self.private_key_path, self.private_key_passphrase), + timeout=self.CONNECTION_TIMEOUT, + ) + else: + logging.info( + f"Trying to connect to host {self.host} as {self.login} using password " + f"(attempt {attempt})" + ) + connection.connect( + hostname=self.host, + port=self.port, + username=self.login, + password=self.password, + timeout=self.CONNECTION_TIMEOUT, + ) + return connection + except AuthenticationException: + connection.close() + logger.exception(f"Can't connect to host {self.host}") + raise + except ( + SSHException, + ssh_exception.NoValidConnectionsError, + AttributeError, + socket.timeout, + OSError, + ) as exc: + connection.close() + can_retry = attempt + 1 < attempts + if can_retry: + logger.warn(f"Can't connect to host {self.host}, will retry. Error: {exc}") + continue + logger.exception(f"Can't connect to host {self.host}") + raise HostIsNotAvailable(self.host) from exc + + def _reset_connection(self) -> None: + if self.__connection: + self.__connection.close() + self.__connection = None diff --git a/tests/helpers.py b/tests/helpers.py index 6035651..b80be61 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,9 +1,30 @@ import traceback +from shell.interfaces import CommandResult + def format_error_details(error: Exception) -> str: - return "".join(traceback.format_exception( + """ + Converts specified exception instance into a string that includes error message + and full stack trace. + + :param Exception error: exception to convert. + :return: string containing exception details. + """ + detail_lines = traceback.format_exception( etype=type(error), value=error, - tb=error.__traceback__) + tb=error.__traceback__, ) + return "".join(detail_lines) + + +def get_output_lines(result: CommandResult) -> list[str]: + """ + Converts output of specified command result into separate lines trimmed from whitespaces. + Empty lines are excluded. + + :param CommandResult result: result which output should be converted. + :return: list of lines extracted from the output. + """ + return [line.strip() for line in result.stdout.split("\n") if line.strip()] diff --git a/tests/test_local_shell_interactive.py b/tests/test_local_shell.py similarity index 53% rename from tests/test_local_shell_interactive.py rename to tests/test_local_shell.py index 278d3b1..52e3861 100644 --- a/tests/test_local_shell_interactive.py +++ b/tests/test_local_shell.py @@ -2,7 +2,7 @@ from unittest import TestCase from shell.interfaces import CommandOptions, InteractiveInput from shell.local_shell import LocalShell -from tests.helpers import format_error_details +from tests.helpers import format_error_details, get_output_lines class TestLocalShellInteractive(TestCase): @@ -15,12 +15,11 @@ class TestLocalShellInteractive(TestCase): inputs = [InteractiveInput(prompt_pattern="Password", input="test")] result = self.shell.exec( - f"python -c \"{script}\"", - CommandOptions(interactive_inputs=inputs) + f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs) ) self.assertEqual(0, result.return_code) - self.assertOutputLines(["Password: test", "test"], result.stdout) + self.assertEqual(["Password: test", "test"], get_output_lines(result)) self.assertEqual("", result.stderr) def test_command_with_several_prompts(self): @@ -34,12 +33,13 @@ class TestLocalShellInteractive(TestCase): ] result = self.shell.exec( - f"python -c \"{script}\"", - CommandOptions(interactive_inputs=inputs) + f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs) ) self.assertEqual(0, result.return_code) - self.assertOutputLines(["Input1: test1", "test1", "Input2: test2", "test2"], result.stdout) + self.assertEqual( + ["Input1: test1", "test1", "Input2: test2", "test2"], get_output_lines(result) + ) self.assertEqual("", result.stderr) def test_failed_command_with_check(self): @@ -47,7 +47,7 @@ class TestLocalShellInteractive(TestCase): inputs = [InteractiveInput(prompt_pattern=".*", input="test")] with self.assertRaises(RuntimeError) as exc: - self.shell.exec(f"python -c \"{script}\"", CommandOptions(interactive_inputs=inputs)) + self.shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)) error = format_error_details(exc.exception) self.assertIn("Error", error) @@ -59,7 +59,7 @@ class TestLocalShellInteractive(TestCase): inputs = [InteractiveInput(prompt_pattern=".*", input="test")] result = self.shell.exec( - f"python -c \"{script}\"", + f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs, check=False), ) self.assertEqual(1, result.return_code) @@ -71,8 +71,44 @@ class TestLocalShellInteractive(TestCase): self.shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs)) error = format_error_details(exc.exception) - self.assertIn("command was not found or was not executable", error) + self.assertIn("return code: 127", error) - def assertOutputLines(self, expected_lines: list[str], output: str) -> None: - output_lines = [line.strip() for line in output.split("\n") if line.strip()] - self.assertEqual(expected_lines, output_lines) + +class TestLocalShellNonInteractive(TestCase): + @classmethod + def setUpClass(cls): + cls.shell = LocalShell() + + def test_successful_command(self): + script = "print('test')" + + result = self.shell.exec(f'python3 -c "{script}"') + + self.assertEqual(0, result.return_code) + self.assertEqual("test", result.stdout.strip()) + self.assertEqual("", result.stderr) + + def test_invalid_command_with_check(self): + script = "invalid script" + + with self.assertRaises(RuntimeError) as exc: + self.shell.exec(f'python3 -c "{script}"') + + error = format_error_details(exc.exception) + self.assertIn("Error", error) + self.assertIn("return code: 1", error) + + def test_invalid_command_without_check(self): + script = "invalid script" + + result = self.shell.exec(f'python3 -c "{script}"', CommandOptions(check=False)) + + self.assertEqual(1, result.return_code) + self.assertIn("Error", result.stdout) + + def test_non_existing_binary(self): + with self.assertRaises(RuntimeError) as exc: + self.shell.exec("not-a-command") + + error = format_error_details(exc.exception) + self.assertIn("return code: 127", error) diff --git a/tests/test_local_shell_non_interactive.py b/tests/test_local_shell_non_interactive.py deleted file mode 100644 index 3fe04f5..0000000 --- a/tests/test_local_shell_non_interactive.py +++ /dev/null @@ -1,46 +0,0 @@ -from unittest import TestCase - -from shell.interfaces import CommandOptions -from shell.local_shell import LocalShell -from tests.helpers import format_error_details - - -class TestLocalShellNonInteractive(TestCase): - @classmethod - def setUpClass(cls): - cls.shell = LocalShell() - - def test_successful_command(self): - script = "print('test')" - - result = self.shell.exec(f"python -c \"{script}\"") - - self.assertEqual(0, result.return_code) - self.assertEqual("test", result.stdout.strip()) - self.assertEqual("", result.stderr) - - def test_failed_command_with_check(self): - script = "invalid script" - - with self.assertRaises(RuntimeError) as exc: - self.shell.exec(f"python -c \"{script}\"") - - error = format_error_details(exc.exception) - self.assertIn("Error", error) - self.assertIn("return code: 1", error) - - def test_failed_command_without_check(self): - script = "invalid script" - - result = self.shell.exec(f"python -c \"{script}\"", CommandOptions(check=False)) - - self.assertEqual(1, result.return_code) - self.assertIn("Error", result.stdout) - - def test_non_existing_binary(self): - with self.assertRaises(RuntimeError) as exc: - self.shell.exec(f"not-a-command") - - error = format_error_details(exc.exception) - self.assertIn("Error", error) - self.assertIn("return code: 127", error) diff --git a/tests/test_ssh_shell.py b/tests/test_ssh_shell.py new file mode 100644 index 0000000..213b7cf --- /dev/null +++ b/tests/test_ssh_shell.py @@ -0,0 +1,138 @@ +import os +from unittest import SkipTest, TestCase + +from shell.interfaces import CommandOptions, InteractiveInput +from shell.ssh_shell import SSHShell +from tests.helpers import format_error_details, get_output_lines + + +def init_shell() -> SSHShell: + host = os.getenv("SSH_SHELL_HOST") + port = os.getenv("SSH_SHELL_PORT", "22") + login = os.getenv("SSH_SHELL_LOGIN") + private_key_path = os.getenv("SSH_SHELL_PRIVATE_KEY_PATH") + private_key_passphrase = os.getenv("SSH_SHELL_PRIVATE_KEY_PASSPHRASE") + + if not all([host, login, private_key_path, private_key_passphrase]): + # TODO: in the future we might use https://pypi.org/project/mock-ssh-server, + # at the moment it is not suitable for us because of its issues with stdin + raise SkipTest("SSH connection is not configured") + + return SSHShell( + host=host, + port=port, + login=login, + private_key_path=private_key_path, + private_key_passphrase=private_key_passphrase, + ) + + +class TestSSHShellInteractive(TestCase): + @classmethod + def setUpClass(cls): + cls.shell = init_shell() + + def test_command_with_one_prompt(self): + script = "password = input('Password: '); print('\\n' + password)" + + inputs = [InteractiveInput(prompt_pattern="Password", input="test")] + result = self.shell.exec( + f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs) + ) + + # TODO: we have inconsistency with local shell here, ssh does not echo input into stdout + self.assertEqual(0, result.return_code) + self.assertEqual(["Password:", "test"], get_output_lines(result)) + self.assertEqual("", result.stderr) + + def test_command_with_several_prompts(self): + script = ( + "input1 = input('Input1: '); print('\\n' + input1); " + "input2 = input('Input2: '); print('\\n' + input2)" + ) + inputs = [ + InteractiveInput(prompt_pattern="Input1", input="test1"), + InteractiveInput(prompt_pattern="Input2", input="test2"), + ] + + result = self.shell.exec( + f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs) + ) + + # TODO: we have inconsistency with local shell here, ssh does not echo input into stdout + self.assertEqual(0, result.return_code) + self.assertEqual(["Input1:", "test1", "Input2:", "test2"], get_output_lines(result)) + self.assertEqual("", result.stderr) + + def test_invalid_command_with_check(self): + script = "invalid script" + inputs = [InteractiveInput(prompt_pattern=".*", input="test")] + + with self.assertRaises(RuntimeError) as raised: + self.shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)) + + error = format_error_details(raised.exception) + self.assertIn("Error", error) + self.assertIn("return code: 1", error) + + def test_invalid_command_without_check(self): + script = "invalid script" + inputs = [InteractiveInput(prompt_pattern=".*", input="test")] + + result = self.shell.exec( + f'python3 -c "{script}"', + CommandOptions(interactive_inputs=inputs, check=False), + ) + self.assertEqual(1, result.return_code) + + def test_non_existing_binary(self): + inputs = [InteractiveInput(prompt_pattern=".*", input="test")] + + with self.assertRaises(RuntimeError) as raised: + self.shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs)) + + error = format_error_details(raised.exception) + self.assertIn("return code: 127", error) + + +class TestSSHShellNonInteractive(TestCase): + @classmethod + def setUpClass(cls): + cls.shell = init_shell() + + def test_correct_command(self): + script = "print('test')" + + result = self.shell.exec(f'python3 -c "{script}"') + + self.assertEqual(0, result.return_code) + self.assertEqual("test", result.stdout.strip()) + self.assertEqual("", result.stderr) + + def test_invalid_command_with_check(self): + script = "invalid script" + + with self.assertRaises(RuntimeError) as raised: + self.shell.exec(f'python3 -c "{script}"') + + error = format_error_details(raised.exception) + self.assertIn("Error", error) + self.assertIn("return code: 1", error) + + def test_invalid_command_without_check(self): + script = "invalid script" + + result = self.shell.exec(f'python3 -c "{script}"', CommandOptions(check=False)) + + self.assertEqual(1, result.return_code) + # TODO: we have inconsistency with local shell here, the local shell captures error info + # in stdout while ssh shell captures it in stderr + self.assertIn("Error", result.stderr) + + def test_non_existing_binary(self): + with self.assertRaises(RuntimeError) as exc: + self.shell.exec("not-a-command") + + error = format_error_details(exc.exception) + self.assertIn("Error", error) + self.assertIn("return code: 127", error)