forked from TrueCloudLab/frostfs-testlib
Keep only one ssh connection per host
Signed-off-by: Andrey Berezin <a.berezin@yadro.com>
This commit is contained in:
parent
d039bcc221
commit
2c2af7f8ed
4 changed files with 261 additions and 144 deletions
|
@ -1,3 +1,3 @@
|
|||
from frostfs_testlib.shell.interfaces import CommandOptions, CommandResult, InteractiveInput, Shell
|
||||
from frostfs_testlib.shell.local_shell import LocalShell
|
||||
from frostfs_testlib.shell.ssh_shell import SSHShell
|
||||
from frostfs_testlib.shell.ssh_shell import SshConnectionProvider, SSHShell
|
||||
|
|
|
@ -20,12 +20,117 @@ from paramiko import (
|
|||
from paramiko.ssh_exception import AuthenticationException
|
||||
|
||||
from frostfs_testlib.reporter import get_reporter
|
||||
from frostfs_testlib.shell.interfaces import CommandInspector, CommandOptions, CommandResult, Shell
|
||||
from frostfs_testlib.shell.interfaces import (
|
||||
CommandInspector,
|
||||
CommandOptions,
|
||||
CommandResult,
|
||||
Shell,
|
||||
SshCredentials,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("frostfs.testlib.shell")
|
||||
reporter = get_reporter()
|
||||
|
||||
|
||||
class SshConnectionProvider:
|
||||
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4
|
||||
SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10
|
||||
CONNECTION_TIMEOUT = 60
|
||||
|
||||
instance = None
|
||||
connections: dict[str, SSHClient] = {}
|
||||
creds: dict[str, SshCredentials] = {}
|
||||
|
||||
def __new__(cls):
|
||||
if not cls.instance:
|
||||
cls.instance = super(SshConnectionProvider, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def store_creds(self, host: str, ssh_creds: SshCredentials):
|
||||
self.creds[host] = ssh_creds
|
||||
|
||||
def provide(self, host: str, port: str) -> SSHClient:
|
||||
if host not in self.creds:
|
||||
raise RuntimeError(f"Please add credentials for host {host}")
|
||||
|
||||
if host in self.connections:
|
||||
client = self.connections[host]
|
||||
if client:
|
||||
return client
|
||||
|
||||
creds = self.creds[host]
|
||||
client = self._create_connection(host, port, creds)
|
||||
self.connections[host] = client
|
||||
return client
|
||||
|
||||
def drop(self, host: str):
|
||||
if host in self.connections:
|
||||
client = self.connections.pop(host)
|
||||
client.close()
|
||||
|
||||
def drop_all(self):
|
||||
hosts = list(self.connections.keys())
|
||||
for host in hosts:
|
||||
self.drop(host)
|
||||
|
||||
def _create_connection(
|
||||
self,
|
||||
host: str,
|
||||
port: str,
|
||||
creds: SshCredentials,
|
||||
) -> SSHClient:
|
||||
for attempt in range(self.SSH_CONNECTION_ATTEMPTS):
|
||||
connection = SSHClient()
|
||||
connection.set_missing_host_key_policy(AutoAddPolicy())
|
||||
try:
|
||||
if creds.ssh_key_path:
|
||||
logger.info(
|
||||
f"Trying to connect to host {host} as {creds.ssh_login} using SSH key "
|
||||
f"{creds.ssh_key_path} (attempt {attempt})"
|
||||
)
|
||||
connection.connect(
|
||||
hostname=host,
|
||||
port=port,
|
||||
username=creds.ssh_login,
|
||||
pkey=_load_private_key(creds.ssh_key_path, creds.ssh_key_passphrase),
|
||||
timeout=self.CONNECTION_TIMEOUT,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Trying to connect to host {host} as {creds.ssh_login} using password "
|
||||
f"(attempt {attempt})"
|
||||
)
|
||||
connection.connect(
|
||||
hostname=host,
|
||||
port=port,
|
||||
username=creds.ssh_login,
|
||||
password=creds.ssh_password,
|
||||
timeout=self.CONNECTION_TIMEOUT,
|
||||
)
|
||||
return connection
|
||||
except AuthenticationException:
|
||||
connection.close()
|
||||
logger.exception(f"Can't connect to host {host}")
|
||||
raise
|
||||
except (
|
||||
SSHException,
|
||||
ssh_exception.NoValidConnectionsError,
|
||||
AttributeError,
|
||||
socket.timeout,
|
||||
OSError,
|
||||
) as exc:
|
||||
connection.close()
|
||||
can_retry = attempt + 1 < self.SSH_CONNECTION_ATTEMPTS
|
||||
if can_retry:
|
||||
logger.warn(
|
||||
f"Can't connect to host {host}, will retry after {self.SSH_ATTEMPTS_INTERVAL}s. Error: {exc}"
|
||||
)
|
||||
sleep(self.SSH_ATTEMPTS_INTERVAL)
|
||||
continue
|
||||
logger.exception(f"Can't connect to host {host}")
|
||||
raise HostIsNotAvailable(host) from exc
|
||||
|
||||
|
||||
class HostIsNotAvailable(Exception):
|
||||
"""Raised when host is not reachable via SSH connection."""
|
||||
|
||||
|
@ -91,10 +196,6 @@ class SSHShell(Shell):
|
|||
# to allow remote command to flush its output buffer
|
||||
DELAY_AFTER_EXIT = 0.2
|
||||
|
||||
SSH_CONNECTION_ATTEMPTS: ClassVar[int] = 4
|
||||
SSH_ATTEMPTS_INTERVAL: ClassVar[int] = 10
|
||||
CONNECTION_TIMEOUT = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
|
@ -106,23 +207,21 @@ class SSHShell(Shell):
|
|||
command_inspectors: Optional[list[CommandInspector]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.connection_provider = SshConnectionProvider()
|
||||
self.connection_provider.store_creds(
|
||||
host, SshCredentials(login, password, private_key_path, private_key_passphrase)
|
||||
)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.login = login
|
||||
self.password = password
|
||||
self.private_key_path = private_key_path
|
||||
self.private_key_passphrase = private_key_passphrase
|
||||
|
||||
self.command_inspectors = command_inspectors or []
|
||||
self.__connection: Optional[SSHClient] = None
|
||||
|
||||
@property
|
||||
def _connection(self):
|
||||
if not self.__connection:
|
||||
self.__connection = self._create_connection()
|
||||
return self.__connection
|
||||
return self.connection_provider.provide(self.host, self.port)
|
||||
|
||||
def drop(self):
|
||||
self._reset_connection()
|
||||
self.connection_provider.drop(self.host)
|
||||
|
||||
def exec(self, command: str, options: Optional[CommandOptions] = None) -> CommandResult:
|
||||
options = options or CommandOptions()
|
||||
|
@ -196,7 +295,7 @@ class SSHShell(Shell):
|
|||
socket.timeout,
|
||||
) as exc:
|
||||
logger.exception(f"Can't execute command {command} on host: {self.host}")
|
||||
self._reset_connection()
|
||||
self.drop()
|
||||
raise HostIsNotAvailable(self.host) from exc
|
||||
|
||||
def _read_channels(
|
||||
|
@ -251,62 +350,3 @@ class SSHShell(Shell):
|
|||
full_stderr = b"".join(stderr_chunks)
|
||||
|
||||
return (full_stdout.decode(errors="ignore"), full_stderr.decode(errors="ignore"))
|
||||
|
||||
def _create_connection(
|
||||
self, attempts: int = SSH_CONNECTION_ATTEMPTS, interval: int = SSH_ATTEMPTS_INTERVAL
|
||||
) -> SSHClient:
|
||||
for attempt in range(attempts):
|
||||
connection = SSHClient()
|
||||
connection.set_missing_host_key_policy(AutoAddPolicy())
|
||||
try:
|
||||
if self.private_key_path:
|
||||
logger.info(
|
||||
f"Trying to connect to host {self.host} as {self.login} using SSH key "
|
||||
f"{self.private_key_path} (attempt {attempt})"
|
||||
)
|
||||
connection.connect(
|
||||
hostname=self.host,
|
||||
port=self.port,
|
||||
username=self.login,
|
||||
pkey=_load_private_key(self.private_key_path, self.private_key_passphrase),
|
||||
timeout=self.CONNECTION_TIMEOUT,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Trying to connect to host {self.host} as {self.login} using password "
|
||||
f"(attempt {attempt})"
|
||||
)
|
||||
connection.connect(
|
||||
hostname=self.host,
|
||||
port=self.port,
|
||||
username=self.login,
|
||||
password=self.password,
|
||||
timeout=self.CONNECTION_TIMEOUT,
|
||||
)
|
||||
return connection
|
||||
except AuthenticationException:
|
||||
connection.close()
|
||||
logger.exception(f"Can't connect to host {self.host}")
|
||||
raise
|
||||
except (
|
||||
SSHException,
|
||||
ssh_exception.NoValidConnectionsError,
|
||||
AttributeError,
|
||||
socket.timeout,
|
||||
OSError,
|
||||
) as exc:
|
||||
connection.close()
|
||||
can_retry = attempt + 1 < attempts
|
||||
if can_retry:
|
||||
logger.warn(
|
||||
f"Can't connect to host {self.host}, will retry after {interval}s. Error: {exc}"
|
||||
)
|
||||
sleep(interval)
|
||||
continue
|
||||
logger.exception(f"Can't connect to host {self.host}")
|
||||
raise HostIsNotAvailable(self.host) from exc
|
||||
|
||||
def _reset_connection(self) -> None:
|
||||
if self.__connection:
|
||||
self.__connection.close()
|
||||
self.__connection = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue