Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak cli. add --sbatch and --salloc as alternatives to --alloc and --persist #119

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 86 additions & 42 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Cluster documentation: https://docs.mila.quebec/
"""

from __future__ import annotations

import argparse
Expand All @@ -21,7 +22,7 @@
from contextlib import ExitStack
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any
from typing import Any, Callable
from urllib.parse import urlencode

import questionary as qn
Expand Down Expand Up @@ -51,6 +52,7 @@
from .remote import Remote, SlurmRemote
from .utils import (
CLUSTERS,
AllocationFlagsAction,
Cluster,
CommandNotFoundError,
MilatoolsUserError,
Expand Down Expand Up @@ -127,6 +129,13 @@ def main():

def mila():
parser = ArgumentParser(prog="mila", description=__doc__, add_help=True)
add_arguments(parser)
verbose, function, args_dict = parse_args(parser)
setup_logging(verbose)
return function(**args_dict)


def add_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--version",
action="version",
Expand Down Expand Up @@ -195,24 +204,26 @@ def mila():
code_parser = subparsers.add_parser(
"code",
help="Open a remote VSCode session on a compute node.",
formatter_class=SortingHelpFormatter,
)
code_parser.add_argument(
"PATH", help="Path to open on the remote machine", type=str
"PATH",
help=(
"Path to open on the remote machine. Defaults to $HOME.\n"
"Can be a relative or absolute path. When a relative path (that doesn't "
"start with a '/', like foo/bar) is passed, the path is relative to the "
"$HOME directory on the selected cluster.\n"
"For example, foo/project will be interpreted as $HOME/foo/project."
),
type=str,
default=".",
nargs="?",
)
code_parser.add_argument(
"--cluster",
choices=CLUSTERS,
default="mila",
help="Which cluster to connect to.",
)
code_parser.add_argument(
"--alloc",
nargs=argparse.REMAINDER,
help="Extra options to pass to slurm",
metavar="VALUE",
default=[],
)
code_parser.add_argument(
"--command",
default=get_code_command(),
Expand All @@ -224,23 +235,20 @@ def mila():
)
code_parser.add_argument(
"--job",
type=str,
type=int,
default=None,
help="Job ID to connect to",
metavar="VALUE",
metavar="JOB_ID",
)
code_parser.add_argument(
"--node",
type=str,
default=None,
help="Node to connect to",
metavar="VALUE",
)
code_parser.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not",
metavar="NODE",
)
_add_allocation_options(code_parser)

code_parser.set_defaults(function=code)

# ----- mila sync vscode-extensions ------
Expand Down Expand Up @@ -350,7 +358,6 @@ def mila():
serve_lab_parser = serve_subparsers.add_parser(
"lab",
help="Start a Jupyterlab server.",
formatter_class=SortingHelpFormatter,
)
serve_lab_parser.add_argument(
"PATH",
Expand All @@ -366,7 +373,6 @@ def mila():
serve_notebook_parser = serve_subparsers.add_parser(
"notebook",
help="Start a Jupyter Notebook server.",
formatter_class=SortingHelpFormatter,
)
serve_notebook_parser.add_argument(
"PATH",
Expand All @@ -382,7 +388,6 @@ def mila():
serve_tensorboard_parser = serve_subparsers.add_parser(
"tensorboard",
help="Start a Tensorboard server.",
formatter_class=SortingHelpFormatter,
)
serve_tensorboard_parser.add_argument(
"LOGDIR", type=str, help="Path to the experiment logs"
Expand All @@ -395,7 +400,6 @@ def mila():
serve_mlflow_parser = serve_subparsers.add_parser(
"mlflow",
help="Start an MLFlow server.",
formatter_class=SortingHelpFormatter,
)
serve_mlflow_parser.add_argument(
"LOGDIR", type=str, help="Path to the experiment logs"
Expand All @@ -408,26 +412,33 @@ def mila():
serve_aim_parser = serve_subparsers.add_parser(
"aim",
help="Start an AIM server.",
formatter_class=SortingHelpFormatter,
)
serve_aim_parser.add_argument(
"LOGDIR", type=str, help="Path to the experiment logs"
)
_add_standard_server_args(serve_aim_parser)
serve_aim_parser.set_defaults(function=aim)


def parse_args(parser: argparse.ArgumentParser) -> tuple[int, Callable, dict[str, Any]]:
"""Parses the command-line arguments.

