Skip to content

Commit

Permalink
Fix CLI so --help 100% matches previous output
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Aug 28, 2023
1 parent f575d82 commit d9ef604
Showing 1 changed file with 95 additions and 31 deletions.
126 changes: 95 additions & 31 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from argparse import ArgumentParser, _HelpAction
from contextlib import ExitStack
from pathlib import Path
from typing import Sequence
from typing import Any, Sequence
from urllib.parse import urlencode

import questionary as qn
Expand Down Expand Up @@ -92,7 +92,13 @@ def main():
def mila():
parser_kwargs = dict()
parser = ArgumentParser(prog="mila", description=__doc__, add_help=True, **parser_kwargs)
parser.add_argument("-v", "--version", action="version", version=f"milatools v{mversion}")
parser.add_argument(
"--version",
"-v",
action="version",
version=f"milatools v{mversion}",
help="Milatools version",
)

subparsers = parser.add_subparsers(
dest="command",
Expand All @@ -101,15 +107,15 @@ def mila():
)

docs_parser = subparsers.add_parser(
"docs", help="Open the Mila cluster documentation", formatter_class=SortingHelpFormatter
"docs", help="Open the Mila cluster documentation.", formatter_class=SortingHelpFormatter
)
docs_parser.add_argument("search", nargs=argparse.REMAINDER, help="Search terms")
docs_parser.add_argument("SEARCH", nargs=argparse.REMAINDER, help="Search terms")
docs_parser.set_defaults(function=docs)

intranet_parser = subparsers.add_parser(
"intranet", help="Open the Mila intranet in a browser."
)
intranet_parser.add_argument("search", nargs=argparse.REMAINDER, help="Search terms")
intranet_parser.add_argument("SEARCH", nargs=argparse.REMAINDER, help="Search terms")
intranet_parser.set_defaults(function=intranet)

init_parser = subparsers.add_parser(
Expand All @@ -124,27 +130,45 @@ def mila():
help="Forward a port on a compute node to your local machine.",
formatter_class=SortingHelpFormatter,
)
forward_parser.add_argument("remote", help="node:port to forward")
forward_parser.add_argument("REMOTE", help="node:port to forward")
forward_parser.add_argument(
"--page",
# nargs="?",
help="String to append after the URL",
default=None,
metavar="VALUE",
)
forward_parser.add_argument(
"--page", nargs="?", help="String to append after the URL", default=None
"--port", type=int, help="Port to open on the local machine", default=None, metavar="VALUE"
)
forward_parser.add_argument("--port", type=int, help="Local port to forward to", default=None)
forward_parser.set_defaults(function=forward)

code_parser = subparsers.add_parser(
"code",
help="Open a remote VSCode session on a compute node.",
formatter_class=SortingHelpFormatter,
)
code_parser.add_argument("path", help="Path to open on the remote machine", type=str)
_add_find_allocation_args(code_parser)
code_parser.add_argument("PATH", help="Path to open on the remote machine", type=str)
code_parser.add_argument(
"--alloc", nargs=argparse.REMAINDER, help="Extra options to pass to slurm", metavar="VALUE"
)
code_parser.add_argument(
"--command",
default=os.environ.get("MILATOOLS_CODE_COMMAND", "code"),
help=(
"Command to use to start vscode\n"
"(defaults to 'code' or the value of $MILATOOLS_CODE_COMMAND)"
'(defaults to "code" or the value of $MILATOOLS_CODE_COMMAND)'
),
metavar="VALUE",
)
# TODO: Use the `_add_find_allocation_args` to avoid duplication. It was expanded so the args
# ordering matches (weird that it's even necessary, since we're using this
# SortingHelpFormatter..)
code_parser.add_argument(
"--job", type=str, default=None, help="Job ID to connect to", metavar="VALUE"
)
code_parser.add_argument(
"--node", type=str, default=None, help="Node to connect to", metavar="VALUE"
)
code_parser.add_argument(
"--persist", action="store_true", help="Whether the server should persist or not"
Expand All @@ -158,13 +182,22 @@ def mila():
)
serve_subparsers = serve_parser.add_subparsers(dest="serve_command", required=True)

# class _KeepSpacesFormatter(SortingHelpFormatter, argparse.RawTextHelpFormatter):
# ...

serve_connect_parser = serve_subparsers.add_parser(
"connect",
help="Reconnect to a persistent server.",
formatter_class=SortingHelpFormatter,
)
serve_connect_parser.add_argument(
"identifier", type=str, help="Server identifier output by the original mila serve command"
"IDENTIFIER",
type=str,
# todo:remove spaces
help="Server identifier output by the original mila serve command",
)
serve_connect_parser.add_argument(
"--port", type=int, help="Port to open on the local machine", default=None, metavar="VALUE"
)
serve_connect_parser.set_defaults(function=connect)

Expand All @@ -174,8 +207,10 @@ def mila():
formatter_class=SortingHelpFormatter,
)
serve_kill_parser.add_argument(
"identifier",
"IDENTIFIER",
type=str,
nargs="?",
default=None,
help="Server identifier output by the original mila serve command",
)
serve_kill_parser.add_argument("--all", action="store_true", help="Kill all servers")
Expand All @@ -193,7 +228,7 @@ def mila():
"lab", help="Start a Jupyterlab server.", formatter_class=SortingHelpFormatter
)
serve_lab_parser.add_argument(
"path", default=None, nargs="?", help="Path to open on the remote machine"
"PATH", default=None, nargs="?", help="Path to open on the remote machine"
)
_add_standard_server_args(serve_lab_parser)
serve_lab_parser.set_defaults(function=lab)
Expand All @@ -202,29 +237,29 @@ def mila():
"notebook", help="Start a Jupyter Notebook server.", formatter_class=SortingHelpFormatter
)
serve_notebook_parser.add_argument(
"path", default=None, nargs="?", help="Path to open on the remote machine"
"PATH", default=None, nargs="?", help="Path to open on the remote machine"
)
_add_standard_server_args(serve_notebook_parser)
serve_notebook_parser.set_defaults(function=notebook)

