Fix hanging of ssh shell

SSH shell was hanging while waiting for remote process exit code. The
hanging occurs when stdout/stderr contain large amount of data. The fix
changes how we read the data and how we wait for remote process's exit
code.

Signed-off-by: Vladimir Domnich <v.domnich@yadro.com>
This commit is contained in:
Vladimir Domnich 2022-10-17 14:15:25 +04:00 committed by Vladimir
parent 64430486f1
commit a79b608b4b

View file

@ -4,10 +4,11 @@ import textwrap
from datetime import datetime
from functools import lru_cache, wraps
from time import sleep
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Tuple
from paramiko import (
AutoAddPolicy,
Channel,
ECDSAKey,
Ed25519Key,
PKey,
@ -38,7 +39,7 @@ def log_command(func):
def wrapper(shell: "SSHShell", command: str, *args, **kwargs) -> CommandResult:
command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n")
with reporter.step(command_info):
logging.info(f'Execute command "{command}" on "{shell.host}"')
logger.info(f'Execute command "{command}" on "{shell.host}"')
start_time = datetime.utcnow()
result = func(shell, command, *args, **kwargs)
@ -146,17 +147,17 @@ class SSHShell(Shell):
stdin.write(input)
except OSError:
logger.exception(f"Error while feeding {input} into command {command}")
if options.close_stdin:
stdin.close()
sleep(self.DELAY_AFTER_EXIT)
# Wait for command to complete and flush its buffer before we attempt to read output
sleep(self.DELAY_AFTER_EXIT)
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
return_code = stdout.channel.recv_exit_status()
sleep(self.DELAY_AFTER_EXIT)
result = CommandResult(
stdout=stdout.read().decode(errors="ignore"),
stderr=stderr.read().decode(errors="ignore"),
stdout=decoded_stdout,
stderr=decoded_stderr,
return_code=return_code,
)
return result
@ -169,13 +170,12 @@ class SSHShell(Shell):
if options.close_stdin:
stdin.close()
# Wait for command to complete and flush its buffer before we attempt to read output
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
return_code = stdout.channel.recv_exit_status()
sleep(self.DELAY_AFTER_EXIT)
return CommandResult(
stdout=stdout.read().decode(errors="ignore"),
stderr=stderr.read().decode(errors="ignore"),
stdout=decoded_stdout,
stderr=decoded_stderr,
return_code=return_code,
)
except (
@ -190,13 +190,66 @@ class SSHShell(Shell):
self._reset_connection()
raise HostIsNotAvailable(self.host) from exc
def _read_channels(
self,
stdout: Channel,
stderr: Channel,
chunk_size: int = 4096,
) -> Tuple[str, str]:
"""Reads data from stdout/stderr channels.
Reading channels is required before we wait for exit status of the remote process.
Otherwise waiting step will hang indefinitely, see the warning from paramiko docs:
# https://docs.paramiko.org/en/stable/api/channel.html#paramiko.channel.Channel.recv_exit_status
Args:
stdout: Channel of stdout stream of the remote process.
stderr: Channel of stderr stream of the remote process.
chunk_size: Max size of data chunk that we read from channel at a time.
Returns:
Tuple with stdout and stderr channels decoded into strings.
"""
# We read data in chunks
stdout_chunks = []
stderr_chunks = []
# Read from channels (if data is ready) until process exits
while not stdout.exit_status_ready():
if stdout.recv_ready():
stdout_chunks.append(stdout.recv(chunk_size))
if stderr.recv_stderr_ready():
stderr_chunks.append(stderr.recv_stderr(chunk_size))
# Wait for command to complete and flush its buffer before we read final output
sleep(self.DELAY_AFTER_EXIT)
# Read the remaining data from the channels:
# If channel returns empty data chunk, it means that all data has been read
while True:
data_chunk = stdout.recv(chunk_size)
if not data_chunk:
break
stdout_chunks.append(data_chunk)
while True:
data_chunk = stderr.recv_stderr(chunk_size)
if not data_chunk:
break
stderr_chunks.append(data_chunk)
# Combine chunks and decode results into regular strings
full_stdout = b"".join(stdout_chunks)
full_stderr = b"".join(stderr_chunks)
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
def _create_connection(self, attempts: int = SSH_CONNECTION_ATTEMPTS) -> SSHClient:
for attempt in range(attempts):
connection = SSHClient()
connection.set_missing_host_key_policy(AutoAddPolicy())
try:
if self.private_key_path:
logging.info(
logger.info(
f"Trying to connect to host {self.host} as {self.login} using SSH key "
f"{self.private_key_path} (attempt {attempt})"
)
@ -208,7 +261,7 @@ class SSHShell(Shell):
timeout=self.CONNECTION_TIMEOUT,
)
else:
logging.info(
logger.info(
f"Trying to connect to host {self.host} as {self.login} using password "
f"(attempt {attempt})"
)