Implement basic version of ssh shell

Signed-off-by: Vladimir Domnich <v.domnich@yadro.com>
This commit is contained in:
Vladimir Domnich 2022-08-24 15:41:11 +04:00 committed by anastasia prasolova
parent f6ee129354
commit d3e5ee2231
16 changed files with 525 additions and 92 deletions

3
.gitignore vendored
View file

@ -3,6 +3,3 @@
# ignore caches under any path # ignore caches under any path
**/__pycache__ **/__pycache__
# ignore virtual environments
venv*/*

11
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,11 @@
repos:
- repo: https://github.com/psf/black
rev: 22.8.0
hooks:
- id: black
language_version: python3.9
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
name: isort (python)

View file

@ -25,6 +25,16 @@ $ source venv/bin/activate
$ pip install -r requirements.txt $ pip install -r requirements.txt
``` ```
3. Setup pre-commit hooks to run code formatters on staged files before you run a `git commit` command:
```
pre-commit install
```
Optionally you might want to integrate code formatters with your code editor to apply formatters to code files as you go:
* isort is supported by [PyCharm](https://plugins.jetbrains.com/plugin/15434-isortconnect), [VS Code](https://cereblanco.medium.com/setup-black-and-isort-in-vscode-514804590bf9). Plugins exist for other IDEs/editors as well.
* black can be integrated with multiple editors, please, instructions are available [here](https://black.readthedocs.io/en/stable/integrations/editors.html).
### Unit Tests ### Unit Tests
Before submitting any changes to the library, please, make sure that all unit tests are passing. To run the tests, please, use the following command: Before submitting any changes to the library, please, make sure that all unit tests are passing. To run the tests, please, use the following command:
``` ```

8
pyproject.toml Normal file
View file

@ -0,0 +1,8 @@
[tool.isort]
profile = "black"
src_paths = ["reporter", "shell", "tests"]
line_length = 100
[tool.black]
line-length = 100
target-version = ["py39"]

View file

@ -1,8 +1,8 @@
import os import os
from .allure_reporter import AllureReporter from reporter.allure_reporter import AllureReporter
from .interfaces import Reporter from reporter.dummy_reporter import DummyReporter
from .dummy_reporter import DummyReporter from reporter.interfaces import Reporter
def get_reporter() -> Reporter: def get_reporter() -> Reporter:

View file

@ -6,7 +6,7 @@ from typing import Any
import allure import allure
from allure import attachment_type from allure import attachment_type
from .interfaces import Reporter from reporter.interfaces import Reporter
class AllureReporter(Reporter): class AllureReporter(Reporter):

View file

@ -1,7 +1,7 @@
from contextlib import AbstractContextManager, contextmanager from contextlib import AbstractContextManager, contextmanager
from typing import Any from typing import Any
from .interfaces import Reporter from reporter.interfaces import Reporter
@contextmanager @contextmanager

View file

@ -16,14 +16,13 @@ class Reporter(ABC):
:param str name: Name of the step :param str name: Name of the step
:return: step context :return: step context
""" """
pass
@abstractmethod @abstractmethod
def attach(self, content: Any, file_name: str) -> None: def attach(self, content: Any, file_name: str) -> None:
""" """
Attach specified content with given file name to the test report. Attach specified content with given file name to the test report.
:param any name: content to attach. If not a string, it will be converted to a string. :param any content: content to attach. If content value is not a string, it will be
converted to a string.
:param str file_name: file name of attachment. :param str file_name: file name of attachment.
""" """
pass

View file

@ -1,2 +1,6 @@
allure-python-commons==2.9.45 allure-python-commons==2.9.45
black==22.8.0
isort==5.10.1
paramiko==2.10.3
pexpect==4.8.0 pexpect==4.8.0
pre-commit==2.20.0

View file

@ -11,6 +11,7 @@ class InteractiveInput:
:attr str prompt_pattern: regular expression that defines expected prompt from the command. :attr str prompt_pattern: regular expression that defines expected prompt from the command.
:attr str input: user input that should be supplied to the command in response to the prompt. :attr str input: user input that should be supplied to the command in response to the prompt.
""" """
prompt_pattern: str prompt_pattern: str
input: str input: str
@ -21,11 +22,12 @@ class CommandOptions:
Options that control command execution. Options that control command execution.
:attr list interactive_inputs: user inputs that should be interactively supplied to :attr list interactive_inputs: user inputs that should be interactively supplied to
the command during its' execution. the command during execution.
:attr int timeout: timeout for command execution (in seconds). :attr int timeout: timeout for command execution (in seconds).
:attr bool check: controls whether to check return code of the command. Set to False to :attr bool check: controls whether to check return code of the command. Set to False to
ignore non-zero return codes. ignore non-zero return codes.
""" """
interactive_inputs: Optional[list[InteractiveInput]] = None interactive_inputs: Optional[list[InteractiveInput]] = None
timeout: int = 30 timeout: int = 30
check: bool = True check: bool = True
@ -36,6 +38,7 @@ class CommandResult:
""" """
Represents a result of a command executed via shell. Represents a result of a command executed via shell.
""" """
stdout: str stdout: str
stderr: str stderr: str
return_code: int return_code: int
@ -56,4 +59,3 @@ class Shell(ABC):
:param CommandOptions options: options that control command execution. :param CommandOptions options: options that control command execution.
:return command result. :return command result.
""" """
pass

View file

@ -9,12 +9,15 @@ import pexpect
from reporter import get_reporter from reporter import get_reporter
from shell.interfaces import CommandOptions, CommandResult, Shell from shell.interfaces import CommandOptions, CommandResult, Shell
logger = logging.getLogger("neofs.testlib.shell") logger = logging.getLogger("neofs.testlib.shell")
reporter = get_reporter() reporter = get_reporter()
class LocalShell(Shell): class LocalShell(Shell):
"""
Implements command shell on a local machine.
"""
def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult: def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult:
# If no options were provided, use default options # If no options were provided, use default options
options = options or CommandOptions() options = options or CommandOptions()
@ -41,12 +44,16 @@ class LocalShell(Shell):
result = self._get_pexpect_process_result(command_process, command) result = self._get_pexpect_process_result(command_process, command)
if options.check and result.return_code != 0: if options.check and result.return_code != 0:
raise RuntimeError(f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}") raise RuntimeError(
f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}"
)
return result return result
except pexpect.ExceptionPexpect as exc: except pexpect.ExceptionPexpect as exc:
result = self._get_pexpect_process_result(command_process, command) result = self._get_pexpect_process_result(command_process, command)
message = f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}" message = (
f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}"
)
if options.check: if options.check:
raise RuntimeError(message) from exc raise RuntimeError(message) from exc
else: else:
@ -54,7 +61,9 @@ class LocalShell(Shell):
return result return result
except OSError as exc: except OSError as exc:
result = self._get_pexpect_process_result(command_process, command) result = self._get_pexpect_process_result(command_process, command)
message = f"Command: {command}\nreturn code: {result.return_code}\nOutput: {exc.strerror}" message = (
f"Command: {command}\nreturn code: {result.return_code}\nOutput: {exc.strerror}"
)
if options.check: if options.check:
raise RuntimeError(message) from exc raise RuntimeError(message) from exc
else: else:
@ -80,7 +89,7 @@ class LocalShell(Shell):
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
timeout=options.timeout, timeout=options.timeout,
shell=True shell=True,
) )
result = CommandResult( result = CommandResult(
@ -92,9 +101,11 @@ class LocalShell(Shell):
except subprocess.CalledProcessError as exc: except subprocess.CalledProcessError as exc:
# TODO: always set check flag to false and capture command result normally # TODO: always set check flag to false and capture command result normally
result = self._get_failing_command_result(command) result = self._get_failing_command_result(command)
raise RuntimeError(f"Command: {command}\nError:\n" raise RuntimeError(
f"return code: {exc.returncode}\n" f"Command: {command}\nError:\n"
f"output: {exc.output}") from exc f"return code: {exc.returncode}\n"
f"output: {exc.output}"
) from exc
except OSError as exc: except OSError as exc:
raise RuntimeError(f"Command: {command}\nOutput: {exc.strerror}") from exc raise RuntimeError(f"Command: {command}\nOutput: {exc.strerror}") from exc
except Exception as exc: except Exception as exc:
@ -106,14 +117,11 @@ class LocalShell(Shell):
def _get_failing_command_result(self, command: str) -> CommandResult: def _get_failing_command_result(self, command: str) -> CommandResult:
return_code, cmd_output = subprocess.getstatusoutput(command) return_code, cmd_output = subprocess.getstatusoutput(command)
return CommandResult( return CommandResult(stdout=cmd_output, stderr="", return_code=return_code)
stdout=cmd_output,
stderr="",
return_code=return_code
)
def _get_pexpect_process_result(self, command_process: Optional[pexpect.spawn], def _get_pexpect_process_result(
command: str) -> CommandResult: self, command_process: Optional[pexpect.spawn], command: str
) -> CommandResult:
""" """
If command process is not None, captures output of this process. If command process is not None, captures output of this process.
If command process is None, then command fails when we attempt to start it, in this case If command process is None, then command fails when we attempt to start it, in this case
@ -137,14 +145,20 @@ class LocalShell(Shell):
return CommandResult(stdout=output, stderr="", return_code=return_code) return CommandResult(stdout=output, stderr="", return_code=return_code)
def _report_command_result(self, command: str, start_time: datetime, end_time: datetime, def _report_command_result(
result: Optional[CommandResult]) -> None: self,
command: str,
start_time: datetime,
end_time: datetime,
result: Optional[CommandResult],
) -> None:
# TODO: increase logging level if return code is non 0, should be warning at least # TODO: increase logging level if return code is non 0, should be warning at least
logger.info( logger.info(
f"Command: {command}\n" f"Command: {command}\n"
f"{'Success:' if result and result.return_code == 0 else 'Error:'}\n" f"{'Success:' if result and result.return_code == 0 else 'Error:'}\n"
f"return code: {result.return_code if result else ''} " f"return code: {result.return_code if result else ''} "
f"\nOutput: {result.stdout if result else ''}") f"\nOutput: {result.stdout if result else ''}"
)
if result: if result:
elapsed_time = end_time - start_time elapsed_time = end_time - start_time

239
shell/ssh_shell.py Normal file
View file

@ -0,0 +1,239 @@
import logging
import socket
import textwrap
from datetime import datetime
from functools import lru_cache, wraps
from time import sleep
from typing import ClassVar, Optional
from paramiko import (
AutoAddPolicy,
ECDSAKey,
Ed25519Key,
PKey,
RSAKey,
SSHClient,
SSHException,
ssh_exception,
)
from paramiko.ssh_exception import AuthenticationException
from reporter import get_reporter
from shell.interfaces import CommandOptions, CommandResult, Shell
logger = logging.getLogger("neofs.testlib.shell")
reporter = get_reporter()
class HostIsNotAvailable(Exception):
"""Raised when host is not reachable via SSH connection"""
def __init__(self, host: str = None):
msg = f"Host {host} is not available"
super().__init__(msg)
def log_command(func):
@wraps(func)
def wrapper(shell: "SSHShell", command: str, *args, **kwargs) -> CommandResult:
command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n")
with reporter.step(command_info):
logging.info(f'Execute command "{command}" on "{shell.host}"')
start_time = datetime.utcnow()
result = func(shell, command, *args, **kwargs)
end_time = datetime.utcnow()
elapsed_time = end_time - start_time
log_message = (
f"HOST: {shell.host}\n"
f"COMMAND:\n{textwrap.indent(command, ' ')}\n"
f"RC:\n {result.return_code}\n"
f"STDOUT:\n{textwrap.indent(result.stdout, ' ')}\n"
f"STDERR:\n{textwrap.indent(result.stderr, ' ')}\n"
f"Start / End / Elapsed\t {start_time.time()} / {end_time.time()} / {elapsed_time}"
)
logger.info(log_message)
reporter.attach(log_message, "SSH command.txt")
return result
return wrapper
@lru_cache
def _load_private_key(file_path: str, password: Optional[str]) -> PKey:
"""
Loads private key from specified file.
We support several type formats, however paramiko doesn't provide functionality to determine
key type in advance. So we attempt to load file with each of the supported formats and then
cache the result so that we don't need to figure out type again on subsequent calls.
"""
logger.debug(f"Loading ssh key from {file_path}")
for key_type in (Ed25519Key, ECDSAKey, RSAKey):
try:
return key_type.from_private_key_file(file_path, password)
except SSHException as ex:
logger.warn(f"SSH key {file_path} can't be loaded with {key_type}: {ex}")
continue
raise SSHException(f"SSH key {file_path} is not supported")
class SSHShell(Shell):
"""
Implements command shell on a remote machine via SSH connection.
"""
# Time in seconds to delay after remote command has completed. The delay is required
# to allow remote command to flush its output buffer
DELAY_AFTER_EXIT = 0.2
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 3
CONNECTION_TIMEOUT = 90
def __init__(
self,
host: str,
login: str,
password: Optional[str] = None,
private_key_path: Optional[str] = None,
private_key_passphrase: Optional[str] = None,
port: str = "22",
) -> None:
self.host = host
self.port = port
self.login = login
self.password = password
self.private_key_path = private_key_path
self.private_key_passphrase = private_key_passphrase
self.__connection: Optional[SSHClient] = None
@property
def _connection(self):
if not self.__connection:
self.__connection = self._create_connection()
return self.__connection
def drop(self):
self._reset_connection()
def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult:
options = options or CommandOptions()
if options.interactive_inputs:
result = self._exec_interactive(command, options)
else:
result = self._exec_non_interactive(command, options)
if options.check and result.return_code != 0:
raise RuntimeError(
f"Command: {command}\nreturn code: {result.return_code}"
f"\nOutput: {result.stdout}"
)
return result
@log_command
def _exec_interactive(self, command: str, options: CommandOptions) -> CommandResult:
stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout)
for interactive_input in options.interactive_inputs:
input = interactive_input.input
if not input.endswith("\n"):
input = f"{input}\n"
try:
stdin.write(input)
except OSError:
logger.exception(f"Error while feeding {input} into command {command}")
# stdin.close()
# Wait for command to complete and flush its buffer before we attempt to read output
sleep(self.DELAY_AFTER_EXIT)
return_code = stdout.channel.recv_exit_status()
sleep(self.DELAY_AFTER_EXIT)
result = CommandResult(
stdout=stdout.read().decode(errors="ignore"),
stderr=stderr.read().decode(errors="ignore"),
return_code=return_code,
)
return result
@log_command
def _exec_non_interactive(self, command: str, options: CommandOptions) -> CommandResult:
try:
_, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout)
# Wait for command to complete and flush its buffer before we attempt to read output
return_code = stdout.channel.recv_exit_status()
sleep(self.DELAY_AFTER_EXIT)
return CommandResult(
stdout=stdout.read().decode(errors="ignore"),
stderr=stderr.read().decode(errors="ignore"),
return_code=return_code,
)
except (
SSHException,
TimeoutError,
ssh_exception.NoValidConnectionsError,
ConnectionResetError,
AttributeError,
socket.timeout,
) as exc:
logger.exception(f"Can't execute command {command} on host: {self.host}")
self._reset_connection()
raise HostIsNotAvailable(self.host) from exc
def _create_connection(self, attempts: int = SSH_CONNECTION_ATTEMPTS) -> SSHClient:
for attempt in range(attempts):
connection = SSHClient()
connection.set_missing_host_key_policy(AutoAddPolicy())
try:
if self.private_key_path:
logging.info(
f"Trying to connect to host {self.host} as {self.login} using SSH key "
f"{self.private_key_path} (attempt {attempt})"
)
connection.connect(
hostname=self.host,
port=self.port,
username=self.login,
pkey=_load_private_key(self.private_key_path, self.private_key_passphrase),
timeout=self.CONNECTION_TIMEOUT,
)
else:
logging.info(
f"Trying to connect to host {self.host} as {self.login} using password "
f"(attempt {attempt})"
)
connection.connect(
hostname=self.host,
port=self.port,
username=self.login,
password=self.password,
timeout=self.CONNECTION_TIMEOUT,
)
return connection
except AuthenticationException:
connection.close()
logger.exception(f"Can't connect to host {self.host}")
raise
except (
SSHException,
ssh_exception.NoValidConnectionsError,
AttributeError,
socket.timeout,
OSError,
) as exc:
connection.close()
can_retry = attempt + 1 < attempts
if can_retry:
logger.warn(f"Can't connect to host {self.host}, will retry. Error: {exc}")
continue
logger.exception(f"Can't connect to host {self.host}")
raise HostIsNotAvailable(self.host) from exc
def _reset_connection(self) -> None:
if self.__connection:
self.__connection.close()
self.__connection = None

View file

@ -1,9 +1,30 @@
import traceback import traceback
from shell.interfaces import CommandResult
def format_error_details(error: Exception) -> str: def format_error_details(error: Exception) -> str:
return "".join(traceback.format_exception( """
Converts specified exception instance into a string that includes error message
and full stack trace.
:param Exception error: exception to convert.
:return: string containing exception details.
"""
detail_lines = traceback.format_exception(
etype=type(error), etype=type(error),
value=error, value=error,
tb=error.__traceback__) tb=error.__traceback__,
) )
return "".join(detail_lines)
def get_output_lines(result: CommandResult) -> list[str]:
"""
Converts output of specified command result into separate lines trimmed from whitespaces.
Empty lines are excluded.
:param CommandResult result: result which output should be converted.
:return: list of lines extracted from the output.
"""
return [line.strip() for line in result.stdout.split("\n") if line.strip()]

View file

@ -2,7 +2,7 @@ from unittest import TestCase
from shell.interfaces import CommandOptions, InteractiveInput from shell.interfaces import CommandOptions, InteractiveInput
from shell.local_shell import LocalShell from shell.local_shell import LocalShell
from tests.helpers import format_error_details from tests.helpers import format_error_details, get_output_lines
class TestLocalShellInteractive(TestCase): class TestLocalShellInteractive(TestCase):
@ -15,12 +15,11 @@ class TestLocalShellInteractive(TestCase):
inputs = [InteractiveInput(prompt_pattern="Password", input="test")] inputs = [InteractiveInput(prompt_pattern="Password", input="test")]
result = self.shell.exec( result = self.shell.exec(
f"python -c \"{script}\"", f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)
CommandOptions(interactive_inputs=inputs)
) )
self.assertEqual(0, result.return_code) self.assertEqual(0, result.return_code)
self.assertOutputLines(["Password: test", "test"], result.stdout) self.assertEqual(["Password: test", "test"], get_output_lines(result))
self.assertEqual("", result.stderr) self.assertEqual("", result.stderr)
def test_command_with_several_prompts(self): def test_command_with_several_prompts(self):
@ -34,12 +33,13 @@ class TestLocalShellInteractive(TestCase):
] ]
result = self.shell.exec( result = self.shell.exec(
f"python -c \"{script}\"", f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)
CommandOptions(interactive_inputs=inputs)
) )
self.assertEqual(0, result.return_code) self.assertEqual(0, result.return_code)
self.assertOutputLines(["Input1: test1", "test1", "Input2: test2", "test2"], result.stdout) self.assertEqual(
["Input1: test1", "test1", "Input2: test2", "test2"], get_output_lines(result)
)
self.assertEqual("", result.stderr) self.assertEqual("", result.stderr)
def test_failed_command_with_check(self): def test_failed_command_with_check(self):
@ -47,7 +47,7 @@ class TestLocalShellInteractive(TestCase):
inputs = [InteractiveInput(prompt_pattern=".*", input="test")] inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
with self.assertRaises(RuntimeError) as exc: with self.assertRaises(RuntimeError) as exc:
self.shell.exec(f"python -c \"{script}\"", CommandOptions(interactive_inputs=inputs)) self.shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))
error = format_error_details(exc.exception) error = format_error_details(exc.exception)
self.assertIn("Error", error) self.assertIn("Error", error)
@ -59,7 +59,7 @@ class TestLocalShellInteractive(TestCase):
inputs = [InteractiveInput(prompt_pattern=".*", input="test")] inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
result = self.shell.exec( result = self.shell.exec(
f"python -c \"{script}\"", f'python3 -c "{script}"',
CommandOptions(interactive_inputs=inputs, check=False), CommandOptions(interactive_inputs=inputs, check=False),
) )
self.assertEqual(1, result.return_code) self.assertEqual(1, result.return_code)
@ -71,8 +71,44 @@ class TestLocalShellInteractive(TestCase):
self.shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs)) self.shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs))
error = format_error_details(exc.exception) error = format_error_details(exc.exception)
self.assertIn("command was not found or was not executable", error) self.assertIn("return code: 127", error)
def assertOutputLines(self, expected_lines: list[str], output: str) -> None:
output_lines = [line.strip() for line in output.split("\n") if line.strip()] class TestLocalShellNonInteractive(TestCase):
self.assertEqual(expected_lines, output_lines) @classmethod
def setUpClass(cls):
cls.shell = LocalShell()
def test_successful_command(self):
script = "print('test')"
result = self.shell.exec(f'python3 -c "{script}"')
self.assertEqual(0, result.return_code)
self.assertEqual("test", result.stdout.strip())
self.assertEqual("", result.stderr)
def test_invalid_command_with_check(self):
script = "invalid script"
with self.assertRaises(RuntimeError) as exc:
self.shell.exec(f'python3 -c "{script}"')
error = format_error_details(exc.exception)
self.assertIn("Error", error)
self.assertIn("return code: 1", error)
def test_invalid_command_without_check(self):
script = "invalid script"
result = self.shell.exec(f'python3 -c "{script}"', CommandOptions(check=False))
self.assertEqual(1, result.return_code)
self.assertIn("Error", result.stdout)
def test_non_existing_binary(self):
with self.assertRaises(RuntimeError) as exc:
self.shell.exec("not-a-command")
error = format_error_details(exc.exception)
self.assertIn("return code: 127", error)

View file

@ -1,46 +0,0 @@
from unittest import TestCase
from shell.interfaces import CommandOptions
from shell.local_shell import LocalShell
from tests.helpers import format_error_details
class TestLocalShellNonInteractive(TestCase):
@classmethod
def setUpClass(cls):
cls.shell = LocalShell()
def test_successful_command(self):
script = "print('test')"
result = self.shell.exec(f"python -c \"{script}\"")
self.assertEqual(0, result.return_code)
self.assertEqual("test", result.stdout.strip())
self.assertEqual("", result.stderr)
def test_failed_command_with_check(self):
script = "invalid script"
with self.assertRaises(RuntimeError) as exc:
self.shell.exec(f"python -c \"{script}\"")
error = format_error_details(exc.exception)
self.assertIn("Error", error)
self.assertIn("return code: 1", error)
def test_failed_command_without_check(self):
script = "invalid script"
result = self.shell.exec(f"python -c \"{script}\"", CommandOptions(check=False))
self.assertEqual(1, result.return_code)
self.assertIn("Error", result.stdout)
def test_non_existing_binary(self):
with self.assertRaises(RuntimeError) as exc:
self.shell.exec(f"not-a-command")
error = format_error_details(exc.exception)
self.assertIn("Error", error)
self.assertIn("return code: 127", error)

138
tests/test_ssh_shell.py Normal file
View file

@ -0,0 +1,138 @@
import os
from unittest import SkipTest, TestCase
from shell.interfaces import CommandOptions, InteractiveInput
from shell.ssh_shell import SSHShell
from tests.helpers import format_error_details, get_output_lines
def init_shell() -> SSHShell:
host = os.getenv("SSH_SHELL_HOST")
port = os.getenv("SSH_SHELL_PORT", "22")
login = os.getenv("SSH_SHELL_LOGIN")
private_key_path = os.getenv("SSH_SHELL_PRIVATE_KEY_PATH")
private_key_passphrase = os.getenv("SSH_SHELL_PRIVATE_KEY_PASSPHRASE")
if not all([host, login, private_key_path, private_key_passphrase]):
# TODO: in the future we might use https://pypi.org/project/mock-ssh-server,
# at the moment it is not suitable for us because of its issues with stdin
raise SkipTest("SSH connection is not configured")
return SSHShell(
host=host,
port=port,
login=login,
private_key_path=private_key_path,
private_key_passphrase=private_key_passphrase,
)
class TestSSHShellInteractive(TestCase):
@classmethod
def setUpClass(cls):
cls.shell = init_shell()
def test_command_with_one_prompt(self):
script = "password = input('Password: '); print('\\n' + password)"
inputs = [InteractiveInput(prompt_pattern="Password", input="test")]
result = self.shell.exec(
f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)
)
# TODO: we have inconsistency with local shell here, ssh does not echo input into stdout
self.assertEqual(0, result.return_code)
self.assertEqual(["Password:", "test"], get_output_lines(result))
self.assertEqual("", result.stderr)
def test_command_with_several_prompts(self):
script = (
"input1 = input('Input1: '); print('\\n' + input1); "
"input2 = input('Input2: '); print('\\n' + input2)"
)
inputs = [
InteractiveInput(prompt_pattern="Input1", input="test1"),
InteractiveInput(prompt_pattern="Input2", input="test2"),
]
result = self.shell.exec(
f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)
)
# TODO: we have inconsistency with local shell here, ssh does not echo input into stdout
self.assertEqual(0, result.return_code)
self.assertEqual(["Input1:", "test1", "Input2:", "test2"], get_output_lines(result))
self.assertEqual("", result.stderr)
def test_invalid_command_with_check(self):
script = "invalid script"
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
with self.assertRaises(RuntimeError) as raised:
self.shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))
error = format_error_details(raised.exception)
self.assertIn("Error", error)
self.assertIn("return code: 1", error)
def test_invalid_command_without_check(self):
script = "invalid script"
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
result = self.shell.exec(
f'python3 -c "{script}"',
CommandOptions(interactive_inputs=inputs, check=False),
)
self.assertEqual(1, result.return_code)
def test_non_existing_binary(self):
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
with self.assertRaises(RuntimeError) as raised:
self.shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs))
error = format_error_details(raised.exception)
self.assertIn("return code: 127", error)
class TestSSHShellNonInteractive(TestCase):
@classmethod
def setUpClass(cls):
cls.shell = init_shell()
def test_correct_command(self):
script = "print('test')"
result = self.shell.exec(f'python3 -c "{script}"')
self.assertEqual(0, result.return_code)
self.assertEqual("test", result.stdout.strip())
self.assertEqual("", result.stderr)
def test_invalid_command_with_check(self):
script = "invalid script"
with self.assertRaises(RuntimeError) as raised:
self.shell.exec(f'python3 -c "{script}"')
error = format_error_details(raised.exception)
self.assertIn("Error", error)
self.assertIn("return code: 1", error)
def test_invalid_command_without_check(self):
script = "invalid script"
result = self.shell.exec(f'python3 -c "{script}"', CommandOptions(check=False))
self.assertEqual(1, result.return_code)
# TODO: we have inconsistency with local shell here, the local shell captures error info
# in stdout while ssh shell captures it in stderr
self.assertIn("Error", result.stderr)
def test_non_existing_binary(self):
with self.assertRaises(RuntimeError) as exc:
self.shell.exec("not-a-command")
error = format_error_details(exc.exception)
self.assertIn("Error", error)
self.assertIn("return code: 127", error)