forked from TrueCloudLab/frostfs-testlib
333 lines
12 KiB
Python
333 lines
12 KiB
Python
import logging
|
|
import socket
|
|
import textwrap
|
|
from datetime import datetime
|
|
from functools import lru_cache, wraps
|
|
from time import sleep
|
|
from typing import ClassVar, Optional, Tuple
|
|
|
|
from paramiko import AutoAddPolicy, Channel, ECDSAKey, Ed25519Key, PKey, RSAKey, SSHClient, SSHException, ssh_exception
|
|
from paramiko.ssh_exception import AuthenticationException
|
|
|
|
from frostfs_testlib import reporter
|
|
from frostfs_testlib.shell.interfaces import CommandInspector, CommandOptions, CommandResult, Shell, SshCredentials
|
|
|
|
logger = logging.getLogger("frostfs.testlib.shell")
|
|
|
|
|
|
class SshConnectionProvider:
|
|
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4
|
|
SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10
|
|
CONNECTION_TIMEOUT = 60
|
|
|
|
instance = None
|
|
connections: dict[str, SSHClient] = {}
|
|
creds: dict[str, SshCredentials] = {}
|
|
|
|
def __new__(cls):
|
|
if not cls.instance:
|
|
cls.instance = super(SshConnectionProvider, cls).__new__(cls)
|
|
return cls.instance
|
|
|
|
def store_creds(self, host: str, ssh_creds: SshCredentials):
|
|
self.creds[host] = ssh_creds
|
|
|
|
def provide(self, host: str, port: str) -> SSHClient:
|
|
if host not in self.creds:
|
|
raise RuntimeError(f"Please add credentials for host {host}")
|
|
|
|
if host in self.connections:
|
|
client = self.connections[host]
|
|
if client:
|
|
return client
|
|
|
|
creds = self.creds[host]
|
|
client = self._create_connection(host, port, creds)
|
|
self.connections[host] = client
|
|
return client
|
|
|
|
def drop(self, host: str):
|
|
if host in self.connections:
|
|
client = self.connections.pop(host)
|
|
client.close()
|
|
|
|
def drop_all(self):
|
|
hosts = list(self.connections.keys())
|
|
for host in hosts:
|
|
self.drop(host)
|
|
|
|
def _create_connection(
|
|
self,
|
|
host: str,
|
|
port: str,
|
|
creds: SshCredentials,
|
|
) -> SSHClient:
|
|
for attempt in range(self.SSH_CONNECTION_ATTEMPTS):
|
|
connection = SSHClient()
|
|
connection.set_missing_host_key_policy(AutoAddPolicy())
|
|
try:
|
|
if creds.ssh_key_path:
|
|
logger.info(
|
|
f"Trying to connect to host {host} as {creds.ssh_login} using SSH key "
|
|
f"{creds.ssh_key_path} (attempt {attempt})"
|
|
)
|
|
connection.connect(
|
|
hostname=host,
|
|
port=port,
|
|
username=creds.ssh_login,
|
|
pkey=_load_private_key(creds.ssh_key_path, creds.ssh_key_passphrase),
|
|
timeout=self.CONNECTION_TIMEOUT,
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Trying to connect to host {host} as {creds.ssh_login} using password " f"(attempt {attempt})"
|
|
)
|
|
connection.connect(
|
|
hostname=host,
|
|
port=port,
|
|
username=creds.ssh_login,
|
|
password=creds.ssh_password,
|
|
timeout=self.CONNECTION_TIMEOUT,
|
|
)
|
|
return connection
|
|
except AuthenticationException:
|
|
connection.close()
|
|
logger.exception(f"Can't connect to host {host}")
|
|
raise
|
|
except (
|
|
SSHException,
|
|
ssh_exception.NoValidConnectionsError,
|
|
AttributeError,
|
|
socket.timeout,
|
|
OSError,
|
|
) as exc:
|
|
connection.close()
|
|
can_retry = attempt + 1 < self.SSH_CONNECTION_ATTEMPTS
|
|
if can_retry:
|
|
logger.warn(
|
|
f"Can't connect to host {host}, will retry after {self.SSH_ATTEMPTS_INTERVAL}s. Error: {exc}"
|
|
)
|
|
sleep(self.SSH_ATTEMPTS_INTERVAL)
|
|
continue
|
|
logger.exception(f"Can't connect to host {host}")
|
|
raise HostIsNotAvailable(host) from exc
|
|
|
|
|
|
class HostIsNotAvailable(Exception):
|
|
"""Raised when host is not reachable via SSH connection."""
|
|
|
|
def __init__(self, host: Optional[str] = None):
|
|
msg = f"Host {host} is not available"
|
|
super().__init__(msg)
|
|
|
|
|
|
def log_command(func):
|
|
@wraps(func)
|
|
def wrapper(shell: "SSHShell", command: str, options: CommandOptions, *args, **kwargs) -> CommandResult:
|
|
command_info = command.removeprefix("$ProgressPreference='SilentlyContinue'\n")
|
|
with reporter.step(command_info):
|
|
logger.info(f'Execute command "{command}" on "{shell.host}"')
|
|
|
|
start_time = datetime.utcnow()
|
|
result = func(shell, command, options, *args, **kwargs)
|
|
end_time = datetime.utcnow()
|
|
|
|
elapsed_time = end_time - start_time
|
|
log_message = (
|
|
f"HOST: {shell.host}\n"
|
|
f"COMMAND:\n{textwrap.indent(command, ' ')}\n"
|
|
f"RC:\n {result.return_code}\n"
|
|
f"STDOUT:\n{textwrap.indent(result.stdout, ' ')}\n"
|
|
f"STDERR:\n{textwrap.indent(result.stderr, ' ')}\n"
|
|
f"Start / End / Elapsed\t {start_time.time()} / {end_time.time()} / {elapsed_time}"
|
|
)
|
|
|
|
if not options.no_log:
|
|
logger.info(log_message)
|
|
|
|
reporter.attach(log_message, "SSH command.txt")
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
@lru_cache
|
|
def _load_private_key(file_path: str, password: Optional[str]) -> PKey:
|
|
"""Loads private key from specified file.
|
|
|
|
We support several type formats, however paramiko doesn't provide functionality to determine
|
|
key type in advance. So we attempt to load file with each of the supported formats and then
|
|
cache the result so that we don't need to figure out type again on subsequent calls.
|
|
"""
|
|
logger.debug(f"Loading ssh key from {file_path}")
|
|
for key_type in (Ed25519Key, ECDSAKey, RSAKey):
|
|
try:
|
|
return key_type.from_private_key_file(file_path, password)
|
|
except SSHException as ex:
|
|
logger.warn(f"SSH key {file_path} can't be loaded with {key_type}: {ex}")
|
|
continue
|
|
raise SSHException(f"SSH key {file_path} is not supported")
|
|
|
|
|
|
class SSHShell(Shell):
|
|
"""Implements command shell on a remote machine via SSH connection."""
|
|
|
|
# Time in seconds to delay after remote command has completed. The delay is required
|
|
# to allow remote command to flush its output buffer
|
|
DELAY_AFTER_EXIT = 0.2
|
|
|
|
def __init__(
|
|
self,
|
|
host: str,
|
|
login: str,
|
|
password: Optional[str] = None,
|
|
private_key_path: Optional[str] = None,
|
|
private_key_passphrase: Optional[str] = None,
|
|
port: str = "22",
|
|
command_inspectors: Optional[list[CommandInspector]] = None,
|
|
custom_environment: Optional[dict] = None
|
|
) -> None:
|
|
super().__init__()
|
|
self.connection_provider = SshConnectionProvider()
|
|
self.connection_provider.store_creds(
|
|
host, SshCredentials(login, password, private_key_path, private_key_passphrase)
|
|
)
|
|
self.host = host
|
|
self.port = port
|
|
|
|
self.command_inspectors = command_inspectors or []
|
|
|
|
self.environment = custom_environment
|
|
|
|
@property
|
|
def _connection(self):
|
|
return self.connection_provider.provide(self.host, self.port)
|
|
|
|
def drop(self):
|
|
self.connection_provider.drop(self.host)
|
|
|
|
def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult:
|
|
options = options or CommandOptions()
|
|
|
|
original_command = command
|
|
extra_inspectors = options.extra_inspectors if options.extra_inspectors else []
|
|
for inspector in [*self.command_inspectors, *extra_inspectors]:
|
|
command = inspector.inspect(original_command, command)
|
|
|
|
if options.interactive_inputs:
|
|
result = self._exec_interactive(command, options)
|
|
else:
|
|
result = self._exec_non_interactive(command, options)
|
|
|
|
if options.check and result.return_code != 0:
|
|
raise RuntimeError(
|
|
f"Command: {command}\nreturn code: {result.return_code}\nOutput: {result.stdout}\nStderr: {result.stderr}\n"
|
|
)
|
|
return result
|
|
|
|
@log_command
|
|
def _exec_interactive(self, command: str, options: CommandOptions) -> CommandResult:
|
|
stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout, get_pty=True, environment=self.environment)
|
|
for interactive_input in options.interactive_inputs:
|
|
input = interactive_input.input
|
|
if not input.endswith("\n"):
|
|
input = f"{input}\n"
|
|
try:
|
|
stdin.write(input)
|
|
except OSError:
|
|
logger.exception(f"Error while feeding {input} into command {command}")
|
|
|
|
if options.close_stdin:
|
|
stdin.close()
|
|
sleep(self.DELAY_AFTER_EXIT)
|
|
|
|
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
|
|
return_code = stdout.channel.recv_exit_status()
|
|
|
|
result = CommandResult(
|
|
stdout=decoded_stdout,
|
|
stderr=decoded_stderr,
|
|
return_code=return_code,
|
|
)
|
|
return result
|
|
|
|
@log_command
|
|
def _exec_non_interactive(self, command: str, options: CommandOptions) -> CommandResult:
|
|
try:
|
|
stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout, environment=self.environment)
|
|
|
|
if options.close_stdin:
|
|
stdin.close()
|
|
|
|
decoded_stdout, decoded_stderr = self._read_channels(stdout.channel, stderr.channel)
|
|
return_code = stdout.channel.recv_exit_status()
|
|
|
|
return CommandResult(
|
|
stdout=decoded_stdout,
|
|
stderr=decoded_stderr,
|
|
return_code=return_code,
|
|
)
|
|
except (
|
|
SSHException,
|
|
TimeoutError,
|
|
ssh_exception.NoValidConnectionsError,
|
|
ConnectionResetError,
|
|
AttributeError,
|
|
socket.timeout,
|
|
) as exc:
|
|
logger.exception(f"Can't execute command {command} on host: {self.host}")
|
|
self.drop()
|
|
raise HostIsNotAvailable(self.host) from exc
|
|
|
|
def _read_channels(
|
|
self,
|
|
stdout: Channel,
|
|
stderr: Channel,
|
|
chunk_size: int = 4096,
|
|
) -> Tuple[str, str]:
|
|
"""Reads data from stdout/stderr channels.
|
|
|
|
Reading channels is required before we wait for exit status of the remote process.
|
|
Otherwise waiting step will hang indefinitely, see the warning from paramiko docs:
|
|
# https://docs.paramiko.org/en/stable/api/channel.html#paramiko.channel.Channel.recv_exit_status
|
|
|
|
Args:
|
|
stdout: Channel of stdout stream of the remote process.
|
|
stderr: Channel of stderr stream of the remote process.
|
|
chunk_size: Max size of data chunk that we read from channel at a time.
|
|
|
|
Returns:
|
|
Tuple with stdout and stderr channels decoded into strings.
|
|
"""
|
|
# We read data in chunks
|
|
stdout_chunks = []
|
|
stderr_chunks = []
|
|
|
|
# Read from channels (if data is ready) until process exits
|
|
while not stdout.exit_status_ready():
|
|
if stdout.recv_ready():
|
|
stdout_chunks.append(stdout.recv(chunk_size))
|
|
if stderr.recv_stderr_ready():
|
|
stderr_chunks.append(stderr.recv_stderr(chunk_size))
|
|
|
|
# Wait for command to complete and flush its buffer before we read final output
|
|
sleep(self.DELAY_AFTER_EXIT)
|
|
|
|
# Read the remaining data from the channels:
|
|
# If channel returns empty data chunk, it means that all data has been read
|
|
while True:
|
|
data_chunk = stdout.recv(chunk_size)
|
|
if not data_chunk:
|
|
break
|
|
stdout_chunks.append(data_chunk)
|
|
while True:
|
|
data_chunk = stderr.recv_stderr(chunk_size)
|
|
if not data_chunk:
|
|
break
|
|
stderr_chunks.append(data_chunk)
|
|
|
|
# Combine chunks and decode results into regular strings
|
|
full_stdout = b"".join(stdout_chunks)
|
|
full_stderr = b"".join(stderr_chunks)
|
|
|
|
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
|