Skip to content

Commit

Permalink
Use tasks for each subpart of mila code
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Apr 19, 2024
1 parent 7199c5a commit a2efc28
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 104 deletions.
97 changes: 69 additions & 28 deletions milatools/cli/code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import argparse
import asyncio
import shlex
import shutil
import sys
Expand All @@ -11,6 +12,7 @@
from milatools.cli import console
from milatools.cli.common import (
check_disk_quota,
check_disk_quota_v1,
find_allocation,
)
from milatools.cli.init_command import DRAC_CLUSTERS
Expand Down Expand Up @@ -103,6 +105,15 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction):
code_parser.set_defaults(function=code)


async def _check_disk_quota_task(remote: RemoteV2) -> None:
try:
await check_disk_quota(remote)
except MilatoolsUserError:
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")


async def code(
path: str,
command: str,
Expand Down Expand Up @@ -134,40 +145,35 @@ async def code(

if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
home = login_node.get_output("echo $HOME", display=False, hide=True)
home = await login_node.get_output_async("echo $HOME", display=False, hide=True)
path = home if path == "." else f"{home}/{path}"

try:
check_disk_quota(login_node)
except MilatoolsUserError:
# Raise errors that are meant to be shown to the user (disk quota is reached).
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")
check_disk_quota_task = asyncio.create_task(_check_disk_quota_task(login_node))
# Raise errors that are meant to be shown to the user (disk quota is reached).

# NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an
# unknown cluster?
if no_internet_on_compute_nodes(cluster):
# Sync the VsCode extensions from the local machine over to the target cluster.
console.log(
f"Installing VSCode extensions that are on the local machine on "
f"{cluster} in the background.",
f"{cluster}.",
style="cyan",
)
# todo: use the mila or the local machine as the reference for vscode
# extensions?
# todo: could also perhaps make this function asynchronous instead of using a
# multiprocessing process for it.
copy_vscode_extensions_process = make_process(
sync_vscode_extensions,
LocalV2(),
[login_node],
sync_vscode_extensions_task = asyncio.create_task(
asyncio.to_thread(
sync_vscode_extensions,
LocalV2(),
[login_node],
),
name="sync_vscode_extensions",
)
copy_vscode_extensions_process.start()
# todo: could potentially do this at the same time as the blocks above and just wait
# for the result here, instead of running each block in sequence.
if currently_in_a_test():
copy_vscode_extensions_process.join()

compute_node_task: asyncio.Task[ComputeNode]

if job or node:
if job and node:
Expand All @@ -177,7 +183,9 @@ async def code(
)
job_id_or_node = job or node
assert job_id_or_node is not None
compute_node = await connect_to_running_job(job_id_or_node, login_node)
compute_node_task = asyncio.create_task(
connect_to_running_job(job_id_or_node, login_node)
)
else:
if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc):
logger.warning(
Expand All @@ -195,24 +203,57 @@ async def code(
# todo: Get the job name from the flags instead?
raise MilatoolsUserError(
"The job name flag (--job-name or -J) should be left unset for now "
"because we use the job name to gage how many people use `mila code` "
"on the various clusters. We also make use of the job name when the "
"call to `salloc` is interrupted before we have a chance to know the "
"job id."
"because we use the job name to measure how many people use `mila "
"code` on the various clusters. We also make use of the job name when "
"the call to `salloc` is interrupted before we have a chance to know "
"the job id."
)
job_name = "mila-code"
alloc = alloc + [f"--job-name={job_name}"]

if persist:
compute_node = await sbatch(
login_node, sbatch_flags=alloc, job_name=job_name
compute_node_task = asyncio.create_task(
sbatch(login_node, sbatch_flags=alloc, job_name=job_name)
)
# compute_node = await sbatch(
# login_node, sbatch_flags=alloc, job_name=job_name
# )
else:
# NOTE: Here we actually need the job name to be known, so that we can
# scancel jobs if the call is interrupted.
compute_node = await salloc(
login_node, salloc_flags=alloc, job_name=job_name
compute_node_task = asyncio.create_task(
salloc(login_node, salloc_flags=alloc, job_name=job_name)
)
# compute_node = await salloc(
# login_node, salloc_flags=alloc, job_name=job_name
# )
try:
_, _, compute_node = await asyncio.gather(
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
return_exceptions=True,
)
except:
# If any of the tasks failed, we want to raise the exception.
for task in (
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
):
if not task.done():
task.cancel()
for task in (
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
):
if exception := task.exception():
raise exception
raise

if isinstance(compute_node, BaseException):
raise compute_node

try:
while True:
Expand Down Expand Up @@ -332,7 +373,7 @@ def code_v1(
command = get_code_command()

try:
check_disk_quota(remote)
check_disk_quota_v1(remote)
except MilatoolsUserError:
raise
except Exception as exc:
Expand Down
7 changes: 5 additions & 2 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import rich.logging
from typing_extensions import TypedDict

from milatools.cli import console
from milatools.utils.vscode_utils import (
sync_vscode_extensions_with_hostnames,
)
Expand Down Expand Up @@ -394,8 +395,10 @@ def setup_logging(verbose: int) -> None:
)
logging.basicConfig(
level=global_loglevel,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[rich.logging.RichHandler(markup=True, rich_tracebacks=True)],
format="%(message)s",
handlers=[
rich.logging.RichHandler(markup=True, rich_tracebacks=True, console=console)
],
)
get_logger("milatools").setLevel(package_loglevel)

Expand Down
27 changes: 23 additions & 4 deletions milatools/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _parse_lfs_quota_output(
return (used_gb, max_gb), (used_files, max_files)


def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None:
async def check_disk_quota(remote: RemoteV2) -> None:
cluster = remote.hostname

# NOTE: This is what the output of the command looks like on the Mila cluster:
Expand All @@ -92,17 +92,36 @@ def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None:
# uid 1471600598 is using default block quota setting
# uid 1471600598 is using default file quota setting

# Need to assert this, otherwise .get_output calls .run which would spawn a job!
assert not isinstance(remote, SlurmRemote)
if not remote.get_output("which lfs", display=False, hide=True):
if not (await remote.get_output_async("which lfs", display=False, hide=True)):
logger.debug("Cluster doesn't have the lfs command. Skipping check.")
return

console.log("Checking disk quota on $HOME...")

home_disk_quota_output = await remote.get_output_async(
"lfs quota -u $USER $HOME", display=False, hide=True
)
_check_disk_quota_common_part(home_disk_quota_output, cluster)


def check_disk_quota_v1(remote: RemoteV1 | RemoteV2) -> None:
cluster = remote.hostname
# Need to check for this, because SlurmRemote is a subclass of RemoteV1 and
# .get_output calls SlurmRemote.run which would spawn a job!
assert not isinstance(remote, SlurmRemote)

if not (remote.get_output("which lfs", display=False, hide=True)):
logger.debug("Cluster doesn't have the lfs command. Skipping check.")
return

console.log("Checking disk quota on $HOME...")
home_disk_quota_output = remote.get_output(
"lfs quota -u $USER $HOME", display=False, hide=True
)
_check_disk_quota_common_part(home_disk_quota_output, cluster)


def _check_disk_quota_common_part(home_disk_quota_output: str, cluster: str):
if "not on a mounted Lustre filesystem" in home_disk_quota_output:
logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.")
return
Expand Down
21 changes: 13 additions & 8 deletions milatools/utils/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,8 @@ async def salloc(
login_node: RemoteV2, salloc_flags: list[str], job_name: str
) -> ComputeNode:
"""Runs `salloc` and returns a remote connected to the compute node."""

salloc_command = "salloc " + shlex.join(salloc_flags)
if login_node.hostname in DRAC_CLUSTERS:
salloc_command = f"cd $SCRATCH && {salloc_command}"

# NOTE: Some SLURM clusters prevent submitting jobs from $HOME.
salloc_command = "cd $SCRATCH && salloc " + shlex.join(salloc_flags)
command = ssh_command(
hostname=login_node.hostname,
control_path=login_node.control_path,
Expand Down Expand Up @@ -320,12 +317,20 @@ async def sbatch(
This then waits asynchronously until the job show us as RUNNING in the output of the
`sacct` command.
"""
# idea: Find the job length from the sbatch flags if possible so we can do
# --wrap='sleep {job_duration}' instead of 'sleep 7d'.
# NOTE: cd to $SCRATCH because some SLURM clusters prevent submitting jobs from the
# HOME directory. Also, if a cluster doesn't have $SCRACTCH set, we just stay in the
# home directory, so no harm done.
# todo: Should we use --ntasks=1 --overlap in the wrapped `srun`, so that only one
# task sleeps? Does that change anything?
# todo: Find the job length from the sbatch flags if possible so we can do
# --wrap='sleep {job_duration}' instead of 'sleep 7d' (perhaps it would change the
# exit state of the job if the actual srun ends peacefull instead of being
# cancelled? (from CANCELLED to COMPLETED)? Although I'm not sure what the
# final job state currently is.)
sbatch_command = (
"sbatch --parsable " + shlex.join(sbatch_flags) + " --wrap 'srun sleep 7d'"
"cd $SCRATCH && sbatch --parsable "
+ shlex.join(sbatch_flags)
+ " --wrap 'srun sleep 7d'"
)
cluster = login_node.hostname
if cluster in DRAC_CLUSTERS:
Expand Down
10 changes: 5 additions & 5 deletions milatools/utils/vscode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_code_command() -> str:
return os.environ.get("MILATOOLS_CODE_COMMAND", "code")


def get_vscode_executable_path(code_command: str | None = None) -> str:
def get_local_vscode_executable_path(code_command: str | None = None) -> str:
if code_command is None:
code_command = get_code_command()

Expand All @@ -73,7 +73,7 @@ def get_vscode_executable_path(code_command: str | None = None) -> str:

def vscode_installed() -> bool:
try:
_ = get_vscode_executable_path()
_ = get_local_vscode_executable_path()
except CommandNotFoundError:
return False
return True
Expand Down Expand Up @@ -227,7 +227,7 @@ def _update_progress(

if isinstance(remote, LocalV2):
assert dest_hostname == "localhost"
code_server_executable = get_vscode_executable_path()
code_server_executable = get_local_vscode_executable_path()
extensions_on_dest = get_local_vscode_extensions()
else:
dest_hostname = remote.hostname
Expand Down Expand Up @@ -332,7 +332,7 @@ def install_vscode_extension(
def get_local_vscode_extensions(code_command: str | None = None) -> dict[str, str]:
output = subprocess.run(
(
get_vscode_executable_path(code_command=code_command),
get_local_vscode_executable_path(code_command=code_command),
"--list-extensions",
"--show-versions",
),
Expand Down Expand Up @@ -398,7 +398,7 @@ def extensions_to_install(


def find_code_server_executable(
remote: RemoteV1 | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server"
remote: RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server"
) -> str | None:
"""Find the most recent `code-server` executable on the remote.
Expand Down
Loading

0 comments on commit a2efc28

Please sign in to comment.