diff --git a/milatools/cli/code.py b/milatools/cli/code.py index bf3f66dc..3e8beaeb 100644 --- a/milatools/cli/code.py +++ b/milatools/cli/code.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import asyncio import shlex import shutil import sys @@ -11,6 +12,7 @@ from milatools.cli import console from milatools.cli.common import ( check_disk_quota, + check_disk_quota_v1, find_allocation, ) from milatools.cli.init_command import DRAC_CLUSTERS @@ -103,6 +105,15 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction): code_parser.set_defaults(function=code) +async def _check_disk_quota_task(remote: RemoteV2) -> None: + try: + await check_disk_quota(remote) + except MilatoolsUserError: + raise + except Exception as exc: + logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") + + async def code( path: str, command: str, @@ -134,16 +145,11 @@ async def code( if not path.startswith("/"): # Get $HOME because we have to give the full path to code - home = login_node.get_output("echo $HOME", display=False, hide=True) + home = await login_node.get_output_async("echo $HOME", display=False, hide=True) path = home if path == "." else f"{home}/{path}" - try: - check_disk_quota(login_node) - except MilatoolsUserError: - # Raise errors that are meant to be shown to the user (disk quota is reached). - raise - except Exception as exc: - logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") + check_disk_quota_task = asyncio.create_task(_check_disk_quota_task(login_node)) + # Raise errors that are meant to be shown to the user (disk quota is reached). # NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an # unknown cluster? @@ -151,23 +157,23 @@ async def code( # Sync the VsCode extensions from the local machine over to the target cluster. console.log( f"Installing VSCode extensions that are on the local machine on " - f"{cluster} in the background.", + f"{cluster}.", style="cyan", ) # todo: use the mila or the local machine as the reference for vscode # extensions? # todo: could also perhaps make this function asynchronous instead of using a # multiprocessing process for it. - copy_vscode_extensions_process = make_process( - sync_vscode_extensions, - LocalV2(), - [login_node], + sync_vscode_extensions_task = asyncio.create_task( + asyncio.to_thread( + sync_vscode_extensions, + LocalV2(), + [login_node], + ), + name="sync_vscode_extensions", ) - copy_vscode_extensions_process.start() - # todo: could potentially do this at the same time as the blocks above and just wait - # for the result here, instead of running each block in sequence. - if currently_in_a_test(): - copy_vscode_extensions_process.join() + + compute_node_task: asyncio.Task[ComputeNode] if job or node: if job and node: @@ -177,7 +183,9 @@ async def code( ) job_id_or_node = job or node assert job_id_or_node is not None - compute_node = await connect_to_running_job(job_id_or_node, login_node) + compute_node_task = asyncio.create_task( + connect_to_running_job(job_id_or_node, login_node) + ) else: if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc): logger.warning( @@ -195,24 +203,57 @@ async def code( # todo: Get the job name from the flags instead? raise MilatoolsUserError( "The job name flag (--job-name or -J) should be left unset for now " - "because we use the job name to gage how many people use `mila code` " - "on the various clusters. We also make use of the job name when the " - "call to `salloc` is interrupted before we have a chance to know the " - "job id." + "because we use the job name to measure how many people use `mila " + "code` on the various clusters. We also make use of the job name when " + "the call to `salloc` is interrupted before we have a chance to know " + "the job id." ) job_name = "mila-code" alloc = alloc + [f"--job-name={job_name}"] if persist: - compute_node = await sbatch( - login_node, sbatch_flags=alloc, job_name=job_name + compute_node_task = asyncio.create_task( + sbatch(login_node, sbatch_flags=alloc, job_name=job_name) ) + # compute_node = await sbatch( + # login_node, sbatch_flags=alloc, job_name=job_name + # ) else: # NOTE: Here we actually need the job name to be known, so that we can # scancel jobs if the call is interrupted. - compute_node = await salloc( - login_node, salloc_flags=alloc, job_name=job_name + compute_node_task = asyncio.create_task( + salloc(login_node, salloc_flags=alloc, job_name=job_name) ) + # compute_node = await salloc( + # login_node, salloc_flags=alloc, job_name=job_name + # ) + try: + _, _, compute_node = await asyncio.gather( + check_disk_quota_task, + sync_vscode_extensions_task, + compute_node_task, + return_exceptions=True, + ) + except: + # If any of the tasks failed, we want to raise the exception. + for task in ( + check_disk_quota_task, + sync_vscode_extensions_task, + compute_node_task, + ): + if not task.done(): + task.cancel() + for task in ( + check_disk_quota_task, + sync_vscode_extensions_task, + compute_node_task, + ): + if exception := task.exception(): + raise exception + raise + + if isinstance(compute_node, BaseException): + raise compute_node try: while True: @@ -332,7 +373,7 @@ def code_v1( command = get_code_command() try: - check_disk_quota(remote) + check_disk_quota_v1(remote) except MilatoolsUserError: raise except Exception as exc: diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index bf26159a..171fe40c 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -23,6 +23,7 @@ import rich.logging from typing_extensions import TypedDict +from milatools.cli import console from milatools.utils.vscode_utils import ( sync_vscode_extensions_with_hostnames, ) @@ -394,8 +395,10 @@ def setup_logging(verbose: int) -> None: ) logging.basicConfig( level=global_loglevel, - format="%(asctime)s - %(levelname)s - %(message)s", - handlers=[rich.logging.RichHandler(markup=True, rich_tracebacks=True)], + format="%(message)s", + handlers=[ + rich.logging.RichHandler(markup=True, rich_tracebacks=True, console=console) + ], ) get_logger("milatools").setLevel(package_loglevel) diff --git a/milatools/cli/common.py b/milatools/cli/common.py index 39818891..f239a269 100644 --- a/milatools/cli/common.py +++ b/milatools/cli/common.py @@ -80,7 +80,7 @@ def _parse_lfs_quota_output( return (used_gb, max_gb), (used_files, max_files) -def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None: +async def check_disk_quota(remote: RemoteV2) -> None: cluster = remote.hostname # NOTE: This is what the output of the command looks like on the Mila cluster: @@ -92,17 +92,36 @@ def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None: # uid 1471600598 is using default block quota setting # uid 1471600598 is using default file quota setting - # Need to assert this, otherwise .get_output calls .run which would spawn a job! - assert not isinstance(remote, SlurmRemote) - if not remote.get_output("which lfs", display=False, hide=True): + if not (await remote.get_output_async("which lfs", display=False, hide=True)): logger.debug("Cluster doesn't have the lfs command. Skipping check.") return console.log("Checking disk quota on $HOME...") + home_disk_quota_output = await remote.get_output_async( + "lfs quota -u $USER $HOME", display=False, hide=True + ) + _check_disk_quota_common_part(home_disk_quota_output, cluster) + + +def check_disk_quota_v1(remote: RemoteV1 | RemoteV2) -> None: + cluster = remote.hostname + # Need to check for this, because SlurmRemote is a subclass of RemoteV1 and + # .get_output calls SlurmRemote.run which would spawn a job! + assert not isinstance(remote, SlurmRemote) + + if not (remote.get_output("which lfs", display=False, hide=True)): + logger.debug("Cluster doesn't have the lfs command. Skipping check.") + return + + console.log("Checking disk quota on $HOME...") home_disk_quota_output = remote.get_output( "lfs quota -u $USER $HOME", display=False, hide=True ) + _check_disk_quota_common_part(home_disk_quota_output, cluster) + + +def _check_disk_quota_common_part(home_disk_quota_output: str, cluster: str): if "not on a mounted Lustre filesystem" in home_disk_quota_output: logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") return diff --git a/milatools/utils/compute_node.py b/milatools/utils/compute_node.py index 6ce6b9f8..fe3eaf90 100644 --- a/milatools/utils/compute_node.py +++ b/milatools/utils/compute_node.py @@ -236,11 +236,8 @@ async def salloc( login_node: RemoteV2, salloc_flags: list[str], job_name: str ) -> ComputeNode: """Runs `salloc` and returns a remote connected to the compute node.""" - - salloc_command = "salloc " + shlex.join(salloc_flags) - if login_node.hostname in DRAC_CLUSTERS: - salloc_command = f"cd $SCRATCH && {salloc_command}" - + # NOTE: Some SLURM clusters prevent submitting jobs from $HOME. + salloc_command = "cd $SCRATCH && salloc " + shlex.join(salloc_flags) command = ssh_command( hostname=login_node.hostname, control_path=login_node.control_path, @@ -320,12 +317,20 @@ async def sbatch( This then waits asynchronously until the job show us as RUNNING in the output of the `sacct` command. """ - # idea: Find the job length from the sbatch flags if possible so we can do - # --wrap='sleep {job_duration}' instead of 'sleep 7d'. + # NOTE: cd to $SCRATCH because some SLURM clusters prevent submitting jobs from the + # HOME directory. Also, if a cluster doesn't have $SCRACTCH set, we just stay in the + # home directory, so no harm done. # todo: Should we use --ntasks=1 --overlap in the wrapped `srun`, so that only one # task sleeps? Does that change anything? + # todo: Find the job length from the sbatch flags if possible so we can do + # --wrap='sleep {job_duration}' instead of 'sleep 7d' (perhaps it would change the + # exit state of the job if the actual srun ends peacefull instead of being + # cancelled? (from CANCELLED to COMPLETED)? Although I'm not sure what the + # final job state currently is.) sbatch_command = ( - "sbatch --parsable " + shlex.join(sbatch_flags) + " --wrap 'srun sleep 7d'" + "cd $SCRATCH && sbatch --parsable " + + shlex.join(sbatch_flags) + + " --wrap 'srun sleep 7d'" ) cluster = login_node.hostname if cluster in DRAC_CLUSTERS: diff --git a/milatools/utils/vscode_utils.py b/milatools/utils/vscode_utils.py index 6d1955fa..c6f066f9 100644 --- a/milatools/utils/vscode_utils.py +++ b/milatools/utils/vscode_utils.py @@ -61,7 +61,7 @@ def get_code_command() -> str: return os.environ.get("MILATOOLS_CODE_COMMAND", "code") -def get_vscode_executable_path(code_command: str | None = None) -> str: +def get_local_vscode_executable_path(code_command: str | None = None) -> str: if code_command is None: code_command = get_code_command() @@ -73,7 +73,7 @@ def get_vscode_executable_path(code_command: str | None = None) -> str: def vscode_installed() -> bool: try: - _ = get_vscode_executable_path() + _ = get_local_vscode_executable_path() except CommandNotFoundError: return False return True @@ -227,7 +227,7 @@ def _update_progress( if isinstance(remote, LocalV2): assert dest_hostname == "localhost" - code_server_executable = get_vscode_executable_path() + code_server_executable = get_local_vscode_executable_path() extensions_on_dest = get_local_vscode_extensions() else: dest_hostname = remote.hostname @@ -332,7 +332,7 @@ def install_vscode_extension( def get_local_vscode_extensions(code_command: str | None = None) -> dict[str, str]: output = subprocess.run( ( - get_vscode_executable_path(code_command=code_command), + get_local_vscode_executable_path(code_command=code_command), "--list-extensions", "--show-versions", ), @@ -398,7 +398,7 @@ def extensions_to_install( def find_code_server_executable( - remote: RemoteV1 | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" + remote: RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" ) -> str | None: """Find the most recent `code-server` executable on the remote. diff --git a/tests/integration/test_code.py b/tests/integration/test_code.py index 3d2e8ff1..391cf436 100644 --- a/tests/integration/test_code.py +++ b/tests/integration/test_code.py @@ -17,14 +17,11 @@ from milatools.cli.common import check_disk_quota from milatools.cli.utils import get_hostname_to_use_for_compute_node from milatools.utils.compute_node import ComputeNode -from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 -from ..cli.common import skip_param_if_on_github_ci from ..conftest import job_name, launches_jobs from .conftest import ( - skip_if_not_already_logged_in, - skip_param_if_not_already_logged_in, + SLURM_CLUSTER, ) from .test_slurm_remote import get_recent_jobs_info_dicts @@ -32,37 +29,25 @@ @pytest.mark.slow -@pytest.mark.parametrize( - "cluster", - [ - skip_param_if_on_github_ci("mila"), - skip_param_if_not_already_logged_in("narval"), - skip_param_if_not_already_logged_in("beluga"), - skip_param_if_not_already_logged_in("cedar"), - pytest.param( - "graham", - marks=[ - skip_if_not_already_logged_in("graham"), - pytest.mark.xfail( - raises=subprocess.CalledProcessError, - reason="Graham doesn't use a lustre filesystem for $HOME.", - strict=True, - ), - ], - ), - skip_param_if_not_already_logged_in("niagara"), - ], - indirect=True, +@pytest.mark.xfail( + SLURM_CLUSTER == "graham", + raises=subprocess.CalledProcessError, + reason="Graham doesn't use a lustre filesystem for $HOME.", + strict=True, ) def test_check_disk_quota( - login_node: RemoteV1 | RemoteV2, + login_node_v2: RemoteV2, capsys: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture, -): # noqa: F811 +): + if login_node_v2.hostname == "localhost": + pytest.skip(reason="Test doesn't work on localhost.") + with caplog.at_level(logging.DEBUG): - check_disk_quota(remote=login_node) - # TODO: Maybe figure out a way to actually test this, (not just by running it and - # expecting no errors). + check_disk_quota(remote=login_node_v2) + # TODO: Maybe figure out a way to actually test this, (apart from just running it + # and expecting no errors). + # Check that it doesn't raise any errors. # IF the quota is nearly met, then a warning is logged. # IF the quota is met, then a `MilatoolsUserError` is logged. @@ -96,9 +81,9 @@ async def get_job_info( @pytest.mark.parametrize("persist", [True, False], ids=["sbatch", "salloc"]) @pytest.mark.parametrize( job_name.__name__, - [ - None, - ], + # Don't set the `--job-name` in the `allocation_flags` fixture + # (this is necessary for `mila code` to work properly). + [None], ids=[""], indirect=True, ) diff --git a/tests/integration/test_code/test_code_mila0_None_True_.txt b/tests/integration/test_code/test_code_mila0_None_True_.txt deleted file mode 100644 index 240c33f7..00000000 --- a/tests/integration/test_code/test_code_mila0_None_True_.txt +++ /dev/null @@ -1,14 +0,0 @@ -Checking disk quota on $HOME... -Disk usage: X / LIMIT GiB and X / LIMIT files -(SLURM_ACCOUNT) $ sbatch --parsable --wckey=SLURM_ACCOUNTtools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=SLURM_ACCOUNT-code --wrap 'srun sleep 7d' -JOB_ID - -(local) echo --new-window --wait --remote ssh-remote+cn-f002.server.SLURM_ACCOUNT.quebec /home/SLURM_ACCOUNT/n/normandf/bob ---new-window --wait --remote ssh-remote+cn-f002.server.SLURM_ACCOUNT.quebec /home/SLURM_ACCOUNT/n/normandf/bob - -The editor was closed. Reopen it with or terminate the process with (maybe twice). -This allocation is persistent and is still active. -To reconnect to this job, run the following: - SLURM_ACCOUNT code /home/SLURM_ACCOUNT/n/normandf/bob --job JOB_ID -To kill this allocation: - ssh SLURM_ACCOUNT scancel JOB_ID \ No newline at end of file diff --git a/tests/integration/test_code/test_code_mila__salloc_.txt b/tests/integration/test_code/test_code_mila__salloc_.txt index b08db608..c18befc7 100644 --- a/tests/integration/test_code/test_code_mila__salloc_.txt +++ b/tests/integration/test_code/test_code_mila__salloc_.txt @@ -1,6 +1,6 @@ Checking disk quota on $HOME... Disk usage: X / LIMIT GiB and X / LIMIT files -(mila) $ salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code +(mila) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code Waiting for job JOB_ID to start. (local) echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob diff --git a/tests/integration/test_code/test_code_mila__sbatch_.txt b/tests/integration/test_code/test_code_mila__sbatch_.txt index 702af262..ec932601 100644 --- a/tests/integration/test_code/test_code_mila__sbatch_.txt +++ b/tests/integration/test_code/test_code_mila__sbatch_.txt @@ -1,6 +1,6 @@ Checking disk quota on $HOME... Disk usage: X / LIMIT GiB and X / LIMIT files -(mila) $ sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep 7d' +(mila) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep 7d' JOB_ID (local) echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob diff --git a/tests/utils/test_vscode_utils.py b/tests/utils/test_vscode_utils.py index 217766b7..293c7288 100644 --- a/tests/utils/test_vscode_utils.py +++ b/tests/utils/test_vscode_utils.py @@ -19,9 +19,9 @@ find_code_server_executable, get_code_command, get_expected_vscode_settings_json_path, + get_local_vscode_executable_path, get_local_vscode_extensions, get_remote_vscode_extensions, - get_vscode_executable_path, install_vscode_extension, install_vscode_extensions_task_function, sync_vscode_extensions, @@ -90,13 +90,13 @@ def test_running_inside_WSL(): def test_get_vscode_executable_path(): if vscode_installed(): - code = get_vscode_executable_path() + code = get_local_vscode_executable_path() assert Path(code).exists() else: with pytest.raises( MilatoolsUserError, match="Command 'code' does not exist locally." ): - get_vscode_executable_path() + get_local_vscode_executable_path() @pytest.fixture @@ -105,7 +105,8 @@ def mock_find_code_server_executable(monkeypatch: pytest.MonkeyPatch): import milatools.utils.vscode_utils mock_find_code_server_executable = Mock( - spec=find_code_server_executable, return_value=get_vscode_executable_path() + spec=find_code_server_executable, + return_value=get_local_vscode_executable_path(), ) monkeypatch.setattr( milatools.utils.vscode_utils, @@ -128,7 +129,8 @@ def test_sync_vscode_extensions_in_parallel_with_hostnames( milatools.utils.vscode_utils, find_code_server_executable.__name__, Mock( - spec=find_code_server_executable, return_value=get_vscode_executable_path() + spec=find_code_server_executable, + return_value=get_local_vscode_executable_path(), ), ) sync_vscode_extensions_with_hostnames( @@ -170,7 +172,7 @@ def vscode_extensions( # the "remote" (just to localhost during tests) is missing some extensions mock_remote_extensions = Mock( spec=get_remote_vscode_extensions, - return_value=(installed_extensions, str(get_vscode_executable_path())), + return_value=(installed_extensions, str(get_local_vscode_executable_path())), ) monkeypatch.setattr( milatools.utils.vscode_utils, @@ -251,7 +253,7 @@ def test_install_vscode_extension(missing_extensions: dict[str, str]): extension_name, version = next(iter(missing_extensions.items())) result = install_vscode_extension( remote=RemoteV2("localhost"), - code_server_executable=str(get_vscode_executable_path()), + code_server_executable=str(get_local_vscode_executable_path()), extension=f"{extension_name}@{version}", verbose=False, ) @@ -280,7 +282,7 @@ def test_get_remote_vscode_extensions(): # therefore the "remote" extensions are the same as the local extensions. fake_remote = RemoteV2("localhost") - local_vscode_executable = get_vscode_executable_path() + local_vscode_executable = get_local_vscode_executable_path() assert local_vscode_executable is not None fake_remote_extensions = get_remote_vscode_extensions(