Skip to content

Commit

Permalink
[MT-67] No ssh connection to get the $HOME path
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Sep 25, 2023
1 parent 60f419e commit fbd6d23
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 21 deletions.
60 changes: 42 additions & 18 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
from .init_command import setup_ssh_config
from .local import Local
from .profile import ensure_program, setup_profile
from .remote import Remote, SlurmRemote
from .remote import PersistAllocationInfo, Remote, SlurmRemote
from .utils import (
CommandNotFoundError,
MilatoolsUserError,
SSHConfig,
SSHConnectionError,
T,
get_fully_qualified_name,
Expand Down Expand Up @@ -536,27 +537,42 @@ def code(
node: Node to connect to
alloc: Extra options to pass to slurm
"""
if (node is not None) + (job is not None) + bool(alloc) > 1:
exit("ERROR: --node, --job and --alloc are mutually exclusive")

here = Local()
remote = Remote("mila")

if command is None:
command = os.environ.get("MILATOOLS_CODE_COMMAND", "code")

check_disk_quota(remote)
if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
user = get_user_from_ssh_config()
path = f"/home/mila/{user[0]}/{user}/{path}"

cnode = _find_allocation(
remote, job_name="mila-code", job=job, node=node, alloc=alloc
)
if persist:
cnode = cnode.persist()
data, proc = cnode.ensure_allocation()
if node is None:
remote = Remote("mila")

node_name = data["node_name"]
check_disk_quota(remote)

if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
home = remote.home()
path = "/".join([home, path])
cnode = _find_allocation(
remote, job_name="mila-code", job=job, node=None, alloc=alloc
)
if persist:
cnode = cnode.persist()
data, proc = cnode.ensure_allocation()
assert "jobid" in data
job = data["jobid"]
assert isinstance(job, str)
else:
data, proc = cnode.ensure_allocation()
job = None

node_name = data["node_name"]
assert isinstance(node_name, str)
node = node_name
else:
proc = None

command_path = shutil.which(command)
if not command_path:
Expand All @@ -567,7 +583,7 @@ def code(
command_path,
"-nw",
"--remote",
f"ssh-remote+{qualified(node_name)}",
f"ssh-remote+{qualified(node)}",
path,
)
print(
Expand All @@ -580,14 +596,14 @@ def code(
if not persist:
if proc is not None:
proc.kill()
print(f"Ended session on '{node_name}'")
print(f"Ended session on '{node}'")

if persist:
print("This allocation is persistent and is still active.")
print("To reconnect to this node:")
print(T.bold(f" mila code {path} --node {node_name}"))
print(T.bold(f" mila code {path} --node {node}"))
print("To kill this allocation:")
print(T.bold(f" ssh mila scancel {data['jobid']}"))
print(T.bold(f" ssh mila scancel {job}"))


def connect(identifier: str, port: int | None):
Expand Down Expand Up @@ -1209,5 +1225,13 @@ def _forward(
return proc, port


def get_user_from_ssh_config() -> str:
ssh_config_path = Path("~/.ssh/config")
ssh_config = SSHConfig(ssh_config_path)
mila_entry = ssh_config.host("mila")
user: str = mila_entry["user"]
return user


if __name__ == "__main__":
main()
19 changes: 17 additions & 2 deletions milatools/cli/remote.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import re
import socket
import tempfile
import time
from pathlib import Path
from queue import Empty, Queue
from typing import TypedDict

import fabric.runners
import paramiko
import questionary as qn
from fabric import Connection
Expand Down Expand Up @@ -76,6 +80,15 @@ def get_first_node_name(node_names_out: str) -> str:
return base + inside_brackets.split("-")[0]


class AllocationInfo(TypedDict):
node_name: str


class PersistAllocationInfo(AllocationInfo):
node_name: str
jobid: str


class Remote:
def __init__(self, hostname, connection=None, transforms=(), keepalive=60):
self.hostname = hostname
Expand Down Expand Up @@ -192,7 +205,7 @@ def persist(self):
)
return self

def ensure_allocation(self):
def ensure_allocation(self) -> tuple[AllocationInfo, None]:
return {"node_name": self.hostname}, None

def run_script(self, name, *args, **kwargs):
Expand Down Expand Up @@ -252,7 +265,9 @@ def with_transforms(self, *transforms, persist=None):
def persist(self):
return self.with_transforms(persist=True)

def ensure_allocation(self):
def ensure_allocation(
self,
) -> tuple[AllocationInfo | PersistAllocationInfo, fabric.runners.Remote]:
if self._persist:
proc, results = self.extract(
"echo @@@ $(hostname) @@@ && sleep 1000d",
Expand Down
2 changes: 1 addition & 1 deletion milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def hoststring(self, host):
return "\n".join(lines)


def qualified(node_name):
def qualified(node_name: str):
"""Return the fully qualified name corresponding to this node name."""

if "." not in node_name and not node_name.endswith(".server.mila.quebec"):
Expand Down

0 comments on commit fbd6d23

Please sign in to comment.