Skip to content

Commit

Permalink
Add test for ComputeNode.connect, tweak CN.__eq__
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Apr 29, 2024
1 parent 98e7901 commit 052bb82
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
5 changes: 3 additions & 2 deletions milatools/utils/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class ComputeNode(Runner):
job_id: int
"""The job ID of the job running on the compute node."""

salloc_subprocess: asyncio.subprocess.Process | None = None
salloc_subprocess: asyncio.subprocess.Process | None = dataclasses.field(
default=None, repr=False, compare=False
)
"""A handle to the subprocess that is running the `salloc` command."""

hostname: str = dataclasses.field(init=False)
Expand All @@ -74,7 +76,6 @@ def __post_init__(self):
@staticmethod
async def connect(
login_node: RemoteV2,
*,
job_id_or_node_name: int | str,
) -> ComputeNode:
return await connect_to_running_job(
Expand Down
33 changes: 31 additions & 2 deletions tests/utils/test_compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest
import pytest_asyncio

from milatools.cli.utils import removesuffix
from milatools.utils.compute_node import (
ComputeNode,
JobNotRunningError,
Expand Down Expand Up @@ -221,6 +222,33 @@ async def test_run_gets_executed_in_job_step(
assert job_step_a >= 0
assert job_step_b == job_step_a + 1

@pytest.mark.asyncio
async def test_connect(self, runner: ComputeNode):
login_node = runner.login_node
job_id = runner.job_id
node_hostname = runner.hostname
compute_node_with_jobid = await ComputeNode.connect(
login_node, job_id_or_node_name=job_id
)
assert compute_node_with_jobid.salloc_subprocess is None
assert compute_node_with_jobid == runner

# Need to connect with the node name, not the full node hostname.
# For the `mila` cluster, we don't currently have a `cn-?????` entry in the ssh
# config (although we could!)
# Therefore, we need to connect to the node with the full hostname. However
# squeue expects the node name, so we have to truncate it manually for now.
if node_hostname.endswith(".server.mila.quebec"):
node_name = removesuffix(node_hostname, ".server.mila.quebec")
else:
node_name = node_hostname

compute_node_with_node_name = await ComputeNode.connect(
login_node, job_id_or_node_name=node_name
)
assert compute_node_with_jobid.salloc_subprocess is None
assert compute_node_with_node_name == runner

@pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"])
@pytest.mark.asyncio
async def test_close(
Expand All @@ -233,7 +261,8 @@ async def test_close(
):
if login_node_v2.hostname == "localhost":
pytest.skip(reason="Test doesn't currently work on the mock slurm cluster.")
# needs to be the last test with this remote though!
# Here we create a new job allocation just to cancel it. We could reuse the
# `runner` fixture, but that would require us to run this test as the very last one.
if persist:
compute_node = await sbatch(
login_node_v2, sbatch_flags=allocation_flags, job_name=job_name
Expand All @@ -257,7 +286,7 @@ async def test_close(
# batch jobs are scancelled.
assert job_state.startswith("CANCELLED")
else:
# interactive jobs are exited cleanly by just exiting in the terminal.
# interactive jobs are exited cleanly.
assert job_state == "COMPLETED"


Expand Down

0 comments on commit 052bb82

Please sign in to comment.