Returns the verbosity level, the function (or awaitable) to call, and the arguments
to the function.
"""
args = parser.parse_args()
args_dict = vars(args)

verbose: int = args_dict.pop("verbose")

function = args_dict.pop("function")
_ = args_dict.pop("<command>")
_ = args_dict.pop("<serve_subcommand>", None)
_ = args_dict.pop("<sync_subcommand>", None)
setup_logging(verbose)
# replace SEARCH -> "search", REMOTE -> "remote", etc.
args_dict = _convert_uppercase_keys_to_lowercase(args_dict)
assert callable(function)
return function(**args_dict)
return verbose, function, args_dict


def setup_logging(verbose: int) -> None:
Expand Down Expand Up @@ -537,7 +548,7 @@ def code(
path: str,
command: str,
persist: bool,
job: str | None,
job: int | None,
node: str | None,
alloc: list[str],
cluster: Cluster = "mila",
Expand Down Expand Up @@ -775,7 +786,7 @@ class StandardServerArgs(TypedDict):
alloc: list[str]
"""Extra options to pass to slurm."""

job: str | None
job: int | None
"""Job ID to connect to."""

name: str | None
Expand Down Expand Up @@ -918,20 +929,56 @@ def add_arguments(self, actions):
super().add_arguments(actions)


def _add_standard_server_args(parser: ArgumentParser):
parser.add_argument(
def _add_allocation_options(parser: ArgumentParser):
# note: Ideally we'd like [--persist --alloc] | [--salloc] | [--sbatch] (i.e. a
# subgroup with alloc and persist within a mutually exclusive group with salloc and
# sbatch) but that doesn't seem possible with argparse as far as I can tell.
satyaog marked this conversation as resolved.
Show resolved Hide resolved
arg_group = parser.add_argument_group(
"Allocation options", description="Extra options to pass to slurm."
)
alloc_group = arg_group.add_mutually_exclusive_group()
common_kwargs = {
"dest": "alloc",
"nargs": argparse.REMAINDER,
"action": AllocationFlagsAction,
"metavar": "VALUE",
"default": [],
}
alloc_group.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not when using --alloc",
)
# --persist can be used with --alloc
arg_group.add_argument(
"--alloc",
nargs=argparse.REMAINDER,
help="Extra options to pass to slurm",
metavar="VALUE",
default=[],
**common_kwargs,
help="Extra options to pass to salloc or to sbatch if --persist is set.",
)
# --persist cannot be used with --salloc or --sbatch.
# Note: REMAINDER args like --alloc, --sbatch and --salloc are already mutually
# exclusive in a sense, since it's only possible to use one correctly, the other
# args are stored in the first one (e.g. mila code --alloc --salloc bob will have
# alloc of ["--salloc", "bob"]).
alloc_group.add_argument(
"--salloc",
**common_kwargs,
help="Extra options to pass to salloc. Same as using --alloc without --persist.",
)
alloc_group.add_argument(
"--sbatch",
**common_kwargs,
help="Extra options to pass to sbatch. Same as using --alloc with --persist.",
)


def _add_standard_server_args(parser: ArgumentParser):
parser.add_argument(
"--job",
type=str,
type=int,
default=None,
help="Job ID to connect to",
metavar="VALUE",
metavar="JOB_ID",
)
parser.add_argument(
"--name",
Expand All @@ -947,11 +994,6 @@ def _add_standard_server_args(parser: ArgumentParser):
help="Node to connect to",
metavar="VALUE",
)
parser.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not",
)
parser.add_argument(
"--port",
type=int,
Expand All @@ -966,6 +1008,8 @@ def _add_standard_server_args(parser: ArgumentParser):
help="Name of the profile to use",
metavar="VALUE",
)
# Add these arguments last because we want them to show up last in the usage message
_add_allocation_options(parser)


def _standard_server(
Expand All @@ -979,7 +1023,7 @@ def _standard_server(
port: int | None,
name: str | None,
node: str | None,
job: str | None,
job: int | None,
alloc: list[str],
port_pattern=None,
token_pattern=None,
Expand Down Expand Up @@ -1264,7 +1308,7 @@ def get_colour(used: float, max: float) -> str:
def _find_allocation(
remote: Remote,
node: str | None,
job: str | None,
job: int | str | None,
alloc: list[str],
cluster: Cluster = "mila",
job_name: str = "mila-tools",
Expand Down
31 changes: 31 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import argparse
import contextvars
import functools
import itertools
Expand All @@ -13,6 +14,7 @@
import warnings
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Literal, TypeVar, Union, get_args

Expand All @@ -26,6 +28,7 @@
if typing.TYPE_CHECKING:
from milatools.cli.remote import Remote

logger = get_logger(__name__)
control_file_var = contextvars.ContextVar("control_file", default="/dev/null")


Expand Down Expand Up @@ -347,3 +350,31 @@ def removesuffix(s: str, suffix: str) -> str:
return s
else:
removesuffix = str.removesuffix


class AllocationFlagsAction(argparse._StoreAction):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace,
values: list[str],
option_string: str | None = None,
):
persist: bool | None = namespace.persist
if option_string == "--alloc":
namespace.alloc = values
elif option_string == "--salloc":
# --salloc is in a mutually exclusive group with --persist
assert not persist
if persist:
raise argparse.ArgumentError(
argument=self,
message="Cannot use --salloc with --persist, use only --sbatch for a persistent session.",
)
namespace.alloc = values
else:
assert option_string == "--sbatch", option_string
# --sbatch is in a mutually exclusive group with --persist
assert not persist
namespace.alloc = values
namespace.persist = True
2 changes: 1 addition & 1 deletion tests/cli/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_help(
[
"mila", # Error: Missing a subcommand.
"mila search conda",
"mila code", # Error: Missing the required PATH argument.
"mila code --boo", # Error: Unknown argument.
"mila serve", # Error: Missing the subcommand.
"mila forward", # Error: Missing the REMOTE argument.
],
Expand Down
Loading
Loading