diff --git a/src/frostfs_testlib/hosting/config.py b/src/frostfs_testlib/hosting/config.py index 88fe3e7..4ab66d7 100644 --- a/src/frostfs_testlib/hosting/config.py +++ b/src/frostfs_testlib/hosting/config.py @@ -67,6 +67,7 @@ class HostConfig: clis: list[CLIConfig] = field(default_factory=list) attributes: dict[str, str] = field(default_factory=dict) interfaces: dict[str, str] = field(default_factory=dict) + environment: dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: self.services = [ServiceConfig(**service) for service in self.services or []] diff --git a/src/frostfs_testlib/shell/command_inspectors.py b/src/frostfs_testlib/shell/command_inspectors.py index 0003017..8fe2f34 100644 --- a/src/frostfs_testlib/shell/command_inspectors.py +++ b/src/frostfs_testlib/shell/command_inspectors.py @@ -9,7 +9,7 @@ class SudoInspector(CommandInspector): def inspect(self, original_command: str, command: str) -> str: if not command.startswith("sudo"): - return f"sudo -i {command}" + return f"sudo {command}" return command diff --git a/src/frostfs_testlib/shell/ssh_shell.py b/src/frostfs_testlib/shell/ssh_shell.py index a7e6e1d..e718b4d 100644 --- a/src/frostfs_testlib/shell/ssh_shell.py +++ b/src/frostfs_testlib/shell/ssh_shell.py @@ -185,6 +185,7 @@ class SSHShell(Shell): private_key_passphrase: Optional[str] = None, port: str = "22", command_inspectors: Optional[list[CommandInspector]] = None, + custom_environment: Optional[dict] = None ) -> None: super().__init__() self.connection_provider = SshConnectionProvider() @@ -196,6 +197,8 @@ class SSHShell(Shell): self.command_inspectors = command_inspectors or [] + self.environment = custom_environment + @property def _connection(self): return self.connection_provider.provide(self.host, self.port) @@ -224,7 +227,7 @@ class SSHShell(Shell): @log_command def _exec_interactive(self, command: str, options: CommandOptions) -> CommandResult: - stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout, get_pty=True) + stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout, get_pty=True, environment=self.environment) for interactive_input in options.interactive_inputs: input = interactive_input.input if not input.endswith("\n"): @@ -251,7 +254,7 @@ class SSHShell(Shell): @log_command def _exec_non_interactive(self, command: str, options: CommandOptions) -> CommandResult: try: - stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout) + stdin, stdout, stderr = self._connection.exec_command(command, timeout=options.timeout, environment=self.environment) if options.close_stdin: stdin.close()