Fix hanging of ssh shell
SSH shell was hanging while waiting for remote process exit code. The hanging occurs when stdout/stderr contain large amount of data. The fix changes how we read the data and how we wait for remote process's exit code. Signed-off-by: Vladimir Domnich <v.domnich@yadro.com>
This commit is contained in:
parent
64430486f1
commit
a79b608b4b
1 changed files with 66 additions and 13 deletions
|
@ -4,10 +4,11 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from functools import lru_cache, wraps
|
||||
from time import sleep
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Tuple
|
||||
|
||||
from paramiko import (
|
||||
AutoAddPolicy,
|
||||
Channel,
|
||||
ECDSAKey,
|
||||
Ed25519Key,
|
||||
PKey,
|
||||
|
@ -38,7 +39,7 @@ def log_command(func):
|
|||
def wrapper(shell: "SSHShell", command: str, *args, **kwargs) -> CommandResult:
|
||||
command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n")
|
||||
with reporter.step(command_info):
|
||||
logging.info(f'Execute command "{command}" on "{shell.host}"')
|
||||
logger.info(f'Execute command "{command}" on "{shell.host}"')
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
result = func(shell, command, *args, **kwargs)
|
||||
|
@ -146,17 +147,17 @@ class SSHShell(Shell):
|
|||
stdin.write(input)
|
||||
except OSError:
|
||||
logger.exception(f"Error while feeding {input} into command {command}")
|
||||
|
||||
if options.close_stdin:
|
||||
stdin.close()
|
||||
sleep(self.DELAY_AFTER_EXIT)
|
||||
|
||||
# Wait for command to complete and flush its buffer before we attempt to read output
|
||||
sleep(self.DELAY_AFTER_EXIT)
|
||||
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
|
||||
return_code = stdout.channel.recv_exit_status()
|
||||
sleep(self.DELAY_AFTER_EXIT)
|
||||
|
||||
result = CommandResult(
|
||||
stdout=stdout.read().decode(errors="ignore"),
|
||||
stderr=stderr.read().decode(errors="ignore"),
|
||||
stdout=decoded_stdout,
|
||||
stderr=decoded_stderr,
|
||||
return_code=return_code,
|
||||
)
|
||||
return result
|
||||
|
@ -169,13 +170,12 @@ class SSHShell(Shell):
|
|||
if options.close_stdin:
|
||||
stdin.close()
|
||||
|
||||
# Wait for command to complete and flush its buffer before we attempt to read output
|
||||
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
|
||||
return_code = stdout.channel.recv_exit_status()
|
||||
sleep(self.DELAY_AFTER_EXIT)
|
||||
|
||||
return CommandResult(
|
||||
stdout=stdout.read().decode(errors="ignore"),
|
||||
stderr=stderr.read().decode(errors="ignore"),
|
||||
stdout=decoded_stdout,
|
||||
stderr=decoded_stderr,
|
||||
return_code=return_code,
|
||||
)
|
||||
except (
|
||||
|
@ -190,13 +190,66 @@ class SSHShell(Shell):
|
|||
self._reset_connection()
|
||||
raise HostIsNotAvailable(self.host) from exc
|
||||
|
||||
def _read_channels(
|
||||
self,
|
||||
stdout: Channel,
|
||||
stderr: Channel,
|
||||
chunk_size: int = 4096,
|
||||
) -> Tuple[str, str]:
|
||||
"""Reads data from stdout/stderr channels.
|
||||
|
||||
Reading channels is required before we wait for exit status of the remote process.
|
||||
Otherwise waiting step will hang indefinitely, see the warning from paramiko docs:
|
||||
# https://docs.paramiko.org/en/stable/api/channel.html#paramiko.channel.Channel.recv_exit_status
|
||||
|
||||
Args:
|
||||
stdout: Channel of stdout stream of the remote process.
|
||||
stderr: Channel of stderr stream of the remote process.
|
||||
chunk_size: Max size of data chunk that we read from channel at a time.
|
||||
|
||||
Returns:
|
||||
Tuple with stdout and stderr channels decoded into strings.
|
||||
"""
|
||||
# We read data in chunks
|
||||
stdout_chunks = []
|
||||
stderr_chunks = []
|
||||
|
||||
# Read from channels (if data is ready) until process exits
|
||||
while not stdout.exit_status_ready():
|
||||
if stdout.recv_ready():
|
||||
stdout_chunks.append(stdout.recv(chunk_size))
|
||||
if stderr.recv_stderr_ready():
|
||||
stderr_chunks.append(stderr.recv_stderr(chunk_size))
|
||||
|
||||
# Wait for command to complete and flush its buffer before we read final output
|
||||
sleep(self.DELAY_AFTER_EXIT)
|
||||
|
||||
# Read the remaining data from the channels:
|
||||
# If channel returns empty data chunk, it means that all data has been read
|
||||
while True:
|
||||
data_chunk = stdout.recv(chunk_size)
|
||||
if not data_chunk:
|
||||
break
|
||||
stdout_chunks.append(data_chunk)
|
||||
while True:
|
||||
data_chunk = stderr.recv_stderr(chunk_size)
|
||||
if not data_chunk:
|
||||
break
|
||||
stderr_chunks.append(data_chunk)
|
||||
|
||||
# Combine chunks and decode results into regular strings
|
||||
full_stdout = b"".join(stdout_chunks)
|
||||
full_stderr = b"".join(stderr_chunks)
|
||||
|
||||
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
|
||||
|
||||
def _create_connection(self, attempts: int = SSH_CONNECTION_ATTEMPTS) -> SSHClient:
|
||||
for attempt in range(attempts):
|
||||
connection = SSHClient()
|
||||
connection.set_missing_host_key_policy(AutoAddPolicy())
|
||||
try:
|
||||
if self.private_key_path:
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Trying to connect to host {self.host} as {self.login} using SSH key "
|
||||
f"{self.private_key_path} (attempt {attempt})"
|
||||
)
|
||||
|
@ -208,7 +261,7 @@ class SSHShell(Shell):
|
|||
timeout=self.CONNECTION_TIMEOUT,
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Trying to connect to host {self.host} as {self.login} using password "
|
||||
f"(attempt {attempt})"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue