Keep only one ssh connection per host #94
4 changed files with 261 additions and 144 deletions
|
@ -1,3 +1,3 @@
|
||||||
from frostfs_testlib.shell.interfaces import CommandOptions, CommandResult, InteractiveInput, Shell
|
from frostfs_testlib.shell.interfaces import CommandOptions, CommandResult, InteractiveInput, Shell
|
||||||
from frostfs_testlib.shell.local_shell import LocalShell
|
from frostfs_testlib.shell.local_shell import LocalShell
|
||||||
from frostfs_testlib.shell.ssh_shell import SSHShell
|
from frostfs_testlib.shell.ssh_shell import SshConnectionProvider, SSHShell
|
||||||
|
|
|
@ -20,12 +20,117 @@ from paramiko import (
|
||||||
from paramiko.ssh_exception import AuthenticationException
|
from paramiko.ssh_exception import AuthenticationException
|
||||||
|
|
||||||
from frostfs_testlib.reporter import get_reporter
|
from frostfs_testlib.reporter import get_reporter
|
||||||
from frostfs_testlib.shell.interfaces import CommandInspector, CommandOptions, CommandResult, Shell
|
from frostfs_testlib.shell.interfaces import (
|
||||||
|
CommandInspector,
|
||||||
|
CommandOptions,
|
||||||
|
CommandResult,
|
||||||
|
Shell,
|
||||||
|
SshCredentials,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger("frostfs.testlib.shell")
|
logger = logging.getLogger("frostfs.testlib.shell")
|
||||||
reporter = get_reporter()
|
reporter = get_reporter()
|
||||||
|
|
||||||
|
|
||||||
|
class SshConnectionProvider:
|
||||||
|
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4
|
||||||
|
SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10
|
||||||
|
CONNECTION_TIMEOUT = 60
|
||||||
|
|
||||||
|
instance = None
|
||||||
|
connections: dict[str, SSHClient] = {}
|
||||||
|
creds: dict[str, SshCredentials] = {}
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if not cls.instance:
|
||||||
|
cls.instance = super(SshConnectionProvider, cls).__new__(cls)
|
||||||
|
return cls.instance
|
||||||
|
|
||||||
|
def store_creds(self, host: str, ssh_creds: SshCredentials):
|
||||||
|
self.creds[host] = ssh_creds
|
||||||
|
|
||||||
|
def provide(self, host: str, port: str) -> SSHClient:
|
||||||
|
if host not in self.creds:
|
||||||
|
raise RuntimeError(f"Please add credentials for host {host}")
|
||||||
|
|
||||||
|
if host in self.connections:
|
||||||
|
client = self.connections[host]
|
||||||
|
if client:
|
||||||
|
return client
|
||||||
|
|
||||||
|
creds = self.creds[host]
|
||||||
|
client = self._create_connection(host, port, creds)
|
||||||
|
self.connections[host] = client
|
||||||
|
return client
|
||||||
|
|
||||||
|
def drop(self, host: str):
|
||||||
|
if host in self.connections:
|
||||||
|
client = self.connections.pop(host)
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
def drop_all(self):
|
||||||
|
hosts = list(self.connections.keys())
|
||||||
|
for host in hosts:
|
||||||
|
self.drop(host)
|
||||||
|
|
||||||
|
def _create_connection(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: str,
|
||||||
|
creds: SshCredentials,
|
||||||
|
) -> SSHClient:
|
||||||
|
for attempt in range(self.SSH_CONNECTION_ATTEMPTS):
|
||||||
|
connection = SSHClient()
|
||||||
|
connection.set_missing_host_key_policy(AutoAddPolicy())
|
||||||
|
try:
|
||||||
|
if creds.ssh_key_path:
|
||||||
|
logger.info(
|
||||||
|
f"Trying to connect to host {host} as {creds.ssh_login} using SSH key "
|
||||||
|
f"{creds.ssh_key_path} (attempt {attempt})"
|
||||||
|
)
|
||||||
|
connection.connect(
|
||||||
|
hostname=host,
|
||||||
|
port=port,
|
||||||
|
username=creds.ssh_login,
|
||||||
|
pkey=_load_private_key(creds.ssh_key_path, creds.ssh_key_passphrase),
|
||||||
|
timeout=self.CONNECTION_TIMEOUT,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Trying to connect to host {host} as {creds.ssh_login} using password "
|
||||||
|
f"(attempt {attempt})"
|
||||||
|
)
|
||||||
|
connection.connect(
|
||||||
|
hostname=host,
|
||||||
|
port=port,
|
||||||
|
username=creds.ssh_login,
|
||||||
|
password=creds.ssh_password,
|
||||||
|
timeout=self.CONNECTION_TIMEOUT,
|
||||||
|
)
|
||||||
|
return connection
|
||||||
|
except AuthenticationException:
|
||||||
|
connection.close()
|
||||||
|
logger.exception(f"Can't connect to host {host}")
|
||||||
|
raise
|
||||||
|
except (
|
||||||
|
SSHException,
|
||||||
|
ssh_exception.NoValidConnectionsError,
|
||||||
|
AttributeError,
|
||||||
|
socket.timeout,
|
||||||
|
OSError,
|
||||||
|
) as exc:
|
||||||
|
connection.close()
|
||||||
|
can_retry = attempt + 1 < self.SSH_CONNECTION_ATTEMPTS
|
||||||
|
if can_retry:
|
||||||
|
logger.warn(
|
||||||
|
f"Can't connect to host {host}, will retry after {self.SSH_ATTEMPTS_INTERVAL}s. Error: {exc}"
|
||||||
|
)
|
||||||
|
sleep(self.SSH_ATTEMPTS_INTERVAL)
|
||||||
|
continue
|
||||||
|
logger.exception(f"Can't connect to host {host}")
|
||||||
|
raise HostIsNotAvailable(host) from exc
|
||||||
|
|
||||||
|
|
||||||
class HostIsNotAvailable(Exception):
|
class HostIsNotAvailable(Exception):
|
||||||
"""Raised when host is not reachable via SSH connection."""
|
"""Raised when host is not reachable via SSH connection."""
|
||||||
|
|
||||||
|
@ -91,10 +196,6 @@ class SSHShell(Shell):
|
||||||
# to allow remote command to flush its output buffer
|
# to allow remote command to flush its output buffer
|
||||||
DELAY_AFTER_EXIT = 0.2
|
DELAY_AFTER_EXIT = 0.2
|
||||||
|
|
||||||
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4
|
|
||||||
SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10
|
|
||||||
CONNECTION_TIMEOUT = 60
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
|
@ -106,23 +207,21 @@ class SSHShell(Shell):
|
||||||
command_inspectors: Optional[list[CommandInspector]] = None,
|
command_inspectors: Optional[list[CommandInspector]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.connection_provider = SshConnectionProvider()
|
||||||
|
self.connection_provider.store_creds(
|
||||||
|
host, SshCredentials(login, password, private_key_path, private_key_passphrase)
|
||||||
|
)
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.login = login
|
|
||||||
self.password = password
|
|
||||||
self.private_key_path = private_key_path
|
|
||||||
self.private_key_passphrase = private_key_passphrase
|
|
||||||
self.command_inspectors = command_inspectors or []
|
self.command_inspectors = command_inspectors or []
|
||||||
self.__connection: Optional[SSHClient] = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _connection(self):
|
def _connection(self):
|
||||||
if not self.__connection:
|
return self.connection_provider.provide(self.host, self.port)
|
||||||
self.__connection = self._create_connection()
|
|
||||||
return self.__connection
|
|
||||||
|
|
||||||
def drop(self):
|
def drop(self):
|
||||||
self._reset_connection()
|
self.connection_provider.drop(self.host)
|
||||||
|
|
||||||
def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult:
|
def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult:
|
||||||
options = options or CommandOptions()
|
options = options or CommandOptions()
|
||||||
|
@ -196,7 +295,7 @@ class SSHShell(Shell):
|
||||||
socket.timeout,
|
socket.timeout,
|
||||||
) as exc:
|
) as exc:
|
||||||
logger.exception(f"Can't execute command {command} on host: {self.host}")
|
logger.exception(f"Can't execute command {command} on host: {self.host}")
|
||||||
self._reset_connection()
|
self.drop()
|
||||||
raise HostIsNotAvailable(self.host) from exc
|
raise HostIsNotAvailable(self.host) from exc
|
||||||
|
|
||||||
def _read_channels(
|
def _read_channels(
|
||||||
|
@ -251,62 +350,3 @@ class SSHShell(Shell):
|
||||||
full_stderr = b"".join(stderr_chunks)
|
full_stderr = b"".join(stderr_chunks)
|
||||||
|
|
||||||
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
|
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
|
||||||
|
|
||||||
def _create_connection(
|
|
||||||
self, attempts: int = SSH_CONNECTION_ATTEMPTS, interval: int = SSH_ATTEMPTS_INTERVAL
|
|
||||||
) -> SSHClient:
|
|
||||||
for attempt in range(attempts):
|
|
||||||
connection = SSHClient()
|
|
||||||
connection.set_missing_host_key_policy(AutoAddPolicy())
|
|
||||||
try:
|
|
||||||
if self.private_key_path:
|
|
||||||
logger.info(
|
|
||||||
f"Trying to connect to host {self.host} as {self.login} using SSH key "
|
|
||||||
f"{self.private_key_path} (attempt {attempt})"
|
|
||||||
)
|
|
||||||
connection.connect(
|
|
||||||
hostname=self.host,
|
|
||||||
port=self.port,
|
|
||||||
username=self.login,
|
|
||||||
pkey=_load_private_key(self.private_key_path, self.private_key_passphrase),
|
|
||||||
timeout=self.CONNECTION_TIMEOUT,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Trying to connect to host {self.host} as {self.login} using password "
|
|
||||||
f"(attempt {attempt})"
|
|
||||||
)
|
|
||||||
connection.connect(
|
|
||||||
hostname=self.host,
|
|
||||||
port=self.port,
|
|
||||||
username=self.login,
|
|
||||||
password=self.password,
|
|
||||||
timeout=self.CONNECTION_TIMEOUT,
|
|
||||||
)
|
|
||||||
return connection
|
|
||||||
except AuthenticationException:
|
|
||||||
connection.close()
|
|
||||||
logger.exception(f"Can't connect to host {self.host}")
|
|
||||||
raise
|
|
||||||
except (
|
|
||||||
SSHException,
|
|
||||||
ssh_exception.NoValidConnectionsError,
|
|
||||||
AttributeError,
|
|
||||||
socket.timeout,
|
|
||||||
OSError,
|
|
||||||
) as exc:
|
|
||||||
connection.close()
|
|
||||||
can_retry = attempt + 1 < attempts
|
|
||||||
if can_retry:
|
|
||||||
logger.warn(
|
|
||||||
f"Can't connect to host {self.host}, will retry after {interval}s. Error: {exc}"
|
|
||||||
)
|
|
||||||
sleep(interval)
|
|
||||||
continue
|
|
||||||
logger.exception(f"Can't connect to host {self.host}")
|
|
||||||
raise HostIsNotAvailable(self.host) from exc
|
|
||||||
|
|
||||||
def _reset_connection(self) -> None:
|
|
||||||
if self.__connection:
|
|
||||||
self.__connection.close()
|
|
||||||
self.__connection = None
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import time
|
||||||
|
|
||||||
import frostfs_testlib.resources.optionals as optionals
|
import frostfs_testlib.resources.optionals as optionals
|
||||||
from frostfs_testlib.reporter import get_reporter
|
from frostfs_testlib.reporter import get_reporter
|
||||||
from frostfs_testlib.shell import CommandOptions, Shell
|
from frostfs_testlib.shell import CommandOptions, Shell, SshConnectionProvider
|
||||||
from frostfs_testlib.steps.network import IfUpDownHelper, IpTablesHelper
|
from frostfs_testlib.steps.network import IfUpDownHelper, IpTablesHelper
|
||||||
from frostfs_testlib.storage.cluster import Cluster, ClusterNode, StorageNode
|
from frostfs_testlib.storage.cluster import Cluster, ClusterNode, StorageNode
|
||||||
from frostfs_testlib.storage.controllers.disk_controller import DiskController
|
from frostfs_testlib.storage.controllers.disk_controller import DiskController
|
||||||
|
@ -37,6 +37,10 @@ class ClusterStateController:
|
||||||
@run_optionally(optionals.OPTIONAL_FAILOVER_ENABLED)
|
@run_optionally(optionals.OPTIONAL_FAILOVER_ENABLED)
|
||||||
@reporter.step_deco("Stop host of node {node}")
|
@reporter.step_deco("Stop host of node {node}")
|
||||||
def stop_node_host(self, node: ClusterNode, mode: str):
|
def stop_node_host(self, node: ClusterNode, mode: str):
|
||||||
|
# Drop ssh connection for this node before shutdown
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
provider.drop(node.host_ip)
|
||||||
|
|
||||||
with reporter.step(f"Stop host {node.host.config.address}"):
|
with reporter.step(f"Stop host {node.host.config.address}"):
|
||||||
node.host.stop_host(mode=mode)
|
node.host.stop_host(mode=mode)
|
||||||
wait_for_host_offline(self.shell, node.storage_node)
|
wait_for_host_offline(self.shell, node.storage_node)
|
||||||
|
@ -48,6 +52,11 @@ class ClusterStateController:
|
||||||
nodes = (
|
nodes = (
|
||||||
reversed(self.cluster.cluster_nodes) if reversed_order else self.cluster.cluster_nodes
|
reversed(self.cluster.cluster_nodes) if reversed_order else self.cluster.cluster_nodes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Drop all ssh connections before shutdown
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
provider.drop_all()
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
with reporter.step(f"Stop host {node.host.config.address}"):
|
with reporter.step(f"Stop host {node.host.config.address}"):
|
||||||
self.stopped_nodes.append(node)
|
self.stopped_nodes.append(node)
|
||||||
|
@ -307,6 +316,10 @@ class ClusterStateController:
|
||||||
options = CommandOptions(close_stdin=True, timeout=1, check=False)
|
options = CommandOptions(close_stdin=True, timeout=1, check=False)
|
||||||
shell.exec('sudo sh -c "echo b > /proc/sysrq-trigger"', options)
|
shell.exec('sudo sh -c "echo b > /proc/sysrq-trigger"', options)
|
||||||
|
|
||||||
|
# Drop ssh connection for this node
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
provider.drop(node.host_ip)
|
||||||
|
|
||||||
if wait_for_return:
|
if wait_for_return:
|
||||||
# Let the things to be settled
|
# Let the things to be settled
|
||||||
# A little wait here to prevent ssh stuck during panic
|
# A little wait here to prevent ssh stuck during panic
|
||||||
|
|
|
@ -1,50 +1,68 @@
|
||||||
import os
|
import os
|
||||||
from unittest import SkipTest, TestCase
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from frostfs_testlib.shell.interfaces import CommandOptions, InteractiveInput
|
from frostfs_testlib.shell.interfaces import CommandOptions, InteractiveInput
|
||||||
from frostfs_testlib.shell.ssh_shell import SSHShell
|
from frostfs_testlib.shell.ssh_shell import SshConnectionProvider, SSHShell
|
||||||
from helpers import format_error_details, get_output_lines
|
from helpers import format_error_details, get_output_lines
|
||||||
|
|
||||||
|
|
||||||
def init_shell() -> SSHShell:
|
def get_shell(host: str):
|
||||||
host = os.getenv("SSH_SHELL_HOST")
|
|
||||||
port = os.getenv("SSH_SHELL_PORT", "22")
|
port = os.getenv("SSH_SHELL_PORT", "22")
|
||||||
login = os.getenv("SSH_SHELL_LOGIN")
|
login = os.getenv("SSH_SHELL_LOGIN")
|
||||||
private_key_path = os.getenv("SSH_SHELL_PRIVATE_KEY_PATH")
|
|
||||||
private_key_passphrase = os.getenv("SSH_SHELL_PRIVATE_KEY_PASSPHRASE")
|
password = os.getenv("SSH_SHELL_PASSWORD", "")
|
||||||
|
private_key_path = os.getenv("SSH_SHELL_PRIVATE_KEY_PATH", "")
|
||||||
|
private_key_passphrase = os.getenv("SSH_SHELL_PRIVATE_KEY_PASSPHRASE", "")
|
||||||
|
|
||||||
if not all([host, login, private_key_path, private_key_passphrase]):
|
if not all([host, login, private_key_path, private_key_passphrase]):
|
||||||
# TODO: in the future we might use https://pypi.org/project/mock-ssh-server,
|
# TODO: in the future we might use https://pypi.org/project/mock-ssh-server,
|
||||||
# at the moment it is not suitable for us because of its issues with stdin
|
# at the moment it is not suitable for us because of its issues with stdin
|
||||||
raise SkipTest("SSH connection is not configured")
|
pytest.skip("SSH connection is not configured")
|
||||||
|
|
||||||
return SSHShell(
|
return SSHShell(
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
login=login,
|
login=login,
|
||||||
|
password=password,
|
||||||
private_key_path=private_key_path,
|
private_key_path=private_key_path,
|
||||||
private_key_passphrase=private_key_passphrase,
|
private_key_passphrase=private_key_passphrase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSSHShellInteractive(TestCase):
|
@pytest.fixture(scope="module")
|
||||||
@classmethod
|
def shell() -> SSHShell:
|
||||||
def setUpClass(cls):
|
return get_shell(host=os.getenv("SSH_SHELL_HOST"))
|
||||||
cls.shell = init_shell()
|
|
||||||
|
|
||||||
def test_command_with_one_prompt(self):
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def shell_same_host() -> SSHShell:
|
||||||
|
return get_shell(host=os.getenv("SSH_SHELL_HOST"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def shell_another_host() -> SSHShell:
|
||||||
|
return get_shell(host=os.getenv("SSH_SHELL_HOST_2"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def reset_connection():
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
provider.drop_all()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHShellInteractive:
|
||||||
|
def test_command_with_one_prompt(self, shell: SSHShell):
|
||||||
script = "password = input('Password: '); print('\\n' + password)"
|
script = "password = input('Password: '); print('\\n' + password)"
|
||||||
|
|
||||||
inputs = [InteractiveInput(prompt_pattern="Password", input="test")]
|
inputs = [InteractiveInput(prompt_pattern="Password", input="test")]
|
||||||
result = self.shell.exec(
|
result = shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))
|
||||||
f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(0, result.return_code)
|
assert result.return_code == 0
|
||||||
self.assertEqual(["Password: test", "test"], get_output_lines(result))
|
assert ["Password: test", "test"] == get_output_lines(result)
|
||||||
self.assertEqual("", result.stderr)
|
assert not result.stderr
|
||||||
|
|
||||||
def test_command_with_several_prompts(self):
|
def test_command_with_several_prompts(self, shell: SSHShell):
|
||||||
script = (
|
script = (
|
||||||
"input1 = input('Input1: '); print('\\n' + input1); "
|
"input1 = input('Input1: '); print('\\n' + input1); "
|
||||||
"input2 = input('Input2: '); print('\\n' + input2)"
|
"input2 = input('Input2: '); print('\\n' + input2)"
|
||||||
|
@ -54,86 +72,132 @@ class TestSSHShellInteractive(TestCase):
|
||||||
InteractiveInput(prompt_pattern="Input2", input="test2"),
|
InteractiveInput(prompt_pattern="Input2", input="test2"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = self.shell.exec(
|
result = shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))
|
||||||
f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(0, result.return_code)
|
assert result.return_code == 0
|
||||||
self.assertEqual(
|
assert ["Input1: test1", "test1", "Input2: test2", "test2"] == get_output_lines(result)
|
||||||
["Input1: test1", "test1", "Input2: test2", "test2"], get_output_lines(result)
|
assert not result.stderr
|
||||||
)
|
|
||||||
self.assertEqual("", result.stderr)
|
|
||||||
|
|
||||||
def test_invalid_command_with_check(self):
|
def test_invalid_command_with_check(self, shell: SSHShell):
|
||||||
script = "invalid script"
|
script = "invalid script"
|
||||||
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
|
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError) as raised:
|
with pytest.raises(RuntimeError) as raised:
|
||||||
self.shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))
|
shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs))
|
||||||
|
|
||||||
error = format_error_details(raised.exception)
|
error = format_error_details(raised.value)
|
||||||
self.assertIn("SyntaxError", error)
|
assert "SyntaxError" in error
|
||||||
self.assertIn("return code: 1", error)
|
assert "return code: 1" in error
|
||||||
|
|
||||||
def test_invalid_command_without_check(self):
|
def test_invalid_command_without_check(self, shell: SSHShell):
|
||||||
script = "invalid script"
|
script = "invalid script"
|
||||||
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
|
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
|
||||||
|
|
||||||
result = self.shell.exec(
|
result = shell.exec(
|
||||||
f'python3 -c "{script}"',
|
f'python3 -c "{script}"',
|
||||||
CommandOptions(interactive_inputs=inputs, check=False),
|
CommandOptions(interactive_inputs=inputs, check=False),
|
||||||
)
|
)
|
||||||
self.assertIn("SyntaxError", result.stdout)
|
assert "SyntaxError" in result.stdout
|
||||||
self.assertEqual(1, result.return_code)
|
assert result.return_code == 1
|
||||||
|
|
||||||
def test_non_existing_binary(self):
|
def test_non_existing_binary(self, shell: SSHShell):
|
||||||
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
|
inputs = [InteractiveInput(prompt_pattern=".*", input="test")]
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError) as raised:
|
with pytest.raises(RuntimeError) as raised:
|
||||||
self.shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs))
|
shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs))
|
||||||
|
|
||||||
error = format_error_details(raised.exception)
|
error = format_error_details(raised.value)
|
||||||
self.assertIn("return code: 127", error)
|
assert "return code: 127" in error
|
||||||
|
|
||||||
|
|
||||||
class TestSSHShellNonInteractive(TestCase):
|
class TestSSHShellNonInteractive:
|
||||||
@classmethod
|
def test_correct_command(self, shell: SSHShell):
|
||||||
def setUpClass(cls):
|
|
||||||
cls.shell = init_shell()
|
|
||||||
|
|
||||||
def test_correct_command(self):
|
|
||||||
script = "print('test')"
|
script = "print('test')"
|
||||||
|
|
||||||
result = self.shell.exec(f'python3 -c "{script}"')
|
result = shell.exec(f'python3 -c "{script}"')
|
||||||
|
|
||||||
self.assertEqual(0, result.return_code)
|
assert result.return_code == 0
|
||||||
self.assertEqual("test", result.stdout.strip())
|
assert result.stdout.strip() == "test"
|
||||||
self.assertEqual("", result.stderr)
|
assert not result.stderr
|
||||||
|
|
||||||
def test_invalid_command_with_check(self):
|
def test_invalid_command_with_check(self, shell: SSHShell):
|
||||||
script = "invalid script"
|
script = "invalid script"
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError) as raised:
|
with pytest.raises(RuntimeError) as raised:
|
||||||
self.shell.exec(f'python3 -c "{script}"')
|
shell.exec(f'python3 -c "{script}"')
|
||||||
|
|
||||||
error = format_error_details(raised.exception)
|
error = format_error_details(raised.value)
|
||||||
self.assertIn("Error", error)
|
assert "Error" in error
|
||||||
self.assertIn("return code: 1", error)
|
assert "return code: 1" in error
|
||||||
|
|
||||||
def test_invalid_command_without_check(self):
|
def test_invalid_command_without_check(self, shell: SSHShell):
|
||||||
script = "invalid script"
|
script = "invalid script"
|
||||||
|
|
||||||
result = self.shell.exec(f'python3 -c "{script}"', CommandOptions(check=False))
|
result = shell.exec(f'python3 -c "{script}"', CommandOptions(check=False))
|
||||||
|
|
||||||
self.assertEqual(1, result.return_code)
|
assert result.return_code == 1
|
||||||
# TODO: we have inconsistency with local shell here, the local shell captures error info
|
# TODO: we have inconsistency with local shell here, the local shell captures error info
|
||||||
# in stdout while ssh shell captures it in stderr
|
# in stdout while ssh shell captures it in stderr
|
||||||
self.assertIn("Error", result.stderr)
|
assert "Error" in result.stderr
|
||||||
|
|
||||||
def test_non_existing_binary(self):
|
def test_non_existing_binary(self, shell: SSHShell):
|
||||||
with self.assertRaises(RuntimeError) as exc:
|
with pytest.raises(RuntimeError) as raised:
|
||||||
self.shell.exec("not-a-command")
|
shell.exec("not-a-command")
|
||||||
|
|
||||||
error = format_error_details(exc.exception)
|
error = format_error_details(raised.value)
|
||||||
self.assertIn("Error", error)
|
assert "Error" in error
|
||||||
self.assertIn("return code: 127", error)
|
assert "return code: 127" in error
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHShellConnection:
|
||||||
|
def test_connection_provider_is_singleton(self):
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
provider2 = SshConnectionProvider()
|
||||||
|
assert id(provider) == id(provider2)
|
||||||
|
|
||||||
|
def test_connection_provider_has_creds(self, shell: SSHShell):
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
assert len(provider.creds) == 1
|
||||||
|
assert len(provider.connections) == 0
|
||||||
|
|
||||||
|
def test_connection_provider_has_only_one_connection(self, shell: SSHShell):
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
assert len(provider.connections) == 0
|
||||||
|
shell.exec("echo 1")
|
||||||
|
assert len(provider.connections) == 1
|
||||||
|
shell.exec("echo 2")
|
||||||
|
assert len(provider.connections) == 1
|
||||||
|
shell.drop()
|
||||||
|
assert len(provider.connections) == 0
|
||||||
|
|
||||||
|
def test_connection_same_host(self, shell: SSHShell, shell_same_host: SSHShell):
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
assert len(provider.connections) == 0
|
||||||
|
|
||||||
|
shell.exec("echo 1")
|
||||||
|
assert len(provider.connections) == 1
|
||||||
|
|
||||||
|
shell_same_host.exec("echo 2")
|
||||||
|
assert len(provider.connections) == 1
|
||||||
|
|
||||||
|
shell.drop()
|
||||||
|
assert len(provider.connections) == 0
|
||||||
|
|
||||||
|
shell.exec("echo 3")
|
||||||
|
assert len(provider.connections) == 1
|
||||||
|
|
||||||
|
def test_connection_another_host(self, shell: SSHShell, shell_another_host: SSHShell):
|
||||||
|
provider = SshConnectionProvider()
|
||||||
|
assert len(provider.connections) == 0
|
||||||
|
|
||||||
|
shell.exec("echo 1")
|
||||||
|
assert len(provider.connections) == 1
|
||||||
|
|
||||||
|
shell_another_host.exec("echo 2")
|
||||||
|
assert len(provider.connections) == 2
|
||||||
|
|
||||||
|
shell.drop()
|
||||||
|
assert len(provider.connections) == 1
|
||||||
|
|
||||||
|
shell_another_host.drop()
|
||||||
|
assert len(provider.connections) == 0
|
||||||
|
|
Loading…
Reference in a new issue