diff --git a/src/neofs_testlib/shell/ssh_shell.py b/src/neofs_testlib/shell/ssh_shell.py index d56e2c4e..4e918c18 100644 --- a/src/neofs_testlib/shell/ssh_shell.py +++ b/src/neofs_testlib/shell/ssh_shell.py @@ -4,10 +4,11 @@ import textwrap from datetime import datetime from functools import lru_cache, wraps from time import sleep -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Tuple from paramiko import ( AutoAddPolicy, + Channel, ECDSAKey, Ed25519Key, PKey, @@ -38,7 +39,7 @@ def log_command(func): def wrapper(shell: "SSHShell", command: str, *args, **kwargs) -> CommandResult: command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n") with reporter.step(command_info): - logging.info(f'Execute command "{command}" on "{shell.host}"') + logger.info(f'Execute command "{command}" on "{shell.host}"') start_time = datetime.utcnow() result = func(shell, command, *args, **kwargs) @@ -146,17 +147,17 @@ class SSHShell(Shell): stdin.write(input) except OSError: logger.exception(f"Error while feeding {input} into command {command}") + if options.close_stdin: stdin.close() + sleep(self.DELAY_AFTER_EXIT) - # Wait for command to complete and flush its buffer before we attempt to read output - sleep(self.DELAY_AFTER_EXIT) + decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel) return_code = stdout.channel.recv_exit_status() - sleep(self.DELAY_AFTER_EXIT) result = CommandResult( - stdout=stdout.read().decode(errors="ignore"), - stderr=stderr.read().decode(errors="ignore"), + stdout=decoded_stdout, + stderr=decoded_stderr, return_code=return_code, ) return result @@ -169,13 +170,12 @@ class SSHShell(Shell): if options.close_stdin: stdin.close() - # Wait for command to complete and flush its buffer before we attempt to read output + decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel) return_code = stdout.channel.recv_exit_status() - sleep(self.DELAY_AFTER_EXIT) return CommandResult( - stdout=stdout.read().decode(errors="ignore"), - stderr=stderr.read().decode(errors="ignore"), + stdout=decoded_stdout, + stderr=decoded_stderr, return_code=return_code, ) except ( @@ -190,13 +190,66 @@ class SSHShell(Shell): self._reset_connection() raise HostIsNotAvailable(self.host) from exc + def _read_channels( + self, + stdout: Channel, + stderr: Channel, + chunk_size: int = 4096, + ) -> Tuple[str, str]: + """Reads data from stdout/stderr channels. + + Reading channels is required before we wait for exit status of the remote process. + Otherwise waiting step will hang indefinitely, see the warning from paramiko docs: + # https://docs.paramiko.org/en/stable/api/channel.html#paramiko.channel.Channel.recv_exit_status + + Args: + stdout: Channel of stdout stream of the remote process. + stderr: Channel of stderr stream of the remote process. + chunk_size: Max size of data chunk that we read from channel at a time. + + Returns: + Tuple with stdout and stderr channels decoded into strings. + """ + # We read data in chunks + stdout_chunks = [] + stderr_chunks = [] + + # Read from channels (if data is ready) until process exits + while not stdout.exit_status_ready(): + if stdout.recv_ready(): + stdout_chunks.append(stdout.recv(chunk_size)) + if stderr.recv_stderr_ready(): + stderr_chunks.append(stderr.recv_stderr(chunk_size)) + + # Wait for command to complete and flush its buffer before we read final output + sleep(self.DELAY_AFTER_EXIT) + + # Read the remaining data from the channels: + # If channel returns empty data chunk, it means that all data has been read + while True: + data_chunk = stdout.recv(chunk_size) + if not data_chunk: + break + stdout_chunks.append(data_chunk) + while True: + data_chunk = stderr.recv_stderr(chunk_size) + if not data_chunk: + break + stderr_chunks.append(data_chunk) + + # Combine chunks and decode results into regular strings + full_stdout = b"".join(stdout_chunks) + full_stderr = b"".join(stderr_chunks) + + return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore")) + def _create_connection(self, attempts: int = SSH_CONNECTION_ATTEMPTS) -> SSHClient: for attempt in range(attempts): connection = SSHClient() connection.set_missing_host_key_policy(AutoAddPolicy()) try: if self.private_key_path: - logging.info( + logger.info( f"Trying to connect to host {self.host} as {self.login} using SSH key " f"{self.private_key_path} (attempt {attempt})" ) @@ -208,7 +261,7 @@ class SSHShell(Shell): timeout=self.CONNECTION_TIMEOUT, ) else: - logging.info( + logger.info( f"Trying to connect to host {self.host} as {self.login} using password " f"(attempt {attempt})" )