From 2c2af7f8ed0ca9199d0a21d0091f260083fbc243 Mon Sep 17 00:00:00 2001 From: Andrey Berezin Date: Tue, 10 Oct 2023 17:47:46 +0300 Subject: [PATCH] Keep only one ssh connection per host Signed-off-by: Andrey Berezin --- src/frostfs_testlib/shell/__init__.py | 2 +- src/frostfs_testlib/shell/ssh_shell.py | 188 +++++++++------- .../controllers/cluster_state_controller.py | 15 +- tests/test_ssh_shell.py | 200 ++++++++++++------ 4 files changed, 261 insertions(+), 144 deletions(-) diff --git a/src/frostfs_testlib/shell/__init__.py b/src/frostfs_testlib/shell/__init__.py index 0300ff8..980d119 100644 --- a/src/frostfs_testlib/shell/__init__.py +++ b/src/frostfs_testlib/shell/__init__.py @@ -1,3 +1,3 @@ from frostfs_testlib.shell.interfaces import CommandOptions, CommandResult, InteractiveInput, Shell from frostfs_testlib.shell.local_shell import LocalShell -from frostfs_testlib.shell.ssh_shell import SSHShell +from frostfs_testlib.shell.ssh_shell import SshConnectionProvider, SSHShell diff --git a/src/frostfs_testlib/shell/ssh_shell.py b/src/frostfs_testlib/shell/ssh_shell.py index 435a494..6db7d51 100644 --- a/src/frostfs_testlib/shell/ssh_shell.py +++ b/src/frostfs_testlib/shell/ssh_shell.py @@ -20,12 +20,117 @@ from paramiko import ( from paramiko.ssh_exception import AuthenticationException from frostfs_testlib.reporter import get_reporter -from frostfs_testlib.shell.interfaces import CommandInspector, CommandOptions, CommandResult, Shell +from frostfs_testlib.shell.interfaces import ( + CommandInspector, + CommandOptions, + CommandResult, + Shell, + SshCredentials, +) logger = logging.getLogger("frostfs.testlib.shell") reporter = get_reporter() +class SshConnectionProvider: + SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4 + SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10 + CONNECTION_TIMEOUT = 60 + + instance = None + connections: dict[str, SSHClient] = {} + creds: dict[str, SshCredentials] = {} + + def __new__(cls): + if not cls.instance: + cls.instance = super(SshConnectionProvider, cls).__new__(cls) + return cls.instance + + def store_creds(self, host: str, ssh_creds: SshCredentials): + self.creds[host] = ssh_creds + + def provide(self, host: str, port: str) -> SSHClient: + if host not in self.creds: + raise RuntimeError(f"Please add credentials for host {host}") + + if host in self.connections: + client = self.connections[host] + if client: + return client + + creds = self.creds[host] + client = self._create_connection(host, port, creds) + self.connections[host] = client + return client + + def drop(self, host: str): + if host in self.connections: + client = self.connections.pop(host) + client.close() + + def drop_all(self): + hosts = list(self.connections.keys()) + for host in hosts: + self.drop(host) + + def _create_connection( + self, + host: str, + port: str, + creds: SshCredentials, + ) -> SSHClient: + for attempt in range(self.SSH_CONNECTION_ATTEMPTS): + connection = SSHClient() + connection.set_missing_host_key_policy(AutoAddPolicy()) + try: + if creds.ssh_key_path: + logger.info( + f"Trying to connect to host {host} as {creds.ssh_login} using SSH key " + f"{creds.ssh_key_path} (attempt {attempt})" + ) + connection.connect( + hostname=host, + port=port, + username=creds.ssh_login, + pkey=_load_private_key(creds.ssh_key_path, creds.ssh_key_passphrase), + timeout=self.CONNECTION_TIMEOUT, + ) + else: + logger.info( + f"Trying to connect to host {host} as {creds.ssh_login} using password " + f"(attempt {attempt})" + ) + connection.connect( + hostname=host, + port=port, + username=creds.ssh_login, + password=creds.ssh_password, + timeout=self.CONNECTION_TIMEOUT, + ) + return connection + except AuthenticationException: + connection.close() + logger.exception(f"Can't connect to host {host}") + raise + except ( + SSHException, + ssh_exception.NoValidConnectionsError, + AttributeError, + socket.timeout, + OSError, + ) as exc: + connection.close() + can_retry = attempt + 1 < self.SSH_CONNECTION_ATTEMPTS + if can_retry: + logger.warn( + f"Can't connect to host {host}, will retry after {self.SSH_ATTEMPTS_INTERVAL}s. Error: {exc}" + ) + sleep(self.SSH_ATTEMPTS_INTERVAL) + continue + logger.exception(f"Can't connect to host {host}") + raise HostIsNotAvailable(host) from exc + + class HostIsNotAvailable(Exception): """Raised when host is not reachable via SSH connection.""" @@ -91,10 +196,6 @@ class SSHShell(Shell): # to allow remote command to flush its output buffer DELAY_AFTER_EXIT = 0.2 - SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4 - SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10 - CONNECTION_TIMEOUT = 60 - def __init__( self, host: str, @@ -106,23 +207,21 @@ class SSHShell(Shell): command_inspectors: Optional[list[CommandInspector]] = None, ) -> None: super().__init__() + self.connection_provider = SshConnectionProvider() + self.connection_provider.store_creds( + host, SshCredentials(login, password, private_key_path, private_key_passphrase) + ) self.host = host self.port = port - self.login = login - self.password = password - self.private_key_path = private_key_path - self.private_key_passphrase = private_key_passphrase + self.command_inspectors = command_inspectors or [] - self.__connection: Optional[SSHClient] = None @property def _connection(self): - if not self.__connection: - self.__connection = self._create_connection() - return self.__connection + return self.connection_provider.provide(self.host, self.port) def drop(self): - self._reset_connection() + self.connection_provider.drop(self.host) def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult: options = options or CommandOptions() @@ -196,7 +295,7 @@ class SSHShell(Shell): socket.timeout, ) as exc: logger.exception(f"Can't execute command {command} on host: {self.host}") - self._reset_connection() + self.drop() raise HostIsNotAvailable(self.host) from exc def _read_channels( @@ -251,62 +350,3 @@ class SSHShell(Shell): full_stderr = b"".join(stderr_chunks) return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore")) - - def _create_connection( - self, attempts: int = SSH_CONNECTION_ATTEMPTS, interval: int = SSH_ATTEMPTS_INTERVAL - ) -> SSHClient: - for attempt in range(attempts): - connection = SSHClient() - connection.set_missing_host_key_policy(AutoAddPolicy()) - try: - if self.private_key_path: - logger.info( - f"Trying to connect to host {self.host} as {self.login} using SSH key " - f"{self.private_key_path} (attempt {attempt})" - ) - connection.connect( - hostname=self.host, - port=self.port, - username=self.login, - pkey=_load_private_key(self.private_key_path, self.private_key_passphrase), - timeout=self.CONNECTION_TIMEOUT, - ) - else: - logger.info( - f"Trying to connect to host {self.host} as {self.login} using password " - f"(attempt {attempt})" - ) - connection.connect( - hostname=self.host, - port=self.port, - username=self.login, - password=self.password, - timeout=self.CONNECTION_TIMEOUT, - ) - return connection - except AuthenticationException: - connection.close() - logger.exception(f"Can't connect to host {self.host}") - raise - except ( - SSHException, - ssh_exception.NoValidConnectionsError, - AttributeError, - socket.timeout, - OSError, - ) as exc: - connection.close() - can_retry = attempt + 1 < attempts - if can_retry: - logger.warn( - f"Can't connect to host {self.host}, will retry after {interval}s. Error: {exc}" - ) - sleep(interval) - continue - logger.exception(f"Can't connect to host {self.host}") - raise HostIsNotAvailable(self.host) from exc - - def _reset_connection(self) -> None: - if self.__connection: - self.__connection.close() - self.__connection = None diff --git a/src/frostfs_testlib/storage/controllers/cluster_state_controller.py b/src/frostfs_testlib/storage/controllers/cluster_state_controller.py index ed82167..c6391f5 100644 --- a/src/frostfs_testlib/storage/controllers/cluster_state_controller.py +++ b/src/frostfs_testlib/storage/controllers/cluster_state_controller.py @@ -3,7 +3,7 @@ import time import frostfs_testlib.resources.optionals as optionals from frostfs_testlib.reporter import get_reporter -from frostfs_testlib.shell import CommandOptions, Shell +from frostfs_testlib.shell import CommandOptions, Shell, SshConnectionProvider from frostfs_testlib.steps.network import IfUpDownHelper, IpTablesHelper from frostfs_testlib.storage.cluster import Cluster, ClusterNode, StorageNode from frostfs_testlib.storage.controllers.disk_controller import DiskController @@ -37,6 +37,10 @@ class ClusterStateController: @run_optionally(optionals.OPTIONAL_FAILOVER_ENABLED) @reporter.step_deco("Stop host of node {node}") def stop_node_host(self, node: ClusterNode, mode: str): + # Drop ssh connection for this node before shutdown + provider = SshConnectionProvider() + provider.drop(node.host_ip) + with reporter.step(f"Stop host {node.host.config.address}"): node.host.stop_host(mode=mode) wait_for_host_offline(self.shell, node.storage_node) @@ -48,6 +52,11 @@ class ClusterStateController: nodes = ( reversed(self.cluster.cluster_nodes) if reversed_order else self.cluster.cluster_nodes ) + + # Drop all ssh connections before shutdown + provider = SshConnectionProvider() + provider.drop_all() + for node in nodes: with reporter.step(f"Stop host {node.host.config.address}"): self.stopped_nodes.append(node) @@ -307,6 +316,10 @@ class ClusterStateController: options = CommandOptions(close_stdin=True, timeout=1, check=False) shell.exec('sudo sh -c "echo b > /proc/sysrq-trigger"', options) + # Drop ssh connection for this node + provider = SshConnectionProvider() + provider.drop(node.host_ip) + if wait_for_return: # Let the things to be settled # A little wait here to prevent ssh stuck during panic diff --git a/tests/test_ssh_shell.py b/tests/test_ssh_shell.py index 4d1c0fd..ecd8c3c 100644 --- a/tests/test_ssh_shell.py +++ b/tests/test_ssh_shell.py @@ -1,50 +1,68 @@ import os -from unittest import SkipTest, TestCase + +import pytest from frostfs_testlib.shell.interfaces import CommandOptions, InteractiveInput -from frostfs_testlib.shell.ssh_shell import SSHShell +from frostfs_testlib.shell.ssh_shell import SshConnectionProvider, SSHShell from helpers import format_error_details, get_output_lines -def init_shell() -> SSHShell: - host = os.getenv("SSH_SHELL_HOST") +def get_shell(host: str): port = os.getenv("SSH_SHELL_PORT", "22") login = os.getenv("SSH_SHELL_LOGIN") - private_key_path = os.getenv("SSH_SHELL_PRIVATE_KEY_PATH") - private_key_passphrase = os.getenv("SSH_SHELL_PRIVATE_KEY_PASSPHRASE") + + password = os.getenv("SSH_SHELL_PASSWORD", "") + private_key_path = os.getenv("SSH_SHELL_PRIVATE_KEY_PATH", "") + private_key_passphrase = os.getenv("SSH_SHELL_PRIVATE_KEY_PASSPHRASE", "") if not all([host, login, private_key_path, private_key_passphrase]): # TODO: in the future we might use https://pypi.org/project/mock-ssh-server, # at the moment it is not suitable for us because of its issues with stdin - raise SkipTest("SSH connection is not configured") + pytest.skip("SSH connection is not configured") return SSHShell( host=host, port=port, login=login, + password=password, private_key_path=private_key_path, private_key_passphrase=private_key_passphrase, ) -class TestSSHShellInteractive(TestCase): - @classmethod - def setUpClass(cls): - cls.shell = init_shell() +@pytest.fixture(scope="module") +def shell() -> SSHShell: + return get_shell(host=os.getenv("SSH_SHELL_HOST")) - def test_command_with_one_prompt(self): + +@pytest.fixture(scope="module") +def shell_same_host() -> SSHShell: + return get_shell(host=os.getenv("SSH_SHELL_HOST")) + + +@pytest.fixture(scope="module") +def shell_another_host() -> SSHShell: + return get_shell(host=os.getenv("SSH_SHELL_HOST_2")) + + +@pytest.fixture(scope="function", autouse=True) +def reset_connection(): + provider = SshConnectionProvider() + provider.drop_all() + + +class TestSSHShellInteractive: + def test_command_with_one_prompt(self, shell: SSHShell): script = "password = input('Password: '); print('\\n' + password)" inputs = [InteractiveInput(prompt_pattern="Password", input="test")] - result = self.shell.exec( - f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs) - ) + result = shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)) - self.assertEqual(0, result.return_code) - self.assertEqual(["Password: test", "test"], get_output_lines(result)) - self.assertEqual("", result.stderr) + assert result.return_code == 0 + assert ["Password: test", "test"] == get_output_lines(result) + assert not result.stderr - def test_command_with_several_prompts(self): + def test_command_with_several_prompts(self, shell: SSHShell): script = ( "input1 = input('Input1: '); print('\\n' + input1); " "input2 = input('Input2: '); print('\\n' + input2)" @@ -54,86 +72,132 @@ class TestSSHShellInteractive(TestCase): InteractiveInput(prompt_pattern="Input2", input="test2"), ] - result = self.shell.exec( - f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs) - ) + result = shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)) - self.assertEqual(0, result.return_code) - self.assertEqual( - ["Input1: test1", "test1", "Input2: test2", "test2"], get_output_lines(result) - ) - self.assertEqual("", result.stderr) + assert result.return_code == 0 + assert ["Input1: test1", "test1", "Input2: test2", "test2"] == get_output_lines(result) + assert not result.stderr - def test_invalid_command_with_check(self): + def test_invalid_command_with_check(self, shell: SSHShell): script = "invalid script" inputs = [InteractiveInput(prompt_pattern=".*", input="test")] - with self.assertRaises(RuntimeError) as raised: - self.shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)) + with pytest.raises(RuntimeError) as raised: + shell.exec(f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs)) - error = format_error_details(raised.exception) - self.assertIn("SyntaxError", error) - self.assertIn("return code: 1", error) + error = format_error_details(raised.value) + assert "SyntaxError" in error + assert "return code: 1" in error - def test_invalid_command_without_check(self): + def test_invalid_command_without_check(self, shell: SSHShell): script = "invalid script" inputs = [InteractiveInput(prompt_pattern=".*", input="test")] - result = self.shell.exec( + result = shell.exec( f'python3 -c "{script}"', CommandOptions(interactive_inputs=inputs, check=False), ) - self.assertIn("SyntaxError", result.stdout) - self.assertEqual(1, result.return_code) + assert "SyntaxError" in result.stdout + assert result.return_code == 1 - def test_non_existing_binary(self): + def test_non_existing_binary(self, shell: SSHShell): inputs = [InteractiveInput(prompt_pattern=".*", input="test")] - with self.assertRaises(RuntimeError) as raised: - self.shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs)) + with pytest.raises(RuntimeError) as raised: + shell.exec("not-a-command", CommandOptions(interactive_inputs=inputs)) - error = format_error_details(raised.exception) - self.assertIn("return code: 127", error) + error = format_error_details(raised.value) + assert "return code: 127" in error -class TestSSHShellNonInteractive(TestCase): - @classmethod - def setUpClass(cls): - cls.shell = init_shell() - - def test_correct_command(self): +class TestSSHShellNonInteractive: + def test_correct_command(self, shell: SSHShell): script = "print('test')" - result = self.shell.exec(f'python3 -c "{script}"') + result = shell.exec(f'python3 -c "{script}"') - self.assertEqual(0, result.return_code) - self.assertEqual("test", result.stdout.strip()) - self.assertEqual("", result.stderr) + assert result.return_code == 0 + assert result.stdout.strip() == "test" + assert not result.stderr - def test_invalid_command_with_check(self): + def test_invalid_command_with_check(self, shell: SSHShell): script = "invalid script" - with self.assertRaises(RuntimeError) as raised: - self.shell.exec(f'python3 -c "{script}"') + with pytest.raises(RuntimeError) as raised: + shell.exec(f'python3 -c "{script}"') - error = format_error_details(raised.exception) - self.assertIn("Error", error) - self.assertIn("return code: 1", error) + error = format_error_details(raised.value) + assert "Error" in error + assert "return code: 1" in error - def test_invalid_command_without_check(self): + def test_invalid_command_without_check(self, shell: SSHShell): script = "invalid script" - result = self.shell.exec(f'python3 -c "{script}"', CommandOptions(check=False)) + result = shell.exec(f'python3 -c "{script}"', CommandOptions(check=False)) - self.assertEqual(1, result.return_code) + assert result.return_code == 1 # TODO: we have inconsistency with local shell here, the local shell captures error info # in stdout while ssh shell captures it in stderr - self.assertIn("Error", result.stderr) + assert "Error" in result.stderr - def test_non_existing_binary(self): - with self.assertRaises(RuntimeError) as exc: - self.shell.exec("not-a-command") + def test_non_existing_binary(self, shell: SSHShell): + with pytest.raises(RuntimeError) as raised: + shell.exec("not-a-command") - error = format_error_details(exc.exception) - self.assertIn("Error", error) - self.assertIn("return code: 127", error) + error = format_error_details(raised.value) + assert "Error" in error + assert "return code: 127" in error + + +class TestSSHShellConnection: + def test_connection_provider_is_singleton(self): + provider = SshConnectionProvider() + provider2 = SshConnectionProvider() + assert id(provider) == id(provider2) + + def test_connection_provider_has_creds(self, shell: SSHShell): + provider = SshConnectionProvider() + assert len(provider.creds) == 1 + assert len(provider.connections) == 0 + + def test_connection_provider_has_only_one_connection(self, shell: SSHShell): + provider = SshConnectionProvider() + assert len(provider.connections) == 0 + shell.exec("echo 1") + assert len(provider.connections) == 1 + shell.exec("echo 2") + assert len(provider.connections) == 1 + shell.drop() + assert len(provider.connections) == 0 + + def test_connection_same_host(self, shell: SSHShell, shell_same_host: SSHShell): + provider = SshConnectionProvider() + assert len(provider.connections) == 0 + + shell.exec("echo 1") + assert len(provider.connections) == 1 + + shell_same_host.exec("echo 2") + assert len(provider.connections) == 1 + + shell.drop() + assert len(provider.connections) == 0 + + shell.exec("echo 3") + assert len(provider.connections) == 1 + + def test_connection_another_host(self, shell: SSHShell, shell_another_host: SSHShell): + provider = SshConnectionProvider() + assert len(provider.connections) == 0 + + shell.exec("echo 1") + assert len(provider.connections) == 1 + + shell_another_host.exec("echo 2") + assert len(provider.connections) == 2 + + shell.drop() + assert len(provider.connections) == 1 + + shell_another_host.drop() + assert len(provider.connections) == 0