Keep only one ssh connection per host

Signed-off-by: Andrey Berezin <a.berezin@yadro.com>
This commit is contained in:
Andrey Berezin 2023-10-10 17:47:46 +03:00
parent d039bcc221
commit 2c2af7f8ed
4 changed files with 261 additions and 144 deletions

View file

@ -1,3 +1,3 @@
from frostfs_testlib.shell.interfaces import CommandOptions, CommandResult, InteractiveInput, Shell
from frostfs_testlib.shell.local_shell import LocalShell
from frostfs_testlib.shell.ssh_shell import SSHShell
from frostfs_testlib.shell.ssh_shell import SshConnectionProvider, SSHShell

View file

@ -20,12 +20,117 @@ from paramiko import (
from paramiko.ssh_exception import AuthenticationException
from frostfs_testlib.reporter import get_reporter
from frostfs_testlib.shell.interfaces import CommandInspector, CommandOptions, CommandResult, Shell
from frostfs_testlib.shell.interfaces import (
CommandInspector,
CommandOptions,
CommandResult,
Shell,
SshCredentials,
)
logger = logging.getLogger("frostfs.testlib.shell")
reporter = get_reporter()
class SshConnectionProvider:
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4
SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10
CONNECTION_TIMEOUT = 60
instance = None
connections: dict[str, SSHClient] = {}
creds: dict[str, SshCredentials] = {}
def __new__(cls):
if not cls.instance:
cls.instance = super(SshConnectionProvider, cls).__new__(cls)
return cls.instance
def store_creds(self, host: str, ssh_creds: SshCredentials):
self.creds[host] = ssh_creds
def provide(self, host: str, port: str) -> SSHClient:
if host not in self.creds:
raise RuntimeError(f"Please add credentials for host {host}")
if host in self.connections:
client = self.connections[host]
if client:
return client
creds = self.creds[host]
client = self._create_connection(host, port, creds)
self.connections[host] = client
return client
def drop(self, host: str):
if host in self.connections:
client = self.connections.pop(host)
client.close()
def drop_all(self):
hosts = list(self.connections.keys())
for host in hosts:
self.drop(host)
def _create_connection(
self,
host: str,
port: str,
creds: SshCredentials,
) -> SSHClient:
for attempt in range(self.SSH_CONNECTION_ATTEMPTS):
connection = SSHClient()
connection.set_missing_host_key_policy(AutoAddPolicy())
try:
if creds.ssh_key_path:
logger.info(
f"Trying to connect to host {host} as {creds.ssh_login} using SSH key "
f"{creds.ssh_key_path} (attempt {attempt})"
)
connection.connect(
hostname=host,
port=port,
username=creds.ssh_login,
pkey=_load_private_key(creds.ssh_key_path, creds.ssh_key_passphrase),
timeout=self.CONNECTION_TIMEOUT,
)
else:
logger.info(
f"Trying to connect to host {host} as {creds.ssh_login} using password "
f"(attempt {attempt})"
)
connection.connect(
hostname=host,
port=port,
username=creds.ssh_login,
password=creds.ssh_password,
timeout=self.CONNECTION_TIMEOUT,
)
return connection
except AuthenticationException:
connection.close()
logger.exception(f"Can't connect to host {host}")
raise
except (
SSHException,
ssh_exception.NoValidConnectionsError,
AttributeError,
socket.timeout,
OSError,
) as exc:
connection.close()
can_retry = attempt + 1 < self.SSH_CONNECTION_ATTEMPTS
if can_retry:
logger.warn(
f"Can't connect to host {host}, will retry after {self.SSH_ATTEMPTS_INTERVAL}s. Error: {exc}"
)
sleep(self.SSH_ATTEMPTS_INTERVAL)
continue
logger.exception(f"Can't connect to host {host}")
raise HostIsNotAvailable(host) from exc
class HostIsNotAvailable(Exception):
"""Raised when host is not reachable via SSH connection."""
@ -91,10 +196,6 @@ class SSHShell(Shell):
# to allow remote command to flush its output buffer
DELAY_AFTER_EXIT = 0.2
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4
SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10
CONNECTION_TIMEOUT = 60
def __init__(
self,
host: str,
@ -106,23 +207,21 @@ class SSHShell(Shell):
command_inspectors: Optional[list[CommandInspector]] = None,
) -> None:
super().__init__()
self.connection_provider = SshConnectionProvider()
self.connection_provider.store_creds(
host, SshCredentials(login, password, private_key_path, private_key_passphrase)
)
self.host = host
self.port = port
self.login = login
self.password = password
self.private_key_path = private_key_path
self.private_key_passphrase = private_key_passphrase
self.command_inspectors = command_inspectors or []
self.__connection: Optional[SSHClient] = None
@property
def _connection(self):
if not self.__connection:
self.__connection = self._create_connection()
return self.__connection
return self.connection_provider.provide(self.host, self.port)
def drop(self):
self._reset_connection()
self.connection_provider.drop(self.host)
def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult:
options = options or CommandOptions()
@ -196,7 +295,7 @@ class SSHShell(Shell):
socket.timeout,
) as exc:
logger.exception(f"Can't execute command {command} on host: {self.host}")
self._reset_connection()
self.drop()
raise HostIsNotAvailable(self.host) from exc
def _read_channels(
@ -251,62 +350,3 @@ class SSHShell(Shell):
full_stderr = b"".join(stderr_chunks)
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
def _create_connection(
self, attempts: int = SSH_CONNECTION_ATTEMPTS, interval: int = SSH_ATTEMPTS_INTERVAL
) -> SSHClient:
for attempt in range(attempts):
connection = SSHClient()
connection.set_missing_host_key_policy(AutoAddPolicy())
try:
if self.private_key_path:
logger.info(
f"Trying to connect to host {self.host} as {self.login} using SSH key "
f"{self.private_key_path} (attempt {attempt})"
)
connection.connect(
hostname=self.host,
port=self.port,
username=self.login,
pkey=_load_private_key(self.private_key_path, self.private_key_passphrase),
timeout=self.CONNECTION_TIMEOUT,
)
else:
logger.info(
f"Trying to connect to host {self.host} as {self.login} using password "
f"(attempt {attempt})"
)
connection.connect(
hostname=self.host,
port=self.port,
username=self.login,
password=self.password,
timeout=self.CONNECTION_TIMEOUT,
)
return connection
except AuthenticationException:
connection.close()
logger.exception(f"Can't connect to host {self.host}")
raise
except (
SSHException,
ssh_exception.NoValidConnectionsError,
AttributeError,
socket.timeout,
OSError,
) as exc:
connection.close()
can_retry = attempt + 1 < attempts
if can_retry:
logger.warn(
f"Can't connect to host {self.host}, will retry after {interval}s. Error: {exc}"
)
sleep(interval)
continue
logger.exception(f"Can't connect to host {self.host}")
raise HostIsNotAvailable(self.host) from exc
def _reset_connection(self) -> None:
if self.__connection:
self.__connection.close()
self.__connection = None

View file

@ -3,7 +3,7 @@ import time
import frostfs_testlib.resources.optionals as optionals
from frostfs_testlib.reporter import get_reporter
from frostfs_testlib.shell import CommandOptions, Shell
from frostfs_testlib.shell import CommandOptions, Shell, SshConnectionProvider
from frostfs_testlib.steps.network import IfUpDownHelper, IpTablesHelper
from frostfs_testlib.storage.cluster import Cluster, ClusterNode, StorageNode
from frostfs_testlib.storage.controllers.disk_controller import DiskController
@ -37,6 +37,10 @@ class ClusterStateController:
@run_optionally(optionals.OPTIONAL_FAILOVER_ENABLED)
@reporter.step_deco("Stop host of node {node}")
def stop_node_host(self, node: ClusterNode, mode: str):
# Drop ssh connection for this node before shutdown
provider = SshConnectionProvider()
provider.drop(node.host_ip)
with reporter.step(f"Stop host {node.host.config.address}"):
node.host.stop_host(mode=mode)
wait_for_host_offline(self.shell, node.storage_node)
@ -48,6 +52,11 @@ class ClusterStateController:
nodes = (
reversed(self.cluster.cluster_nodes) if reversed_order else self.cluster.cluster_nodes
)
# Drop all ssh connections before shutdown
provider = SshConnectionProvider()
provider.drop_all()
for node in nodes:
with reporter.step(f"Stop host {node.host.config.address}"):
self.stopped_nodes.append(node)
@ -307,6 +316,10 @@ class ClusterStateController:
options = CommandOptions(close_stdin=True, timeout=1, check=False)
shell.exec('sudo sh -c "echo b > /proc/sysrq-trigger"', options)
# Drop ssh connection for this node
provider = SshConnectionProvider()
provider.drop(node.host_ip)
if wait_for_return:
# Let the things to be settled
# A little wait here to prevent ssh stuck during panic