import logging import socket import textwrap from datetime import datetime from functools import lru_cache, wraps from time import sleep from typing import ClassVar, Optional, Tuple from paramiko import AutoAddPolicy, Channel, ECDSAKey, Ed25519Key, PKey, RSAKey, SSHClient, SSHException, ssh_exception from paramiko.ssh_exception import AuthenticationException from frostfs_testlib import reporter from frostfs_testlib.shell.interfaces import CommandInspector, CommandOptions, CommandResult, Shell, SshCredentials logger = logging.getLogger("frostfs.testlib.shell") class SshConnectionProvider: SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4 SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10 CONNECTION_TIMEOUT = 60 instance = None connections: dict[str, SSHClient] = {} creds: dict[str, SshCredentials] = {} def __new__(cls): if not cls.instance: cls.instance = super(SshConnectionProvider, cls).__new__(cls) return cls.instance def store_creds(self, host: str, ssh_creds: SshCredentials): self.creds[host] = ssh_creds def provide(self, host: str, port: str) -> SSHClient: if host not in self.creds: raise RuntimeError(f"Please add credentials for host {host}") if host in self.connections: client = self.connections[host] if client: return client creds = self.creds[host] client = self._create_connection(host, port, creds) self.connections[host] = client return client def drop(self, host: str): if host in self.connections: client = self.connections.pop(host) client.close() def drop_all(self): hosts = list(self.connections.keys()) for host in hosts: self.drop(host) def _create_connection( self, host: str, port: str, creds: SshCredentials, ) -> SSHClient: for attempt in range(self.SSH_CONNECTION_ATTEMPTS): connection = SSHClient() connection.set_missing_host_key_policy(AutoAddPolicy()) try: if creds.ssh_key_path: logger.info( f"Trying to connect to host {host} as {creds.ssh_login} using SSH key " f"{creds.ssh_key_path} (attempt {attempt})" ) connection.connect( hostname=host, port=port, username=creds.ssh_login, pkey=_load_private_key(creds.ssh_key_path, creds.ssh_key_passphrase), timeout=self.CONNECTION_TIMEOUT, ) else: logger.info( f"Trying to connect to host {host} as {creds.ssh_login} using password " f"(attempt {attempt})" ) connection.connect( hostname=host, port=port, username=creds.ssh_login, password=creds.ssh_password, timeout=self.CONNECTION_TIMEOUT, ) return connection except AuthenticationException: connection.close() logger.exception(f"Can't connect to host {host}") raise except ( SSHException, ssh_exception.NoValidConnectionsError, AttributeError, socket.timeout, OSError, ) as exc: connection.close() can_retry = attempt + 1 < self.SSH_CONNECTION_ATTEMPTS if can_retry: logger.warn( f"Can't connect to host {host}, will retry after {self.SSH_ATTEMPTS_INTERVAL}s. Error: {exc}" ) sleep(self.SSH_ATTEMPTS_INTERVAL) continue logger.exception(f"Can't connect to host {host}") raise HostIsNotAvailable(host) from exc class HostIsNotAvailable(Exception): """Raised when host is not reachable via SSH connection.""" def __init__(self, host: Optional[str] = None): msg = f"Host {host} is not available" super().__init__(msg) def log_command(func): @wraps(func) def wrapper(shell: "SSHShell", command: str, options: CommandOptions, *args, **kwargs) -> CommandResult: command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n") with reporter.step(command_info): logger.info(f'Execute command "{command}" on "{shell.host}"') start_time = datetime.utcnow() result = func(shell, command, options, *args, **kwargs) end_time = datetime.utcnow() elapsed_time = end_time - start_time log_message = ( f"HOST: {shell.host}\n" f"COMMAND:\n{textwrap.indent(command, ' ')}\n" f"RC:\n {result.return_code}\n" f"STDOUT:\n{textwrap.indent(result.stdout, ' ')}\n" f"STDERR:\n{textwrap.indent(result.stderr, ' ')}\n" f"Start / End / Elapsed\t {start_time.time()} / {end_time.time()} / {elapsed_time}" ) if not options.no_log: logger.info(log_message) reporter.attach(log_message, "SSH command.txt") return result return wrapper @lru_cache def _load_private_key(file_path: str, password: Optional[str]) -> PKey: """Loads private key from specified file. We support several type formats, however paramiko doesn't provide functionality to determine key type in advance. So we attempt to load file with each of the supported formats and then cache the result so that we don't need to figure out type again on subsequent calls. """ logger.debug(f"Loading ssh key from {file_path}") for key_type in (Ed25519Key, ECDSAKey, RSAKey): try: return key_type.from_private_key_file(file_path, password) except SSHException as ex: logger.warn(f"SSH key {file_path} can't be loaded with {key_type}: {ex}") continue raise SSHException(f"SSH key {file_path} is not supported") class SSHShell(Shell): """Implements command shell on a remote machine via SSH connection.""" # Time in seconds to delay after remote command has completed. The delay is required # to allow remote command to flush its output buffer DELAY_AFTER_EXIT = 0.2 def __init__( self, host: str, login: str, password: Optional[str] = None, private_key_path: Optional[str] = None, private_key_passphrase: Optional[str] = None, port: str = "22", command_inspectors: Optional[list[CommandInspector]] = None, custom_environment: Optional[dict] = None ) -> None: super().__init__() self.connection_provider = SshConnectionProvider() self.connection_provider.store_creds( host, SshCredentials(login, password, private_key_path, private_key_passphrase) ) self.host = host self.port = port self.command_inspectors = command_inspectors or [] self.environment = custom_environment @property def _connection(self): return self.connection_provider.provide(self.host, self.port) def drop(self): self.connection_provider.drop(self.host) def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult: options = options or CommandOptions() original_command = command extra_inspectors = options.extra_inspectors if options.extra_inspectors else [] for inspector in [*self.command_inspectors, *extra_inspectors]: command = inspector.inspect(original_command, command) if options.interactive_inputs: result = self._exec_interactive(command, options) else: result = self._exec_non_interactive(command, options) if options.check and result.return_code != 0: raise RuntimeError( f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}\nStderr: {result.stderr}\n" ) return result @log_command def _exec_interactive(self, command: str, options: CommandOptions) -> CommandResult: stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout, get_pty=True, environment=self.environment) for interactive_input in options.interactive_inputs: input = interactive_input.input if not input.endswith("\n"): input = f"{input}\n" try: stdin.write(input) except OSError: logger.exception(f"Error while feeding {input} into command {command}") if options.close_stdin: stdin.close() sleep(self.DELAY_AFTER_EXIT) decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel) return_code = stdout.channel.recv_exit_status() result = CommandResult( stdout=decoded_stdout, stderr=decoded_stderr, return_code=return_code, ) return result @log_command def _exec_non_interactive(self, command: str, options: CommandOptions) -> CommandResult: try: stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout, environment=self.environment) if options.close_stdin: stdin.close() decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel) return_code = stdout.channel.recv_exit_status() return CommandResult( stdout=decoded_stdout, stderr=decoded_stderr, return_code=return_code, ) except ( SSHException, TimeoutError, ssh_exception.NoValidConnectionsError, ConnectionResetError, AttributeError, socket.timeout, ) as exc: logger.exception(f"Can't execute command {command} on host: {self.host}") self.drop() raise HostIsNotAvailable(self.host) from exc def _read_channels( self, stdout: Channel, stderr: Channel, chunk_size: int = 4096, ) -> Tuple[str, str]: """Reads data from stdout/stderr channels. Reading channels is required before we wait for exit status of the remote process. Otherwise waiting step will hang indefinitely, see the warning from paramiko docs: # https://docs.paramiko.org/en/stable/api/channel.html#paramiko.channel.Channel.recv_exit_status Args: stdout: Channel of stdout stream of the remote process. stderr: Channel of stderr stream of the remote process. chunk_size: Max size of data chunk that we read from channel at a time. Returns: Tuple with stdout and stderr channels decoded into strings. """ # We read data in chunks stdout_chunks = [] stderr_chunks = [] # Read from channels (if data is ready) until process exits while not stdout.exit_status_ready(): if stdout.recv_ready(): stdout_chunks.append(stdout.recv(chunk_size)) if stderr.recv_stderr_ready(): stderr_chunks.append(stderr.recv_stderr(chunk_size)) # Wait for command to complete and flush its buffer before we read final output sleep(self.DELAY_AFTER_EXIT) # Read the remaining data from the channels: # If channel returns empty data chunk, it means that all data has been read while True: data_chunk = stdout.recv(chunk_size) if not data_chunk: break stdout_chunks.append(data_chunk) while True: data_chunk = stderr.recv_stderr(chunk_size) if not data_chunk: break stderr_chunks.append(data_chunk) # Combine chunks and decode results into regular strings full_stdout = b"".join(stdout_chunks) full_stderr = b"".join(stderr_chunks) return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))