From 94a0d3a85aed4fa54bde11a770aff9f278f836bc Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 11 Apr 2024 11:02:58 -0400 Subject: [PATCH 1/4] Add new `mila login` command Signed-off-by: Fabrice Normandin --- milatools/cli/commands.py | 13 ++++- milatools/cli/login.py | 47 +++++++++++++++++++ tests/cli/test_commands/test_help_mila_.txt | 6 ++- .../test_invalid_command_output_mila_.txt | 2 +- ...alid_command_output_mila_search_conda_.txt | 4 +- tests/cli/test_login.py | 38 +++++++++++++++ 6 files changed, 103 insertions(+), 7 deletions(-) create mode 100644 milatools/cli/login.py create mode 100644 tests/cli/test_login.py diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index c38b4e88..1e8a78fa 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -32,12 +32,12 @@ from typing_extensions import TypedDict from milatools.cli import console +from milatools.cli.login import login from milatools.utils.local_v1 import LocalV1 from milatools.utils.remote_v1 import RemoteV1, SlurmRemote -from milatools.utils.remote_v2 import RemoteV2 +from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2 from milatools.utils.vscode_utils import ( get_code_command, - # install_local_vscode_extensions_on_remote, sync_vscode_extensions, sync_vscode_extensions_with_hostnames, ) @@ -170,6 +170,15 @@ def mila(): init_parser.set_defaults(function=init) + # ----- mila login ------ + login_parser = subparsers.add_parser( + "login", + help="Sets up reusable SSH connections to the entries of the SSH config.", + formatter_class=SortingHelpFormatter, + ) + login_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE) + login_parser.set_defaults(function=login) + # ----- mila forward ------ forward_parser = subparsers.add_parser( diff --git a/milatools/cli/login.py b/milatools/cli/login.py new file mode 100644 index 00000000..9a5bfcc2 --- /dev/null +++ b/milatools/cli/login.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +from paramiko import SSHConfig + +from milatools.cli import console +from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2 + + +async def login( + ssh_config_path: Path = SSH_CONFIG_FILE, +) -> list[RemoteV2]: + """Logs in and sets up reusable SSH connections to all the hosts in the SSH config. + + Returns the list of remotes where the connection was successfully established. + """ + ssh_config = SSHConfig.from_path(str(ssh_config_path.expanduser())) + potential_clusters = [ + host + for host in ssh_config.get_hostnames() + if not any(c in host for c in ["*", "?", "!"]) + ] + # take out entries like `mila-cpu` that have a proxy and remote command. + potential_clusters = [ + hostname + for hostname in potential_clusters + if not ( + (config := ssh_config.lookup(hostname)).get("proxycommand") + and config.get("remotecommand") + ) + ] + remotes = await asyncio.gather( + *( + RemoteV2.connect(hostname, ssh_config_path=ssh_config_path) + for hostname in potential_clusters + ), + return_exceptions=True, + ) + remotes = [remote for remote in remotes if isinstance(remote, RemoteV2)] + console.log(f"Successfully connected to {[remote.hostname for remote in remotes]}") + return remotes + + +if __name__ == "__main__": + asyncio.run(login()) diff --git a/tests/cli/test_commands/test_help_mila_.txt b/tests/cli/test_commands/test_help_mila_.txt index 88b5da3e..9b09c790 100644 --- a/tests/cli/test_commands/test_help_mila_.txt +++ b/tests/cli/test_commands/test_help_mila_.txt @@ -1,14 +1,16 @@ usage: mila [-h] [--version] [-v] - {docs,intranet,init,forward,code,sync,serve} ... + {docs,intranet,init,login,forward,code,sync,serve} ... Tools to connect to and interact with the Mila cluster. Cluster documentation: https://docs.mila.quebec/ positional arguments: - {docs,intranet,init,forward,code,sync,serve} + {docs,intranet,init,login,forward,code,sync,serve} docs Open the Mila cluster documentation. intranet Open the Mila intranet in a browser. init Set up your configuration and credentials. + login Sets up reusable SSH connections to the entries of the + SSH config. forward Forward a port on a compute node to your local machine. code Open a remote VSCode session on a compute node. diff --git a/tests/cli/test_commands/test_invalid_command_output_mila_.txt b/tests/cli/test_commands/test_invalid_command_output_mila_.txt index fa5c4b84..3ebe3e53 100644 --- a/tests/cli/test_commands/test_invalid_command_output_mila_.txt +++ b/tests/cli/test_commands/test_invalid_command_output_mila_.txt @@ -1,3 +1,3 @@ usage: mila [-h] [--version] [-v] - {docs,intranet,init,forward,code,sync,serve} ... + {docs,intranet,init,login,forward,code,sync,serve} ... mila: error: the following arguments are required: diff --git a/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt b/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt index 725b21ac..fa96195d 100644 --- a/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt +++ b/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt @@ -1,3 +1,3 @@ usage: mila [-h] [--version] [-v] - {docs,intranet,init,forward,code,sync,serve} ... -mila: error: argument : invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'forward', 'code', 'sync', 'serve') + {docs,intranet,init,login,forward,code,sync,serve} ... +mila: error: argument : invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'login', 'forward', 'code', 'sync', 'serve') diff --git a/tests/cli/test_login.py b/tests/cli/test_login.py new file mode 100644 index 00000000..b0a3f5eb --- /dev/null +++ b/tests/cli/test_login.py @@ -0,0 +1,38 @@ +import textwrap +from logging import getLogger as get_logger +from pathlib import Path + +import pytest + +from milatools.cli.login import login +from milatools.utils.remote_v2 import SSH_CACHE_DIR, RemoteV2 + +from .common import requires_ssh_to_localhost + +logger = get_logger(__name__) + + +@requires_ssh_to_localhost +@pytest.mark.asyncio +async def test_login(tmp_path: Path): # ssh_config_file: Path): + assert SSH_CACHE_DIR.exists() + ssh_config_path = tmp_path / "ssh_config" + ssh_config_path.write_text( + textwrap.dedent( + """\ + Host foo + hostname localhost + Host bar + hostname localhost + """ + ) + + "\n" + ) + + # Should create a connection to every host in the ssh config file. + remotes = await login(ssh_config_path=ssh_config_path) + assert all(isinstance(remote, RemoteV2) for remote in remotes) + assert set(remote.hostname for remote in remotes) == {"foo", "bar"} + for remote in remotes: + logger.info(f"Removing control socket at {remote.control_path}") + remote.control_path.unlink() From a02bdce2b3dd6209c19ab47e259ff7887d6c94fa Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 15 Apr 2024 13:13:46 -0400 Subject: [PATCH 2/4] Add new `mila run` command (based on `login`) Signed-off-by: Fabrice Normandin --- milatools/cli/commands.py | 13 ++++++ milatools/cli/login.py | 5 ++- milatools/cli/run.py | 95 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 milatools/cli/run.py diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 1e8a78fa..14bd9a96 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -33,6 +33,7 @@ from milatools.cli import console from milatools.cli.login import login +from milatools.cli.run import run_command from milatools.utils.local_v1 import LocalV1 from milatools.utils.remote_v1 import RemoteV1, SlurmRemote from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2 @@ -179,6 +180,18 @@ def mila(): login_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE) login_parser.set_defaults(function=login) + # ----- mila run ------ + run_parser = subparsers.add_parser( + "run", + help="Runs a command over SSH on all the slurm clusters in the SSH config.", + formatter_class=SortingHelpFormatter, + ) + run_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE) + run_parser.add_argument( + "command", type=str, nargs=argparse.REMAINDER, help="The command to run." + ) + run_parser.set_defaults(function=run_command) + # ----- mila forward ------ forward_parser = subparsers.add_parser( diff --git a/milatools/cli/login.py b/milatools/cli/login.py index 9a5bfcc2..2f41b857 100644 --- a/milatools/cli/login.py +++ b/milatools/cli/login.py @@ -6,6 +6,7 @@ from paramiko import SSHConfig from milatools.cli import console +from milatools.cli.utils import CLUSTERS from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2 @@ -22,10 +23,12 @@ async def login( for host in ssh_config.get_hostnames() if not any(c in host for c in ["*", "?", "!"]) ] - # take out entries like `mila-cpu` that have a proxy and remote command. potential_clusters = [ hostname for hostname in potential_clusters + if hostname in CLUSTERS + # TODO: make this more generic with something like this: + # take out entries like `mila-cpu` that have a proxy and remote command. if not ( (config := ssh_config.lookup(hostname)).get("proxycommand") and config.get("remotecommand") diff --git a/milatools/cli/run.py b/milatools/cli/run.py new file mode 100644 index 00000000..72e86c97 --- /dev/null +++ b/milatools/cli/run.py @@ -0,0 +1,95 @@ +import asyncio +import shlex +import sys +from pathlib import Path + +import rich +import rich.columns +import rich.live +import rich.table +import rich.text + +from milatools.cli import console +from milatools.cli.login import login +from milatools.cli.utils import SSH_CONFIG_FILE +from milatools.utils.remote_v2 import RemoteV2 + + +async def run_command( + command: str | list[str], ssh_config_path: Path = SSH_CONFIG_FILE +): + command = shlex.join(command) if isinstance(command, list) else command + if command.startswith("'") and command.endswith("'"): + # NOTE: Need to remove leading and trailing quotes so the ssh subprocess doesn't + # give an error. For example, with `mila run 'echo $SCRATCH'`, we would + # otherwise get the error: bash: echo: command not found + command = command[1:-1] + + remotes = await login(ssh_config_path=ssh_config_path) + + async def _is_slurm_cluster(remote: RemoteV2) -> bool: + sbatch_path = await remote.get_output_async( + "which sbatch", warn=True, hide=True, display=False + ) + return bool(sbatch_path) + + is_slurm_cluster = await asyncio.gather( + *(_is_slurm_cluster(remote) for remote in remotes), + ) + cluster_login_nodes = [ + remote + for remote, is_slurm_cluster in zip(remotes, is_slurm_cluster) + if is_slurm_cluster + ] + + results = await asyncio.gather( + *( + login_node.run_async(command=command, warn=True, display=True, hide=False) + for login_node in cluster_login_nodes + ) + ) + for remote, result in zip(cluster_login_nodes, results): + for line in result.stdout.splitlines(): + print(f"({remote.hostname}) {line}") + for line in result.stderr.splitlines(): + print(f"({remote.hostname}) {line}", file=sys.stderr) + + return results + + table = rich.table.Table(title=command) + table.add_column("Cluster") + + # need an stdout column. + need_stdout_column = any(result.stdout for result in results) + need_stderr_column = any(result.stderr for result in results) + + if not need_stderr_column and not need_stdout_column: + return results + + if need_stdout_column: + table.add_column("stdout") + if need_stderr_column: + table.add_column("stderr") + + for remote, result in zip(cluster_login_nodes, results): + row = [remote.hostname] + if need_stdout_column: + row.append(result.stdout) + if need_stderr_column: + row.append(result.stderr) + table.add_row(*row, end_section=True) + + console.print(table) + # table = rich.table.Table(title=command) + # with rich.live.Live(table, refresh_per_second=1): + + # async with asyncio.TaskGroup() as group: + # for remote in remotes: + # table.add_column(remote.hostname, no_wrap=True) + # task = group.create_task(remote.run_async(command)) + # task.add_done_callback(lambda _: table.add_row()) + return results + + +if __name__ == "main": + asyncio.run(run_command("hostname")) From 237a8efc62fd4a3b3a4b83c862afef47f7b48a16 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 16 Apr 2024 10:55:27 -0400 Subject: [PATCH 3/4] add `show_table` parameter and argument Signed-off-by: Fabrice Normandin --- milatools/cli/commands.py | 1 + milatools/cli/run.py | 30 +++++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 14bd9a96..bd804711 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -187,6 +187,7 @@ def mila(): formatter_class=SortingHelpFormatter, ) run_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE) + run_parser.add_argument("--show-table", action="store_true", default=False) run_parser.add_argument( "command", type=str, nargs=argparse.REMAINDER, help="The command to run." ) diff --git a/milatools/cli/run.py b/milatools/cli/run.py index 72e86c97..140ce926 100644 --- a/milatools/cli/run.py +++ b/milatools/cli/run.py @@ -1,5 +1,6 @@ import asyncio import shlex +import subprocess import sys from pathlib import Path @@ -16,7 +17,9 @@ async def run_command( - command: str | list[str], ssh_config_path: Path = SSH_CONFIG_FILE + command: str | list[str], + ssh_config_path: Path = SSH_CONFIG_FILE, + show_table: bool = False, ): command = shlex.join(command) if isinstance(command, list) else command if command.startswith("'") and command.endswith("'"): @@ -44,18 +47,36 @@ async def _is_slurm_cluster(remote: RemoteV2) -> bool: results = await asyncio.gather( *( - login_node.run_async(command=command, warn=True, display=True, hide=False) + login_node.run_async(command=command, warn=True, display=True, hide=True) for login_node in cluster_login_nodes ) ) + if show_table: + _print_with_table(command, cluster_login_nodes, results) + else: + _print_with_prefix(command, cluster_login_nodes, results) + return results + + +def _print_with_prefix( + command: str, + cluster_login_nodes: list[RemoteV2], + results: list[subprocess.CompletedProcess[str]], +): for remote, result in zip(cluster_login_nodes, results): for line in result.stdout.splitlines(): - print(f"({remote.hostname}) {line}") + console.print(f"[bold]({remote.hostname})[/bold] {line}", markup=True) for line in result.stderr.splitlines(): print(f"({remote.hostname}) {line}", file=sys.stderr) - return results + # return results + +def _print_with_table( + command: str, + cluster_login_nodes: list[RemoteV2], + results: list[subprocess.CompletedProcess[str]], +): table = rich.table.Table(title=command) table.add_column("Cluster") @@ -88,7 +109,6 @@ async def _is_slurm_cluster(remote: RemoteV2) -> bool: # table.add_column(remote.hostname, no_wrap=True) # task = group.create_task(remote.run_async(command)) # task.add_done_callback(lambda _: table.add_row()) - return results if __name__ == "main": From 54d9ed2bf732baff864d7c4e9aee4221629c8a50 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 18 Apr 2024 10:47:56 -0400 Subject: [PATCH 4/4] Add missing __future__ import for type hints Signed-off-by: Fabrice Normandin --- milatools/cli/run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/milatools/cli/run.py b/milatools/cli/run.py index 140ce926..d66f8cb9 100644 --- a/milatools/cli/run.py +++ b/milatools/cli/run.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import shlex import subprocess