serve_tensorboard_parser = serve_subparsers.add_parser(
"tensorboard", help="Start a Tensorboard server.", formatter_class=SortingHelpFormatter
)
serve_tensorboard_parser.add_argument("logdir", type=str, help="Path to the experiment logs")
serve_tensorboard_parser.add_argument("LOGDIR", type=str, help="Path to the experiment logs")
_add_standard_server_args(serve_tensorboard_parser)
serve_notebook_parser.set_defaults(function=tensorboard)

serve_mlflow_parser = serve_subparsers.add_parser(
"mlflow", help="Start an MLFlow server.", formatter_class=SortingHelpFormatter
)
serve_mlflow_parser.add_argument("logdir", type=str, help="Path to the experiment logs")
serve_mlflow_parser.add_argument("LOGDIR", type=str, help="Path to the experiment logs")
_add_standard_server_args(serve_mlflow_parser)
serve_mlflow_parser.set_defaults(function=mlflow)

serve_aim_parser = serve_subparsers.add_parser(
"aim", help="Start an Aim server.", formatter_class=SortingHelpFormatter
"aim", help="Start an AIM server.", formatter_class=SortingHelpFormatter
)
serve_aim_parser.add_argument("logdir", type=str, help="Path to the experiment logs")
serve_aim_parser.add_argument("LOGDIR", type=str, help="Path to the experiment logs")
_add_standard_server_args(serve_aim_parser)
serve_aim_parser.set_defaults(function=aim)

Expand All @@ -233,10 +268,16 @@ def mila():
function = args_dict.pop("function")
_ = args_dict.pop("command")
_ = args_dict.pop("serve_command", None)
# replace SEARCH -> "search", REMOTE -> "remote", etc.
args_dict = _convert_uppercase_keys_to_lowercase(args_dict)
assert callable(function)
return function(**args_dict)


def _convert_uppercase_keys_to_lowercase(args_dict: dict[str, Any]) -> dict[str, Any]:
return {(k.lower() if k.isupper() else k): v for k, v in args_dict.items()}


def docs(search: Sequence[str]) -> None:
url = "https://docs.mila.quebec"
terms = "+".join(search)
Expand Down Expand Up @@ -474,12 +515,13 @@ def connect(identifier: str, port: int | None = None):
local_proc.kill()


def kill(identifier: str, all: bool = False):
def kill(identifier: str | None, all: bool = False):
"""Kill a persistent server."""
remote = Remote("mila")

if all:
for identifier in remote.get_lines("ls .milatools/control", hide=True):
assert isinstance(identifier, str) # note: was implicit before.
info = _get_server_info(remote, identifier, hide=True)
if "jobid" in info:
remote.run(f"scancel {info['jobid']}")
Expand Down Expand Up @@ -666,14 +708,30 @@ def add_arguments(self, actions):


def _add_standard_server_args(parser: ArgumentParser):
_add_find_allocation_args(parser)
# parser.add_argument("--alloc", nargs=argparse.REMAINDER, help="Extra options to pass to slurm")
# parser.add_argument("--job", type=str, default=None, help="Job ID to connect to")
parser.add_argument("--name", default=None, type=str, help="Name of the persistent server")
# parser.add_argument("--node", type=str, default=None, help="Node to connect to")
parser.add_argument("--persist", action="store_true", help="Whether the server should persist")
parser.add_argument("--port", type=int, default=None, help="Port to open on the local machine")
parser.add_argument("--profile", default=None, type=str, help="Name of the profile to use")
# TODO: Use this function to avoid duplicating code. It was expanded so the order of args match
# _add_find_allocation_args(parser)

parser.add_argument(
"--alloc", nargs=argparse.REMAINDER, help="Extra options to pass to slurm", metavar="VALUE"
)
parser.add_argument(
"--job", type=str, default=None, help="Job ID to connect to", metavar="VALUE"
)
parser.add_argument(
"--name", default=None, type=str, help="Name of the persistent server", metavar="VALUE"
)
parser.add_argument(
"--node", type=str, default=None, help="Node to connect to", metavar="VALUE"
)
parser.add_argument(
"--persist", action="store_true", help="Whether the server should persist or not"
)
parser.add_argument(
"--port", type=int, default=None, help="Port to open on the local machine", metavar="VALUE"
)
parser.add_argument(
"--profile", default=None, type=str, help="Name of the profile to use", metavar="VALUE"
)


def _standard_server(
Expand Down Expand Up @@ -832,9 +890,15 @@ def _standard_server(


def _add_find_allocation_args(parser: ArgumentParser):
parser.add_argument("--alloc", nargs=argparse.REMAINDER, help="Extra options to pass to slurm")
parser.add_argument("--node", type=str, default=None, help="Node to connect to")
parser.add_argument("--job", type=str, default=None, help="Job ID to connect to")
parser.add_argument(
"--alloc", nargs=argparse.REMAINDER, help="Extra options to pass to slurm", metavar="VALUE"
)
parser.add_argument(
"--node", type=str, default=None, help="Node to connect to", metavar="VALUE"
)
parser.add_argument(
"--job", type=str, default=None, help="Job ID to connect to", metavar="VALUE"
)


def _find_allocation(
Expand Down

0 comments on commit d9ef604

Please sign in to comment.