Skip to content

Commit

Permalink
Add --salloc and --sbatch flags (see desc.)
Browse files Browse the repository at this point in the history
- Adds a --salloc flag which is exaclty the same as using the '--alloc'
  flag (without the --persist) flag.
- Adds a --sbatch flag which is the same as doing --persist --alloc ...

I think these are more naturally understood as the argument that are
passed to `salloc` and `sbatch` respectively.

Also, these two new args are in a mutually exclusive group with
--persist.

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Apr 25, 2024
1 parent 6d8ebac commit e137096
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 24 deletions.
69 changes: 45 additions & 24 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .remote import Remote, SlurmRemote
from .utils import (
CLUSTERS,
AllocationFlagsAction,
Cluster,
CommandNotFoundError,
MilatoolsUserError,
Expand Down Expand Up @@ -224,13 +225,6 @@ def add_arguments(parser: argparse.ArgumentParser):
default="mila",
help="Which cluster to connect to.",
)
code_parser.add_argument(
"--alloc",
nargs=argparse.REMAINDER,
help="Extra options to pass to slurm",
metavar="VALUE",
default=[],
)
code_parser.add_argument(
"--command",
default=get_code_command(),
Expand All @@ -252,13 +246,10 @@ def add_arguments(parser: argparse.ArgumentParser):
type=str,
default=None,
help="Node to connect to",
metavar="VALUE",
)
code_parser.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not",
metavar="NODE",
)
_add_allocation_options(code_parser)

code_parser.set_defaults(function=code)

# ----- mila sync vscode-extensions ------
Expand Down Expand Up @@ -944,14 +935,49 @@ def add_arguments(self, actions):
super().add_arguments(actions)


def _add_standard_server_args(parser: ArgumentParser):
parser.add_argument(
def _add_allocation_options(parser: ArgumentParser):
arg_group = parser.add_argument_group(
"Allocation options", description="Extra options to pass to slurm."
)
common_kwargs = {
"dest": "alloc",
"nargs": argparse.REMAINDER,
"action": AllocationFlagsAction,
"metavar": "VALUE",
"default": [],
}
arg_group.add_argument(
"--alloc",
nargs=argparse.REMAINDER,
help="Extra options to pass to slurm",
metavar="VALUE",
default=[],
**common_kwargs,
help="Extra options to pass to salloc or to sbatch if --persist is set.",
)
alloc_group = arg_group.add_mutually_exclusive_group()
alloc_group.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not when using --alloc",
)

# --persist cannot be used with --salloc or --sbatch.
# Note: REMAINDER args like --alloc, --sbatch and --salloc are already mutually
# exclusive in a sense, since it's only possible to use one correctly, the other
# args are stored in the first one (e.g. mila code --alloc --salloc bob will have
# alloc of ["--salloc", "bob"]).

alloc_group.add_argument(
"--salloc",
**common_kwargs,
help="Extra options to pass to salloc. Same as using --alloc without --persist.",
)
alloc_group.add_argument(
"--sbatch",
**common_kwargs,
help="Extra options to pass to sbatch (equivalent to --persist --alloc [...])",
)


def _add_standard_server_args(parser: ArgumentParser):
_add_allocation_options(parser)
parser.add_argument(
"--job",
type=int,
Expand All @@ -973,11 +999,6 @@ def _add_standard_server_args(parser: ArgumentParser):
help="Node to connect to",
metavar="VALUE",
)
parser.add_argument(
"--persist",
action="store_true",
help="Whether the server should persist or not",
)
parser.add_argument(
"--port",
type=int,
Expand Down
31 changes: 31 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import argparse
import contextvars
import functools
import itertools
Expand All @@ -13,6 +14,7 @@
import warnings
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Literal, TypeVar, Union, get_args

Expand All @@ -26,6 +28,7 @@
if typing.TYPE_CHECKING:
from milatools.cli.remote import Remote

logger = get_logger(__name__)
control_file_var = contextvars.ContextVar("control_file", default="/dev/null")


Expand Down Expand Up @@ -347,3 +350,31 @@ def removesuffix(s: str, suffix: str) -> str:
return s
else:
removesuffix = str.removesuffix


class AllocationFlagsAction(argparse._StoreAction):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace,
values: list[str],
option_string: str | None = None,
):
persist: bool | None = namespace.persist
if option_string == "--alloc":
namespace.alloc = values
elif option_string == "--salloc":
# --salloc is in a mutually exclusive group with --persist
assert not persist
if persist:
raise argparse.ArgumentError(
argument=self,
message="Cannot use --salloc with --persist, use only --sbatch for a persistent session.",
)
namespace.alloc = values
else:
assert option_string == "--sbatch", option_string
# --sbatch is in a mutually exclusive group with --persist
assert not persist
namespace.alloc = values
namespace.persist = True

0 comments on commit e137096

Please sign in to comment.