Fix hanging of ssh shell
SSH shell was hanging while waiting for remote process exit code. The hanging occurs when stdout/stderr contain large amount of data. The fix changes how we read the data and how we wait for remote process's exit code. Signed-off-by: Vladimir Domnich <v.domnich@yadro.com>
This commit is contained in:
parent
64430486f1
commit
a79b608b4b
1 changed files with 66 additions and 13 deletions
|
@ -4,10 +4,11 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import ClassVar, Optional
|
from typing import ClassVar, Optional, Tuple
|
||||||
|
|
||||||
from paramiko import (
|
from paramiko import (
|
||||||
AutoAddPolicy,
|
AutoAddPolicy,
|
||||||
|
Channel,
|
||||||
ECDSAKey,
|
ECDSAKey,
|
||||||
Ed25519Key,
|
Ed25519Key,
|
||||||
PKey,
|
PKey,
|
||||||
|
@ -38,7 +39,7 @@ def log_command(func):
|
||||||
def wrapper(shell: "SSHShell", command: str, *args, **kwargs) -> CommandResult:
|
def wrapper(shell: "SSHShell", command: str, *args, **kwargs) -> CommandResult:
|
||||||
command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n")
|
command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n")
|
||||||
with reporter.step(command_info):
|
with reporter.step(command_info):
|
||||||
logging.info(f'Execute command "{command}" on "{shell.host}"')
|
logger.info(f'Execute command "{command}" on "{shell.host}"')
|
||||||
|
|
||||||
start_time = datetime.utcnow()
|
start_time = datetime.utcnow()
|
||||||
result = func(shell, command, *args, **kwargs)
|
result = func(shell, command, *args, **kwargs)
|
||||||
|
@ -146,17 +147,17 @@ class SSHShell(Shell):
|
||||||
stdin.write(input)
|
stdin.write(input)
|
||||||
except OSError:
|
except OSError:
|
||||||
logger.exception(f"Error while feeding {input} into command {command}")
|
logger.exception(f"Error while feeding {input} into command {command}")
|
||||||
|
|
||||||
if options.close_stdin:
|
if options.close_stdin:
|
||||||
stdin.close()
|
stdin.close()
|
||||||
|
sleep(self.DELAY_AFTER_EXIT)
|
||||||
|
|
||||||
# Wait for command to complete and flush its buffer before we attempt to read output
|
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
|
||||||
sleep(self.DELAY_AFTER_EXIT)
|
|
||||||
return_code = stdout.channel.recv_exit_status()
|
return_code = stdout.channel.recv_exit_status()
|
||||||
sleep(self.DELAY_AFTER_EXIT)
|
|
||||||
|
|
||||||
result = CommandResult(
|
result = CommandResult(
|
||||||
stdout=stdout.read().decode(errors="ignore"),
|
stdout=decoded_stdout,
|
||||||
stderr=stderr.read().decode(errors="ignore"),
|
stderr=decoded_stderr,
|
||||||
return_code=return_code,
|
return_code=return_code,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
@ -169,13 +170,12 @@ class SSHShell(Shell):
|
||||||
if options.close_stdin:
|
if options.close_stdin:
|
||||||
stdin.close()
|
stdin.close()
|
||||||
|
|
||||||
# Wait for command to complete and flush its buffer before we attempt to read output
|
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
|
||||||
return_code = stdout.channel.recv_exit_status()
|
return_code = stdout.channel.recv_exit_status()
|
||||||
sleep(self.DELAY_AFTER_EXIT)
|
|
||||||
|
|
||||||
return CommandResult(
|
return CommandResult(
|
||||||
stdout=stdout.read().decode(errors="ignore"),
|
stdout=decoded_stdout,
|
||||||
stderr=stderr.read().decode(errors="ignore"),
|
stderr=decoded_stderr,
|
||||||
return_code=return_code,
|
return_code=return_code,
|
||||||
)
|
)
|
||||||
except (
|
except (
|
||||||
|
@ -190,13 +190,66 @@ class SSHShell(Shell):
|
||||||
self._reset_connection()
|
self._reset_connection()
|
||||||
raise HostIsNotAvailable(self.host) from exc
|
raise HostIsNotAvailable(self.host) from exc
|
||||||
|
|
||||||
|
def _read_channels(
|
||||||
|
self,
|
||||||
|
stdout: Channel,
|
||||||
|
stderr: Channel,
|
||||||
|
chunk_size: int = 4096,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""Reads data from stdout/stderr channels.
|
||||||
|
|
||||||
|
Reading channels is required before we wait for exit status of the remote process.
|
||||||
|
Otherwise waiting step will hang indefinitely, see the warning from paramiko docs:
|
||||||
|
# https://docs.paramiko.org/en/stable/api/channel.html#paramiko.channel.Channel.recv_exit_status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stdout: Channel of stdout stream of the remote process.
|
||||||
|
stderr: Channel of stderr stream of the remote process.
|
||||||
|
chunk_size: Max size of data chunk that we read from channel at a time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with stdout and stderr channels decoded into strings.
|
||||||
|
"""
|
||||||
|
# We read data in chunks
|
||||||
|
stdout_chunks = []
|
||||||
|
stderr_chunks = []
|
||||||
|
|
||||||
|
# Read from channels (if data is ready) until process exits
|
||||||
|
while not stdout.exit_status_ready():
|
||||||
|
if stdout.recv_ready():
|
||||||
|
stdout_chunks.append(stdout.recv(chunk_size))
|
||||||
|
if stderr.recv_stderr_ready():
|
||||||
|
stderr_chunks.append(stderr.recv_stderr(chunk_size))
|
||||||
|
|
||||||
|
# Wait for command to complete and flush its buffer before we read final output
|
||||||
|
sleep(self.DELAY_AFTER_EXIT)
|
||||||
|
|
||||||
|
# Read the remaining data from the channels:
|
||||||
|
# If channel returns empty data chunk, it means that all data has been read
|
||||||
|
while True:
|
||||||
|
data_chunk = stdout.recv(chunk_size)
|
||||||
|
if not data_chunk:
|
||||||
|
break
|
||||||
|
stdout_chunks.append(data_chunk)
|
||||||
|
while True:
|
||||||
|
data_chunk = stderr.recv_stderr(chunk_size)
|
||||||
|
if not data_chunk:
|
||||||
|
break
|
||||||
|
stderr_chunks.append(data_chunk)
|
||||||
|
|
||||||
|
# Combine chunks and decode results into regular strings
|
||||||
|
full_stdout = b"".join(stdout_chunks)
|
||||||
|
full_stderr = b"".join(stderr_chunks)
|
||||||
|
|
||||||
|
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
|
||||||
|
|
||||||
def _create_connection(self, attempts: int = SSH_CONNECTION_ATTEMPTS) -> SSHClient:
|
def _create_connection(self, attempts: int = SSH_CONNECTION_ATTEMPTS) -> SSHClient:
|
||||||
for attempt in range(attempts):
|
for attempt in range(attempts):
|
||||||
connection = SSHClient()
|
connection = SSHClient()
|
||||||
connection.set_missing_host_key_policy(AutoAddPolicy())
|
connection.set_missing_host_key_policy(AutoAddPolicy())
|
||||||
try:
|
try:
|
||||||
if self.private_key_path:
|
if self.private_key_path:
|
||||||
logging.info(
|
logger.info(
|
||||||
f"Trying to connect to host {self.host} as {self.login} using SSH key "
|
f"Trying to connect to host {self.host} as {self.login} using SSH key "
|
||||||
f"{self.private_key_path} (attempt {attempt})"
|
f"{self.private_key_path} (attempt {attempt})"
|
||||||
)
|
)
|
||||||
|
@ -208,7 +261,7 @@ class SSHShell(Shell):
|
||||||
timeout=self.CONNECTION_TIMEOUT,
|
timeout=self.CONNECTION_TIMEOUT,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info(
|
logger.info(
|
||||||
f"Trying to connect to host {self.host} as {self.login} using password "
|
f"Trying to connect to host {self.host} as {self.login} using password "
|
||||||
f"(attempt {attempt})"
|
f"(attempt {attempt})"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue