diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index fb59e22b..72255142 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -1,3 +1,11 @@ +"""Tools to connect to and interact with the Mila cluster. + +Cluster documentation: https://docs.mila.quebec/ +""" +from __future__ import annotations + +import argparse +import operator import os import re import shutil @@ -6,14 +14,17 @@ import sys import time import traceback +import typing import webbrowser +from argparse import ArgumentParser, _HelpAction from contextlib import ExitStack from pathlib import Path +from typing import Sequence from urllib.parse import urlencode import questionary as qn -from coleo import Option, auto_cli, default, tooled from invoke import UnexpectedExit +from typing_extensions import TypedDict from ..version import version as mversion from .init_command import setup_ssh_config @@ -32,17 +43,17 @@ yn, ) +if typing.TYPE_CHECKING: + from typing_extensions import Unpack + def main(): - """Entry point for milatools.""" on_mila = get_fully_qualified_name().endswith(".server.mila.quebec") if on_mila: - exit( - "ERROR: 'mila ...' should be run on your local machine and not on the Mila cluster" - ) + exit("ERROR: 'mila ...' should be run on your local machine and not on the Mila cluster") try: - auto_cli(milatools) + mila() except MilatoolsUserError as exc: # These are user errors and should not be reported print("ERROR:", exc, file=sys.stderr) @@ -54,16 +65,14 @@ def main(): options = { "labels": ",".join([sys.argv[1], mversion]), "template": "bug_report.md", - "title": f"[v{mversion}] Issue running the command `mila " - f"{sys.argv[1]}`", + "title": f"[v{mversion}] Issue running the command `mila " f"{sys.argv[1]}`", } github_issue_url = ( f"https://github.com/mila-iqia/milatools/issues/new?{urlencode(options)}" ) print( T.bold_yellow( - f"An error occured during the execution of the command " - f"`{sys.argv[1]}`. " + f"An error occured during the execution of the command " f"`{sys.argv[1]}`. " ) + T.yellow( "Please try updating milatools by running\n" @@ -73,486 +82,615 @@ def main(): ) + T.italic_yellow(github_issue_url) + T.yellow( - "\nPlease provide the error traceback with the report " - "(the red text above)." + "\nPlease provide the error traceback with the report " "(the red text above)." ), file=sys.stderr, ) exit(1) -class milatools: - """Tools to connect to and interact with the Mila cluster. +def mila(): + parser_kwargs = dict() + parser = ArgumentParser(prog="mila", description=__doc__, add_help=True, **parser_kwargs) + parser.add_argument("-v", "--version", action="version", version=f"milatools v{mversion}") - Cluster documentation: https://docs.mila.quebec/ - """ + subparsers = parser.add_subparsers( + dest="command", + required=True, + # parser_class=ArgumentParser, + ) - def __main__(): - # This path is triggered when no command is passed - - # Milatools version - # [alias: -v] - version: Option & bool = default(False) - - if version: - print(f"milatools v{mversion}") - - def docs(): - """Open the Mila cluster documentation.""" - # Search terms - # [remainder] - search: Option = default([]) - url = "https://docs.mila.quebec" - if search: - terms = "+".join(search) - url = f"{url}/search.html?q={terms}" - print(f"Opening the docs: {url}") - webbrowser.open(url) - - def intranet(): - """Open the Mila intranet in a browser.""" - # Search terms - # [remainder] - search: Option = default([]) - if search: - terms = "+".join(search) - url = f"https://sites.google.com/search/mila.quebec/mila-intranet?query={terms}&scope=site&showTabs=false" - else: - url = "https://intranet.mila.quebec" - print(f"Opening the intranet: {url}") - webbrowser.open(url) + docs_parser = subparsers.add_parser( + "docs", help="Open the Mila cluster documentation", formatter_class=SortingHelpFormatter + ) + docs_parser.add_argument("search", nargs=argparse.REMAINDER, help="Search terms") + docs_parser.set_defaults(function=docs) + + intranet_parser = subparsers.add_parser( + "intranet", help="Open the Mila intranet in a browser." + ) + intranet_parser.add_argument("search", nargs=argparse.REMAINDER, help="Search terms") + intranet_parser.set_defaults(function=intranet) - def init(): - """Set up your configuration and credentials.""" + init_parser = subparsers.add_parser( + "init", + help="Set up your configuration and credentials.", + formatter_class=SortingHelpFormatter, + ) + init_parser.set_defaults(function=init) - ############################# - # Step 1: SSH Configuration # - ############################# + forward_parser = subparsers.add_parser( + "forward", + help="Forward a port on a compute node to your local machine.", + formatter_class=SortingHelpFormatter, + ) + forward_parser.add_argument("remote", help="node:port to forward") + forward_parser.add_argument( + "--page", nargs="?", help="String to append after the URL", default=None + ) + forward_parser.add_argument("--port", type=int, help="Local port to forward to", default=None) + forward_parser.set_defaults(function=forward) - print("Checking ssh config") + code_parser = subparsers.add_parser( + "code", + help="Open a remote VSCode session on a compute node.", + formatter_class=SortingHelpFormatter, + ) + code_parser.add_argument("path", help="Path to open on the remote machine", type=str) + _add_find_allocation_args(code_parser) + code_parser.add_argument( + "--command", + default=os.environ.get("MILATOOLS_CODE_COMMAND", "code"), + help=( + "Command to use to start vscode\n" + "(defaults to 'code' or the value of $MILATOOLS_CODE_COMMAND)" + ), + ) + code_parser.add_argument( + "--persist", action="store_true", help="Whether the server should persist or not" + ) + code_parser.set_defaults(function=code) - setup_ssh_config() - # TODO: Move the rest of this command to functions in the init_command module, - # so they can more easily be tested. + serve_parser = subparsers.add_parser( + "serve", + help="Start services on compute nodes and forward them to your local machine.", + formatter_class=SortingHelpFormatter, + ) + serve_subparsers = serve_parser.add_subparsers(dest="serve_command", required=True) - print("# OK") + serve_connect_parser = serve_subparsers.add_parser( + "connect", + help="Reconnect to a persistent server.", + formatter_class=SortingHelpFormatter, + ) + serve_connect_parser.add_argument( + "identifier", type=str, help="Server identifier output by the original mila serve command" + ) + serve_connect_parser.set_defaults(function=connect) - ############################# - # Step 2: Passwordless auth # - ############################# + serve_kill_parser = serve_subparsers.add_parser( + "kill", + help="Kill a persistent server.", + formatter_class=SortingHelpFormatter, + ) + serve_kill_parser.add_argument( + "identifier", + type=str, + help="Server identifier output by the original mila serve command", + ) + serve_kill_parser.add_argument("--all", action="store_true", help="Kill all servers") + serve_kill_parser.set_defaults(function=kill) - print("Checking passwordless authentication") + serve_list_parser = serve_subparsers.add_parser( + "list", help="List active servers.", formatter_class=SortingHelpFormatter + ) + serve_list_parser.add_argument( + "--purge", action="store_true", help="Purge dead or invalid servers" + ) + serve_list_parser.set_defaults(function=serve_list) - here = Local() + serve_lab_parser = serve_subparsers.add_parser( + "lab", help="Start a Jupyterlab server.", formatter_class=SortingHelpFormatter + ) + serve_lab_parser.add_argument( + "path", default=None, nargs="?", help="Path to open on the remote machine" + ) + _add_standard_server_args(serve_lab_parser) + serve_lab_parser.set_defaults(function=lab) - # Check that there is an id file + serve_notebook_parser = serve_subparsers.add_parser( + "notebook", help="Start a Jupyter Notebook server.", formatter_class=SortingHelpFormatter + ) + serve_notebook_parser.add_argument( + "path", default=None, nargs="?", help="Path to open on the remote machine" + ) + _add_standard_server_args(serve_notebook_parser) + serve_notebook_parser.set_defaults(function=notebook) - sshdir = os.path.expanduser("~/.ssh") - if not any( - entry.startswith("id") and entry.endswith(".pub") - for entry in os.listdir(sshdir) - ): - if yn("You have no public keys. Generate one?"): - here.run("ssh-keygen") - else: - exit("No public keys.") + serve_tensorboard_parser = serve_subparsers.add_parser( + "tensorboard", help="Start a Tensorboard server.", formatter_class=SortingHelpFormatter + ) + serve_tensorboard_parser.add_argument("logdir", type=str, help="Path to the experiment logs") + _add_standard_server_args(serve_tensorboard_parser) + serve_notebook_parser.set_defaults(function=tensorboard) - # Check that it is possible to connect using the key + serve_mlflow_parser = serve_subparsers.add_parser( + "mlflow", help="Start an MLFlow server.", formatter_class=SortingHelpFormatter + ) + serve_mlflow_parser.add_argument("logdir", type=str, help="Path to the experiment logs") + _add_standard_server_args(serve_mlflow_parser) + serve_mlflow_parser.set_defaults(function=mlflow) - if not here.check_passwordless("mila"): - if yn( - "Your public key does not appear be registered on the cluster. Register it?" - ): - here.run("ssh-copy-id", "mila") - if not here.check_passwordless("mila"): - exit("ssh-copy-id appears to have failed") - else: - exit("No passwordless login.") + serve_aim_parser = serve_subparsers.add_parser( + "aim", help="Start an Aim server.", formatter_class=SortingHelpFormatter + ) + serve_aim_parser.add_argument("logdir", type=str, help="Path to the experiment logs") + _add_standard_server_args(serve_aim_parser) + serve_aim_parser.set_defaults(function=aim) + + args = parser.parse_args() + args_dict = vars(args) + function = args_dict.pop("function") + _ = args_dict.pop("command") + _ = args_dict.pop("serve_command", None) + assert callable(function) + return function(**args_dict) + + +def docs(search: Sequence[str]) -> None: + url = "https://docs.mila.quebec" + terms = "+".join(search) + url = f"{url}/search.html?q={terms}" + print(f"Opening the docs: {url}") + webbrowser.open(url) - ##################################### - # Step 3: Set up keys on login node # - ##################################### - print("Checking connection to compute nodes") +def intranet(search: Sequence[str]) -> None: + """Open the Mila intranet in a browser.""" + if search: + terms = "+".join(search) + url = f"https://sites.google.com/search/mila.quebec/mila-intranet?query={terms}&scope=site&showTabs=false" + else: + url = "https://intranet.mila.quebec" + print(f"Opening the intranet: {url}") + webbrowser.open(url) - remote = Remote("mila") - try: - pubkeys = remote.get_lines("ls -t ~/.ssh/id*.pub") - print("# OK") - except UnexpectedExit: - print("# MISSING") - if yn("You have no public keys on the login node. Generate them?"): - # print("(Note: You can just press Enter 3x to accept the defaults)") - # _, keyfile = remote.extract("ssh-keygen", pattern="Your public key has been saved in ([^ ]+)", wait=True) - private_file = "~/.ssh/id_rsa" - remote.run(f'ssh-keygen -q -t rsa -N "" -f {private_file}') - pubkeys = [f"{private_file}.pub"] - else: - exit("Cannot proceed because there is no public key") - common = remote.with_bash().get_output( - "comm -12 <(sort ~/.ssh/authorized_keys) <(sort ~/.ssh/*.pub)" - ) - if common: - print("# OK") +def init(): + """Set up your configuration and credentials.""" + + ############################# + # Step 1: SSH Configuration # + ############################# + + print("Checking ssh config") + + setup_ssh_config() + # TODO: Move the rest of this command to functions in the init_command module, + # so they can more easily be tested. + + print("# OK") + + ############################# + # Step 2: Passwordless auth # + ############################# + + print("Checking passwordless authentication") + + here = Local() + + # Check that there is an id file + + sshdir = os.path.expanduser("~/.ssh") + if not any(entry.startswith("id") and entry.endswith(".pub") for entry in os.listdir(sshdir)): + if yn("You have no public keys. Generate one?"): + here.run("ssh-keygen") else: - print("# MISSING") - if yn( - "To connect to a compute node from a login node you need one id_*.pub to be in " - "authorized_keys. Do it?" - ): - pubkey = pubkeys[0] - remote.run(f"cat {pubkey} >> ~/.ssh/authorized_keys") - else: - exit("You will not be able to SSH to a compute node") + exit("No public keys.") - ################### - # Welcome message # - ################### + # Check that it is possible to connect using the key - print(T.bold_cyan("=" * 60)) - print( - T.bold_cyan("Congrats! You are now ready to start working on the cluster!") - ) - print(T.bold_cyan("=" * 60)) - print(T.bold("To connect to a login node:")) - print(" ssh mila") - print(T.bold("To allocate and connect to a compute node:")) - print(" ssh mila-cpu") - print(T.bold("To open a directory on the cluster with VSCode:")) - print(" mila code path/to/code/on/cluster") - print(T.bold("Same as above, but allocate 1 GPU, 4 CPUs, 32G of RAM:")) - print( - " mila code path/to/code/on/cluster --alloc --gres=gpu:1 --mem=32G -c 4" - ) - print() - print( - "For more information, read the milatools documentation at", - T.bold_cyan("https://github.com/mila-iqia/milatools"), - "or run `mila --help`.", - "Also make sure you read the Mila cluster documentation at", - T.bold_cyan("https://docs.mila.quebec/"), - "and join the", - T.bold_green("#mila-cluster"), - "channel on Slack.", - ) + if not here.check_passwordless("mila"): + if yn("Your public key does not appear be registered on the cluster. Register it?"): + here.run("ssh-copy-id", "mila") + if not here.check_passwordless("mila"): + exit("ssh-copy-id appears to have failed") + else: + exit("No passwordless login.") - def forward(): - """Forward a port on a compute node to your local machine.""" + ##################################### + # Step 3: Set up keys on login node # + ##################################### - # node:port to forward - # [positional] - remote: Option + print("Checking connection to compute nodes") - node, remote_port = remote.split(":") - try: - remote_port = int(remote_port) - except ValueError: - pass - - # String to append after the URL - page: Option = default(None) - - local_proc = _forward( - local=Local(), - node=f"{node}.server.mila.quebec", - to_forward=remote_port, - page=page, - ) + remote = Remote("mila") + try: + pubkeys = remote.get_lines("ls -t ~/.ssh/id*.pub") + print("# OK") + except UnexpectedExit: + print("# MISSING") + if yn("You have no public keys on the login node. Generate them?"): + # print("(Note: You can just press Enter 3x to accept the defaults)") + # _, keyfile = remote.extract("ssh-keygen", pattern="Your public key has been saved in ([^ ]+)", wait=True) + private_file = "~/.ssh/id_rsa" + remote.run(f'ssh-keygen -q -t rsa -N "" -f {private_file}') + pubkeys = [f"{private_file}.pub"] + else: + exit("Cannot proceed because there is no public key") - try: - local_proc.wait() - except KeyboardInterrupt: - exit("Terminated by user.") - finally: - local_proc.kill() + common = remote.with_bash().get_output( + "comm -12 <(sort ~/.ssh/authorized_keys) <(sort ~/.ssh/*.pub)" + ) + if common: + print("# OK") + else: + print("# MISSING") + if yn( + "To connect to a compute node from a login node you need one id_*.pub to be in " + "authorized_keys. Do it?" + ): + pubkey = pubkeys[0] + remote.run(f"cat {pubkey} >> ~/.ssh/authorized_keys") + else: + exit("You will not be able to SSH to a compute node") + + ################### + # Welcome message # + ################### + + print(T.bold_cyan("=" * 60)) + print(T.bold_cyan("Congrats! You are now ready to start working on the cluster!")) + print(T.bold_cyan("=" * 60)) + print(T.bold("To connect to a login node:")) + print(" ssh mila") + print(T.bold("To allocate and connect to a compute node:")) + print(" ssh mila-cpu") + print(T.bold("To open a directory on the cluster with VSCode:")) + print(" mila code path/to/code/on/cluster") + print(T.bold("Same as above, but allocate 1 GPU, 4 CPUs, 32G of RAM:")) + print(" mila code path/to/code/on/cluster --alloc --gres=gpu:1 --mem=32G -c 4") + print() + print( + "For more information, read the milatools documentation at", + T.bold_cyan("https://github.com/mila-iqia/milatools"), + "or run `mila --help`.", + "Also make sure you read the Mila cluster documentation at", + T.bold_cyan("https://docs.mila.quebec/"), + "and join the", + T.bold_green("#mila-cluster"), + "channel on Slack.", + ) - def code(): - """Open a remote VSCode session on a compute node.""" - # Path to open on the remote machine - # [positional] - path: Option - # Command to use to start vscode - # (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) - command: Option = None +def forward(remote: str, page: str | None, port: int | None, **kwargs: Unpack[StandardServerArgs]): + """Forward a port on a compute node to your local machine.""" + node, remote_port = remote.split(":") + try: + remote_port = int(remote_port) + except ValueError: + pass - # Whether the server should persist or not - persist: Option & bool = default(False) + local_proc, _ = _forward( + local=Local(), + node=f"{node}.server.mila.quebec", + to_forward=remote_port, + page=page, + port=port, + ) + + try: + local_proc.wait() + except KeyboardInterrupt: + exit("Terminated by user.") + finally: + local_proc.kill() - if command is None: - command = os.environ.get("MILATOOLS_CODE_COMMAND", "code") - command_path = shutil.which(command) - if not command_path: - raise CommandNotFoundError(command) +def code( + path: str, command: str, persist: bool, job: str | None, node: str | None, alloc: Sequence[str] +): + """Open a remote VSCode session on a compute node. + + Arguments: + path: Path to open on the remote machine + command: Command to use to start vscode + (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) + persist: Whether the server should persist or not + job: Job ID to connect to + node: Node to connect to + alloc: Extra options to pass to slurm + """ + if command is None: + command = os.environ.get("MILATOOLS_CODE_COMMAND", "code") - remote = Remote("mila") - here = Local() + command_path = shutil.which(command) + if not command_path: + raise CommandNotFoundError(command) - cnode = _find_allocation(remote, job_name="mila-code") - if persist: - cnode = cnode.persist() - data, proc = cnode.ensure_allocation() + remote = Remote("mila") + here = Local() - node_name = data["node_name"] + cnode = _find_allocation(remote, job_name="mila-code", job=job, node=node, alloc=alloc) + if persist: + cnode = cnode.persist() + data, proc = cnode.ensure_allocation() - if not path.startswith("/"): - # Get $HOME because we have to give the full path to code - home = remote.home() - path = "/".join([home, path]) + node_name = data["node_name"] - try: - while True: - here.run( - command_path, - "-nw", - "--remote", - f"ssh-remote+{qualified(node_name)}", - path, - ) - print( - "The editor was closed. Reopen it with " - " or terminate the process with " - ) - input() - - except KeyboardInterrupt: - if not persist: - if proc is not None: - proc.kill() - print(f"Ended session on '{node_name}'") + if not path.startswith("/"): + # Get $HOME because we have to give the full path to code + home = remote.home() + path = "/".join([home, path]) - if persist: - print(f"This allocation is persistent and is still active.") - print(f"To reconnect to this node:") - print(T.bold(f" mila code {path} --node {node_name}")) - print(f"To kill this allocation:") - print(T.bold(f" ssh mila scancel {data['jobid']}")) - - class serve: - """Start services on compute nodes and forward them to your local machine.""" - - def connect(): - """Reconnect to a persistent server.""" - - remote = Remote("mila") - _, info = _get_server_info_command(remote) - - local_proc = _forward( - local=Local(), - node=f"{info['node_name']}.server.mila.quebec", - to_forward=info["to_forward"], - options={"token": info.get("token", None)}, - preferred_port=info["local_port"], - through_login=info["host"] == "0.0.0.0", + try: + while True: + here.run( + command_path, + "-nw", + "--remote", + f"ssh-remote+{qualified(node_name)}", + path, + ) + print( + "The editor was closed. Reopen it with " + " or terminate the process with " ) + input() - try: - local_proc.wait() - except KeyboardInterrupt: - exit("Terminated by user.") - finally: - local_proc.kill() + except KeyboardInterrupt: + if not persist: + if proc is not None: + proc.kill() + print(f"Ended session on '{node_name}'") - def kill(): - """Kill a persistent server.""" + if persist: + print("This allocation is persistent and is still active.") + print("To reconnect to this node:") + print(T.bold(f" mila code {path} --node {node_name}")) + print("To kill this allocation:") + print(T.bold(f" ssh mila scancel {data['jobid']}")) - # Server identifier output by the original mila serve command - # [positional: ?] - identifier: Option = default(None) - # Kill all servers - all: Option & bool = default(False) +def connect(identifier: str, port: int | None = None): + """Reconnect to a persistent server.""" - remote = Remote("mila") + remote = Remote("mila") + info = _get_server_info(remote, identifier) + local_proc, _ = _forward( + local=Local(), + node=f"{info['node_name']}.server.mila.quebec", + to_forward=info["to_forward"], + options={"token": info.get("token", None)}, + port=port or info["local_port"], + through_login=info["host"] == "0.0.0.0", + ) - if all: - for identifier in remote.get_lines("ls .milatools/control", hide=True): - info = _get_server_info(remote, identifier, hide=True) - if "jobid" in info: - remote.run(f"scancel {info['jobid']}") - remote.run(f"rm .milatools/control/{identifier}") + try: + local_proc.wait() + except KeyboardInterrupt: + exit("Terminated by user.") + finally: + local_proc.kill() - elif identifier is None: - exit("Please give the name of the server to kill") - else: - info = _get_server_info(remote, identifier) +def kill(identifier: str, all: bool = False): + """Kill a persistent server.""" + remote = Remote("mila") + if all: + for identifier in remote.get_lines("ls .milatools/control", hide=True): + info = _get_server_info(remote, identifier, hide=True) + if "jobid" in info: remote.run(f"scancel {info['jobid']}") - remote.run(f"rm .milatools/control/{identifier}") - - def list(): - """List active servers.""" - - # Purge dead or invalid servers - purge: Option & bool = default(False) - - remote = Remote("mila") - - to_purge = [] - - remote.run("mkdir -p ~/.milatools/control", hide=True) - - for identifier in remote.get_lines("ls .milatools/control", hide=True): - info = _get_server_info(remote, identifier, hide=True) - jobid = info.get("jobid", None) - status = remote.get_output( - f"squeue -j {jobid} -ho %T", hide=True, warn=True - ) - program = info.pop("program", "???") - if status == "RUNNING": - necessary_keys = {"node_name", "to_forward"} - if any(k not in info for k in necessary_keys): - qn.print( - f"{identifier} ({program}, MISSING INFO)", style="bold red" - ) - to_purge.append((identifier, jobid)) - else: - qn.print(f"{identifier} ({program})", style="bold yellow") - else: - qn.print(f"{identifier} ({program}, DEAD)", style="bold red") - to_purge.append((identifier, None)) - for k, v in info.items(): - print(f" {k:20} : {v}") - - if purge: - for identifier, jobid in to_purge: - if jobid is not None: - remote.run(f"scancel {jobid}") - remote.run(f"rm .milatools/control/{identifier}") - - def lab(): - """Start a Jupyterlab server.""" - - # Path to open on the remote machine - # [positional: ?] - path: Option = default(None) - - if path and path.endswith(".ipynb"): - exit("Only directories can be given to the mila serve lab command") - - _standard_server( - path, - program="jupyter-lab", - installers={ - "conda": "conda install -y jupyterlab", - "pip": "pip install jupyterlab", - }, - command="jupyter lab --sock {sock} {path}", - # command="jupyter lab --ip {host} --port 0", - token_pattern=r"\?token=([a-f0-9]+)", - ) + remote.run(f"rm .milatools/control/{identifier}") - def notebook(): - """Start a Jupyter Notebook server.""" + elif identifier is None: + exit("Please give the name of the server to kill") - # Path to open on the remote machine - # [positional: ?] - path: Option = default(None) + else: + info = _get_server_info(remote, identifier) - if path and path.endswith(".ipynb"): - exit("Only directories can be given to the mila serve notebook command") + remote.run(f"scancel {info['jobid']}") + remote.run(f"rm .milatools/control/{identifier}") - _standard_server( - path, - program="jupyter-notebook", - installers={ - "conda": "conda install -y jupyter", - "pip": "pip install jupyter", - }, - command="jupyter notebook --sock {sock} {path}", - # command="jupyter notebook --ip {host} --port 0", - token_pattern=r"\?token=([a-f0-9]+)", - ) - def tensorboard(): - """Start a Tensorboard server.""" - - # Path to the experiment logs - # [positional] - logdir: Option - - _standard_server( - logdir, - program="tensorboard", - installers={ - "conda": "conda install -y tensorboard", - "pip": "pip install tensorboard", - }, - command="tensorboard --logdir {path} --host {host} --port 0", - port_pattern="TensorBoard [^ ]+ at http://[^:]+:([0-9]+)/", - ) +def serve_list(purge: bool): + """List active servers.""" + remote = Remote("mila") - def mlflow(): - """Start an MLFlow server.""" - - # Path to the experiment logs - # [positional] - logdir: Option - - _standard_server( - logdir, - program="mlflow", - installers={ - "pip": "pip install mlflow", - }, - command="mlflow ui --backend-store-uri {path} --host {host} --port 0", - port_pattern="Listening at: http://[^:]+:([0-9]+)", - ) + to_purge = [] - def aim(): - """Start an AIM server.""" - - # Path to the experiment logs - # [positional] - logdir: Option - - _standard_server( - logdir, - program="aim", - installers={ - "pip": "pip install aim", - }, - command="aim up --repo {path} --host {host} --port 0", - port_pattern=f"Open http://[^:]+:([0-9]+)", - ) + remote.run("mkdir -p ~/.milatools/control", hide=True) + for identifier in remote.get_lines("ls .milatools/control", hide=True): + info = _get_server_info(remote, identifier, hide=True) + jobid = info.get("jobid", None) + status = remote.get_output(f"squeue -j {jobid} -ho %T", hide=True, warn=True) + program = info.pop("program", "???") + if status == "RUNNING": + necessary_keys = {"node_name", "to_forward"} + if any(k not in info for k in necessary_keys): + qn.print(f"{identifier} ({program}, MISSING INFO)", style="bold red") + to_purge.append((identifier, jobid)) + else: + qn.print(f"{identifier} ({program})", style="bold yellow") + else: + qn.print(f"{identifier} ({program}, DEAD)", style="bold red") + to_purge.append((identifier, None)) + for k, v in info.items(): + print(f" {k:20} : {v}") + + if purge: + for identifier, jobid in to_purge: + if jobid is not None: + remote.run(f"scancel {jobid}") + remote.run(f"rm .milatools/control/{identifier}") + + +class StandardServerArgs(TypedDict): + profile: str | None + """Name of the profile to use""" + persist: bool + """Whether the server should persist or not""" + name: str | None + """Name of the persistent server""" + node: str | None + """Node to connect to""" -def _get_server_info(remote, identifier, hide=False): + job: str | None + """Job ID to connect to""" + + alloc: Sequence[str] + """Extra options to pass to slurm""" + + +def lab(path: str | None, **kwargs: Unpack[StandardServerArgs]): + """Start a Jupyterlab server. + + Arguments: + path: Path to open on the remote machine + """ + + if path and path.endswith(".ipynb"): + exit("Only directories can be given to the mila serve lab command") + + _standard_server( + path, + program="jupyter-lab", + installers={ + "conda": "conda install -y jupyterlab", + "pip": "pip install jupyterlab", + }, + command="jupyter lab --sock {sock} {path}", + # command="jupyter lab --ip {host} --port 0", + token_pattern=r"\?token=([a-f0-9]+)", + **kwargs, + ) + + +def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]): + """Start a Jupyter Notebook server. + + Arguments: + path: Path to open on the remote machine + """ + if path and path.endswith(".ipynb"): + exit("Only directories can be given to the mila serve notebook command") + + _standard_server( + path, + program="jupyter-notebook", + installers={ + "conda": "conda install -y jupyter", + "pip": "pip install jupyter", + }, + command="jupyter notebook --sock {sock} {path}", + # command="jupyter notebook --ip {host} --port 0", + token_pattern=r"\?token=([a-f0-9]+)", + **kwargs, + ) + + +def tensorboard(logdir: str, **kwargs: Unpack[StandardServerArgs]): + """Start a Tensorboard server. + + Arguments: + logdir: Path to the experiment logs + """ + + _standard_server( + logdir, + program="tensorboard", + installers={ + "conda": "conda install -y tensorboard", + "pip": "pip install tensorboard", + }, + command="tensorboard --logdir {path} --host {host} --port 0", + port_pattern="TensorBoard [^ ]+ at http://[^:]+:([0-9]+)/", + **kwargs, + ) + + +def mlflow(logdir: str, **kwargs: Unpack[StandardServerArgs]): + """Start an MLFlow server. + + Arguments: + logdir: Path to the experiment logs + """ + + _standard_server( + logdir, + program="mlflow", + installers={ + "pip": "pip install mlflow", + }, + command="mlflow ui --backend-store-uri {path} --host {host} --port 0", + port_pattern="Listening at: http://[^:]+:([0-9]+)", + **kwargs, + ) + + +def aim(logdir: str, **kwargs: Unpack[StandardServerArgs]): + """Start an AIM server. + + Arguments: + logdir: Path to the experiment logs + """ + _standard_server( + logdir, + program="aim", + installers={ + "pip": "pip install aim", + }, + command="aim up --repo {path} --host {host} --port 0", + port_pattern=r"Open http://[^:]+:([0-9]+)", + **kwargs, + ) + + +def _get_server_info(remote: Remote, identifier: str, hide: bool = False) -> dict[str, str]: text = remote.get_output(f"cat .milatools/control/{identifier}", hide=hide) info = dict(line.split(" = ") for line in text.split("\n") if line) return info -@tooled -def _get_server_info_command(remote): - # Server identifier output by the original mila serve command - # [positional] - identifier: Option +class SortingHelpFormatter(argparse.HelpFormatter): + """Taken and adapted from https://stackoverflow.com/a/12269143/6388696""" + + def add_arguments(self, actions): + actions = sorted(actions, key=operator.attrgetter("option_strings")) + # put help actions first. + actions = sorted(actions, key=lambda action: not isinstance(action, _HelpAction)) + super().add_arguments(actions) + - return identifier, _get_server_info(remote, identifier) +def _add_standard_server_args(parser: ArgumentParser): + _add_find_allocation_args(parser) + # parser.add_argument("--alloc", nargs=argparse.REMAINDER, help="Extra options to pass to slurm") + # parser.add_argument("--job", type=str, default=None, help="Job ID to connect to") + parser.add_argument("--name", default=None, type=str, help="Name of the persistent server") + # parser.add_argument("--node", type=str, default=None, help="Node to connect to") + parser.add_argument("--persist", action="store_true", help="Whether the server should persist") + parser.add_argument("--port", type=int, default=None, help="Port to open on the local machine") + parser.add_argument("--profile", default=None, type=str, help="Name of the profile to use") -@tooled def _standard_server( - path, - program, + path: str | None, + *, + program: str, installers, command, + profile: str | None, + persist: bool, + name: str | None, + node: str | None, + job: str | None, + alloc: Sequence[str], port_pattern=None, token_pattern=None, ): - # Name of the profile to use - profile: Option = default(None) - - # Whether the server should persist or not - persist: Option & bool = default(False) - - # Name of the persistent server - name: Option = default(None) - # Make the server visible from the login node (other users will be able to connect) # share: Option & bool = default(False) # Temporarily disabled @@ -569,6 +707,11 @@ def _standard_server( if path == "~" or path.startswith("~/"): path = remote.home() + path[1:] + results: dict | None = None + node_name: str | None = None + to_forward: int | str | None = None + cf: str | None = None + proc = None with ExitStack() as stack: if persist: cf = stack.enter_context(with_control_file(remote, name=name)) @@ -583,9 +726,9 @@ def _standard_server( qn.print(f"Using profile: {prof}") cat_result = remote.run(f"cat {prof}", hide=True, warn=True) if cat_result.ok: - qn.print(f"=" * 50) + qn.print("=" * 50) qn.print(cat_result.stdout.rstrip()) - qn.print(f"=" * 50) + qn.print("=" * 50) else: exit(f"Could not find or load profile: {prof}") @@ -598,7 +741,9 @@ def _standard_server( ): exit(f"Exit: {program} is not installed.") - cnode = _find_allocation(remote, job_name=f"mila-serve-{program}") + cnode = _find_allocation( + remote, job_name=f"mila-serve-{program}", node=node, job=job, alloc=alloc + ) patterns = { "node_name": "#### ([A-Za-z0-9_-]+)", @@ -607,9 +752,7 @@ def _standard_server( if port_pattern: patterns["port"] = port_pattern elif share: - exit( - "Server cannot be shared because it is serving over a Unix domain socket" - ) + exit("Server cannot be shared because it is serving over a Unix domain socket") else: remote.run("mkdir -p ~/.milatools/sockets", hide=True) @@ -654,12 +797,16 @@ def _standard_server( if token_pattern: remote.simple_run(f"echo token = {results['token']} >> {cf}") + assert results is not None + assert node_name is not None + assert to_forward is not None + assert proc is not None if token_pattern: options = {"token": results["token"]} else: options = {} - local_proc = _forward( + local_proc, local_port = _forward( local=Local(), node=qualified(node_name), to_forward=to_forward, @@ -667,7 +814,7 @@ def _standard_server( ) if cf is not None: - remote.simple_run(f"echo local_port = {local_proc.local_port} >> {cf}") + remote.simple_run(f"echo local_port = {local_port} >> {cf}") try: local_proc.wait() @@ -684,17 +831,28 @@ def _standard_server( proc.kill() -@tooled -def _find_allocation(remote, job_name="mila-tools"): - # Node to connect to - node: Option = default(None) +def _add_find_allocation_args(parser: ArgumentParser): + parser.add_argument("--alloc", nargs=argparse.REMAINDER, help="Extra options to pass to slurm") + parser.add_argument("--node", type=str, default=None, help="Node to connect to") + parser.add_argument("--job", type=str, default=None, help="Job ID to connect to") - # Job ID to connect to - job: Option = default(None) - # Extra options to pass to slurm - # [nargs: --] - alloc: Option = default([]) +def _find_allocation( + remote, + node: str | None, + job: str | None, + alloc: Sequence[str], + job_name: str = "mila-tools", +): + # # Node to connect to + # node: Option = default(None) + + # # Job ID to connect to + # job: Option = default(None) + + # # Extra options to pass to slurm + # # [nargs: --] + # alloc: Option = default([]) if (node is not None) + (job is not None) + bool(alloc) > 1: exit("ERROR: --node, --job and --alloc are mutually exclusive") @@ -715,19 +873,15 @@ def _find_allocation(remote, job_name="mila-tools"): ) -@tooled def _forward( - local, - node, - to_forward, - page=None, - options={}, - preferred_port=None, - through_login=False, + local: Local, + node: str, + to_forward: int | str, + page: str | None = None, + options: dict[str, str | None] = {}, + port: int | str | None = None, + through_login: bool = False, ): - # Port to open on the local machine - port: Option = default(preferred_port) - if port is None: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Find a free local port by binding to port 0 @@ -774,9 +928,9 @@ def _forward( try: # This feels stupid, there's probably a better way local.silent_get("nc", "-z", "localhost", str(port)) - except subprocess.CalledProcessError as exc: + except subprocess.CalledProcessError: continue - except Exception as exc: + except Exception: break break @@ -785,5 +939,8 @@ def _forward( style="bold", ) webbrowser.open(url) - proc.local_port = port - return proc + return proc, port + + +if __name__ == "__main__": + main()