Skip to content

Commit

Permalink
Tweak cli. add --sbatch and --salloc as alternatives to --alloc and -…
Browse files Browse the repository at this point in the history
…-persist (#119)

* Split `mila` function into chunks

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Make `mila code` default to `mila code .`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Make the job_id an int instead of str

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add --salloc and --sbatch flags (see desc.)

- Adds a --salloc flag which is exaclty the same as using the '--alloc'
  flag (without the --persist) flag.
- Adds a --sbatch flag which is the same as doing --persist --alloc ...

I think these are more naturally understood as the argument that are
passed to `salloc` and `sbatch` respectively.

Also, these two new args are in a mutually exclusive group with
--persist.

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Put the --alloc/--salloc/--sbatch args last

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add missing regression test file

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Jun 3, 2024
1 parent 9b4a7cd commit bf82ca9
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 81 deletions.
127 changes: 85 additions & 42 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from contextlib import ExitStack
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any
from typing import Any, Callable
from urllib.parse import urlencode

import questionary as qn
Expand Down Expand Up @@ -54,6 +54,7 @@
from .profile import ensure_program, setup_profile
from .utils import (
CLUSTERS,
AllocationFlagsAction,
Cluster,
CommandNotFoundError,
MilatoolsUserError,
Expand Down Expand Up @@ -130,6 +131,13 @@ def main():

def mila():
parser = ArgumentParser(prog="mila", description=__doc__, add_help=True)
add_arguments(parser)
verbose, function, args_dict = parse_args(parser)
setup_logging(verbose)
return function(**args_dict)


def add_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--version",
action="version",
Expand Down Expand Up @@ -198,24 +206,26 @@ def mila():
code_parser = subparsers.add_parser(
"code",
help="Open a remote VSCode session on a compute node.",
formatter_class=SortingHelpFormatter,
)
code_parser.add_argument(
"PATH", help="Path to open on the remote machine", type=str
"PATH",
help=(
"Path to open on the remote machine. Defaults to $HOME.\n"
"Can be a relative or absolute path. When a relative path (that doesn't "
"start with a '/', like foo/bar) is passed, the path is relative to the "
"$HOME directory on the selected cluster.\n"
"For example, foo/project will be interpreted as $HOME/foo/project."
),
type=str,
default=".",
nargs="?",
)
code_parser.add_argument(
"--cluster",
choices=CLUSTERS,
default="mila",
help="Which cluster to connect to.",
)
code_parser.add_argument(
"--alloc",
nargs=argparse.REMAINDER,
help="Extra options to pass to slurm",
metavar="VALUE",
default=[],
)
code_parser.add_argument(
"--command",
default=get_code_command(),
Expand All @@ -227,23 +237,20 @@ def mila():
)
code_parser.add_argument(
"--job",
type=str,
type=int,
default=None,
help="Job ID to connect to",
metavar="VALUE",
metavar="JOB_ID",
)
code_parser.add_argument(
"--node",
type=str,
default=None,
help="Node to connect to",
metavar="VALUE",
)
code_parser.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not",
metavar="NODE",
)
_add_allocation_options(code_parser)

code_parser.set_defaults(function=code)

# ----- mila sync vscode-extensions ------
Expand Down Expand Up @@ -353,7 +360,6 @@ def mila():
serve_lab_parser = serve_subparsers.add_parser(
"lab",
help="Start a Jupyterlab server.",
formatter_class=SortingHelpFormatter,
)
serve_lab_parser.add_argument(
"PATH",
Expand All @@ -369,7 +375,6 @@ def mila():
serve_notebook_parser = serve_subparsers.add_parser(
"notebook",
help="Start a Jupyter Notebook server.",
formatter_class=SortingHelpFormatter,
)
serve_notebook_parser.add_argument(
"PATH",
Expand All @@ -385,7 +390,6 @@ def mila():
serve_tensorboard_parser = serve_subparsers.add_parser(
"tensorboard",
help="Start a Tensorboard server.",
formatter_class=SortingHelpFormatter,
)
serve_tensorboard_parser.add_argument(
"LOGDIR", type=str, help="Path to the experiment logs"
Expand All @@ -398,7 +402,6 @@ def mila():
serve_mlflow_parser = serve_subparsers.add_parser(
"mlflow",
help="Start an MLFlow server.",
formatter_class=SortingHelpFormatter,
)
serve_mlflow_parser.add_argument(
"LOGDIR", type=str, help="Path to the experiment logs"
Expand All @@ -411,22 +414,29 @@ def mila():
serve_aim_parser = serve_subparsers.add_parser(
"aim",
help="Start an AIM server.",
formatter_class=SortingHelpFormatter,
)
serve_aim_parser.add_argument(
"LOGDIR", type=str, help="Path to the experiment logs"
)
_add_standard_server_args(serve_aim_parser)
serve_aim_parser.set_defaults(function=aim)


def parse_args(parser: argparse.ArgumentParser) -> tuple[int, Callable, dict[str, Any]]:
"""Parses the command-line arguments.
Returns the verbosity level, the function (or awaitable) to call, and the arguments
to the function.
"""
args = parser.parse_args()
args_dict = vars(args)

verbose: int = args_dict.pop("verbose")

function = args_dict.pop("function")
_ = args_dict.pop("<command>")
_ = args_dict.pop("<serve_subcommand>", None)
_ = args_dict.pop("<sync_subcommand>", None)
setup_logging(verbose)
# replace SEARCH -> "search", REMOTE -> "remote", etc.
args_dict = _convert_uppercase_keys_to_lowercase(args_dict)

Expand All @@ -438,7 +448,7 @@ def mila():
return

assert callable(function)
return function(**args_dict)
return verbose, function, args_dict


def setup_logging(verbose: int) -> None:
Expand Down Expand Up @@ -550,7 +560,7 @@ def code(
path: str,
command: str,
persist: bool,
job: str | None,
job: int | None,
node: str | None,
alloc: list[str],
cluster: Cluster = "mila",
Expand Down Expand Up @@ -788,7 +798,7 @@ class StandardServerArgs(TypedDict):
alloc: list[str]
"""Extra options to pass to slurm."""

job: str | None
job: int | None
"""Job ID to connect to."""

name: str | None
Expand Down Expand Up @@ -931,20 +941,56 @@ def add_arguments(self, actions):
super().add_arguments(actions)


def _add_standard_server_args(parser: ArgumentParser):
parser.add_argument(
def _add_allocation_options(parser: ArgumentParser):
# note: Ideally we'd like [--persist --alloc] | [--salloc] | [--sbatch] (i.e. a
# subgroup with alloc and persist within a mutually exclusive group with salloc and
# sbatch) but that doesn't seem possible with argparse as far as I can tell.
arg_group = parser.add_argument_group(
"Allocation options", description="Extra options to pass to slurm."
)
alloc_group = arg_group.add_mutually_exclusive_group()
common_kwargs = {
"dest": "alloc",
"nargs": argparse.REMAINDER,
"action": AllocationFlagsAction,
"metavar": "VALUE",
"default": [],
}
alloc_group.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not when using --alloc",
)
# --persist can be used with --alloc
arg_group.add_argument(
"--alloc",
nargs=argparse.REMAINDER,
help="Extra options to pass to slurm",
metavar="VALUE",
default=[],
**common_kwargs,
help="Extra options to pass to salloc or to sbatch if --persist is set.",
)
# --persist cannot be used with --salloc or --sbatch.
# Note: REMAINDER args like --alloc, --sbatch and --salloc are already mutually
# exclusive in a sense, since it's only possible to use one correctly, the other
# args are stored in the first one (e.g. mila code --alloc --salloc bob will have
# alloc of ["--salloc", "bob"]).
alloc_group.add_argument(
"--salloc",
**common_kwargs,
help="Extra options to pass to salloc. Same as using --alloc without --persist.",
)
alloc_group.add_argument(
"--sbatch",
**common_kwargs,
help="Extra options to pass to sbatch. Same as using --alloc with --persist.",
)


def _add_standard_server_args(parser: ArgumentParser):
parser.add_argument(
"--job",
type=str,
type=int,
default=None,
help="Job ID to connect to",
metavar="VALUE",
metavar="JOB_ID",
)
parser.add_argument(
"--name",
Expand All @@ -960,11 +1006,6 @@ def _add_standard_server_args(parser: ArgumentParser):
help="Node to connect to",
metavar="VALUE",
)
parser.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not",
)
parser.add_argument(
"--port",
type=int,
Expand All @@ -979,6 +1020,8 @@ def _add_standard_server_args(parser: ArgumentParser):
help="Name of the profile to use",
metavar="VALUE",
)
# Add these arguments last because we want them to show up last in the usage message
_add_allocation_options(parser)


def _standard_server(
Expand All @@ -992,7 +1035,7 @@ def _standard_server(
port: int | None,
name: str | None,
node: str | None,
job: str | None,
job: int | None,
alloc: list[str],
port_pattern=None,
token_pattern=None,
Expand Down Expand Up @@ -1277,7 +1320,7 @@ def get_colour(used: float, max: float) -> str:
def _find_allocation(
remote: RemoteV1,
node: str | None,
job: str | None,
job: int | str | None,
alloc: list[str],
cluster: Cluster = "mila",
job_name: str = "mila-tools",
Expand Down
31 changes: 31 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import argparse
import contextvars
import functools
import itertools
Expand All @@ -13,6 +14,7 @@
import warnings
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Literal, TypeVar, Union, get_args

Expand All @@ -27,6 +29,7 @@
from milatools.utils.remote_v1 import RemoteV1


logger = get_logger(__name__)
control_file_var = contextvars.ContextVar("control_file", default="/dev/null")

SSH_CONFIG_FILE = Path.home() / ".ssh" / "config"
Expand Down Expand Up @@ -375,3 +378,31 @@ def removesuffix(s: str, suffix: str) -> str:
return s
else:
removesuffix = str.removesuffix


class AllocationFlagsAction(argparse._StoreAction):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace,
values: list[str],
option_string: str | None = None,
):
persist: bool | None = namespace.persist
if option_string == "--alloc":
namespace.alloc = values
elif option_string == "--salloc":
# --salloc is in a mutually exclusive group with --persist
assert not persist
if persist:
raise argparse.ArgumentError(
argument=self,
message="Cannot use --salloc with --persist, use only --sbatch for a persistent session.",
)
namespace.alloc = values
else:
assert option_string == "--sbatch", option_string
# --sbatch is in a mutually exclusive group with --persist
assert not persist
namespace.alloc = values
namespace.persist = True
2 changes: 1 addition & 1 deletion tests/cli/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_help(
[
"mila", # Error: Missing a subcommand.
"mila search conda",
"mila code", # Error: Missing the required PATH argument.
"mila code --boo", # Error: Unknown argument.
"mila serve", # Error: Missing the subcommand.
"mila forward", # Error: Missing the REMOTE argument.
],
Expand Down
31 changes: 23 additions & 8 deletions tests/cli/test_commands/test_help_mila_code_.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
usage: mila code [-h] [--cluster {mila,cedar,narval,beluga,graham}]
[--alloc ...] [--command VALUE] [--job VALUE] [--node VALUE]
[--persist]
PATH
[--command VALUE] [--job JOB_ID] [--node NODE] [--persist]
[--alloc ...] [--salloc ...] [--sbatch ...]
[PATH]

positional arguments:
PATH Path to open on the remote machine
PATH Path to open on the remote machine. Defaults to $HOME.
Can be a relative or absolute path. When a relative
path (that doesn't start with a '/', like foo/bar) is
passed, the path is relative to the $HOME directory on
the selected cluster. For example, foo/project will be
interpreted as $HOME/foo/project.

optional arguments:
-h, --help show this help message and exit
--alloc ... Extra options to pass to slurm
--cluster {mila,cedar,narval,beluga,graham}
Which cluster to connect to.
--command VALUE Command to use to start vscode (defaults to "code" or
the value of $MILATOOLS_CODE_COMMAND)
--job VALUE Job ID to connect to
--node VALUE Node to connect to
--persist Whether the server should persist or not
--job JOB_ID Job ID to connect to
--node NODE Node to connect to

Allocation optional arguments:
Extra options to pass to slurm.

--persist Whether the server should persist or not when using
--alloc
--alloc ... Extra options to pass to salloc or to sbatch if
--persist is set.
--salloc ... Extra options to pass to salloc. Same as using --alloc
without --persist.
--sbatch ... Extra options to pass to sbatch. Same as using --alloc
with --persist.
Loading

0 comments on commit bf82ca9

Please sign in to comment.