diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index f257808f..8a465c02 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -52,6 +52,7 @@ from .remote import Remote, SlurmRemote from .utils import ( CLUSTERS, + AllocationFlagsAction, Cluster, CommandNotFoundError, MilatoolsUserError, @@ -224,13 +225,6 @@ def add_arguments(parser: argparse.ArgumentParser): default="mila", help="Which cluster to connect to.", ) - code_parser.add_argument( - "--alloc", - nargs=argparse.REMAINDER, - help="Extra options to pass to slurm", - metavar="VALUE", - default=[], - ) code_parser.add_argument( "--command", default=get_code_command(), @@ -252,13 +246,10 @@ def add_arguments(parser: argparse.ArgumentParser): type=str, default=None, help="Node to connect to", - metavar="VALUE", - ) - code_parser.add_argument( - "--persist", - action="store_true", - help="Whether the server should persist or not", + metavar="NODE", ) + _add_allocation_options(code_parser) + code_parser.set_defaults(function=code) # ----- mila sync vscode-extensions ------ @@ -944,14 +935,49 @@ def add_arguments(self, actions): super().add_arguments(actions) -def _add_standard_server_args(parser: ArgumentParser): - parser.add_argument( +def _add_allocation_options(parser: ArgumentParser): + arg_group = parser.add_argument_group( + "Allocation options", description="Extra options to pass to slurm." + ) + common_kwargs = { + "dest": "alloc", + "nargs": argparse.REMAINDER, + "action": AllocationFlagsAction, + "metavar": "VALUE", + "default": [], + } + arg_group.add_argument( "--alloc", - nargs=argparse.REMAINDER, - help="Extra options to pass to slurm", - metavar="VALUE", - default=[], + **common_kwargs, + help="Extra options to pass to salloc or to sbatch if --persist is set.", ) + alloc_group = arg_group.add_mutually_exclusive_group() + alloc_group.add_argument( + "--persist", + action="store_true", + help="Whether the server should persist or not when using --alloc", + ) + + # --persist cannot be used with --salloc or --sbatch. + # Note: REMAINDER args like --alloc, --sbatch and --salloc are already mutually + # exclusive in a sense, since it's only possible to use one correctly, the other + # args are stored in the first one (e.g. mila code --alloc --salloc bob will have + # alloc of ["--salloc", "bob"]). + + alloc_group.add_argument( + "--salloc", + **common_kwargs, + help="Extra options to pass to salloc. Same as using --alloc without --persist.", + ) + alloc_group.add_argument( + "--sbatch", + **common_kwargs, + help="Extra options to pass to sbatch (equivalent to --persist --alloc [...])", + ) + + +def _add_standard_server_args(parser: ArgumentParser): + _add_allocation_options(parser) parser.add_argument( "--job", type=int, @@ -973,11 +999,6 @@ def _add_standard_server_args(parser: ArgumentParser): help="Node to connect to", metavar="VALUE", ) - parser.add_argument( - "--persist", - action="store_true", - help="Whether the server should persist or not", - ) parser.add_argument( "--port", type=int, diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index e62b79ce..6660f68b 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import argparse import contextvars import functools import itertools @@ -13,6 +14,7 @@ import warnings from collections.abc import Callable, Iterable from contextlib import contextmanager +from logging import getLogger as get_logger from pathlib import Path from typing import Any, Literal, TypeVar, Union, get_args @@ -26,6 +28,7 @@ if typing.TYPE_CHECKING: from milatools.cli.remote import Remote +logger = get_logger(__name__) control_file_var = contextvars.ContextVar("control_file", default="/dev/null") @@ -347,3 +350,31 @@ def removesuffix(s: str, suffix: str) -> str: return s else: removesuffix = str.removesuffix + + +class AllocationFlagsAction(argparse._StoreAction): + def __call__( + self, + parser: argparse.ArgumentParser, + namespace, + values: list[str], + option_string: str | None = None, + ): + persist: bool | None = namespace.persist + if option_string == "--alloc": + namespace.alloc = values + elif option_string == "--salloc": + # --salloc is in a mutually exclusive group with --persist + assert not persist + if persist: + raise argparse.ArgumentError( + argument=self, + message="Cannot use --salloc with --persist, use only --sbatch for a persistent session.", + ) + namespace.alloc = values + else: + assert option_string == "--sbatch", option_string + # --sbatch is in a mutually exclusive group with --persist + assert not persist + namespace.alloc = values + namespace.persist = True