Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MT-67] No ssh connection to get the $HOME path #55

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
from .init_command import setup_ssh_config
from .local import Local
from .profile import ensure_program, setup_profile
from .remote import Remote, SlurmRemote
from .remote import PersistAllocationInfo, Remote, SlurmRemote
from .utils import (
CommandNotFoundError,
MilatoolsUserError,
SSHConfig,
SSHConnectionError,
T,
get_fully_qualified_name,
Expand Down Expand Up @@ -536,34 +537,42 @@ def code(
node: Node to connect to
alloc: Extra options to pass to slurm
"""
if (node is not None) + (job is not None) + bool(alloc) > 1:
exit("ERROR: --node, --job and --alloc are mutually exclusive")

here = Local()
remote = Remote("mila")

if command is None:
command = os.environ.get("MILATOOLS_CODE_COMMAND", "code")

if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
user = get_user_from_ssh_config()
path = f"/home/mila/{user[0]}/{user}/{path}"

if node is None:
remote = Remote("mila")

check_disk_quota(remote)

cnode = _find_allocation(
remote, job=job, node=node, alloc=alloc, job_name="mila-code"
remote, job_name="mila-code", job=job, node=None, alloc=alloc
)
if persist:
cnode = cnode.persist()
data, proc = cnode.ensure_allocation()

cnode = _find_allocation(
remote, job_name="mila-code", job=job, node=node, alloc=alloc
)
if persist:
cnode = cnode.persist()
data, proc = cnode.ensure_allocation()

node_name = data["node_name"]
data, proc = cnode.ensure_allocation()
assert "jobid" in data
job = data["jobid"]
assert isinstance(job, str)
else:
data, proc = cnode.ensure_allocation()
job = None

if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
home = remote.home()
path = "/".join([home, path])
node_name = data["node_name"]
assert isinstance(node_name, str)
node = node_name
else:
proc = None

command_path = shutil.which(command)
if not command_path:
Expand All @@ -574,7 +583,7 @@ def code(
command_path,
"-nw",
"--remote",
f"ssh-remote+{qualified(node_name)}",
f"ssh-remote+{qualified(node)}",
path,
)
print(
Expand All @@ -587,14 +596,14 @@ def code(
if not persist:
if proc is not None:
proc.kill()
print(f"Ended session on '{node_name}'")
print(f"Ended session on '{node}'")

if persist:
print("This allocation is persistent and is still active.")
print("To reconnect to this node:")
print(T.bold(f" mila code {path} --node {node_name}"))
print(T.bold(f" mila code {path} --node {node}"))
print("To kill this allocation:")
print(T.bold(f" ssh mila scancel {data['jobid']}"))
print(T.bold(f" ssh mila scancel {job}"))


def connect(identifier: str, port: int | None):
Expand Down Expand Up @@ -1216,5 +1225,13 @@ def _forward(
return proc, port


def get_user_from_ssh_config() -> str:
ssh_config_path = Path("~/.ssh/config")
ssh_config = SSHConfig(ssh_config_path)
mila_entry = ssh_config.host("mila")
user: str = mila_entry["user"]
return user


if __name__ == "__main__":
main()
19 changes: 17 additions & 2 deletions milatools/cli/remote.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import re
import socket
import tempfile
import time
from pathlib import Path
from queue import Empty, Queue

import fabric.runners
import paramiko
import questionary as qn
from fabric import Connection
from typing_extensions import TypedDict

from .utils import SSHConnectionError, T, control_file_var, here, shjoin

Expand Down Expand Up @@ -76,6 +80,15 @@ def get_first_node_name(node_names_out: str) -> str:
return base + inside_brackets.split("-")[0]


class AllocationInfo(TypedDict):
node_name: str


class PersistAllocationInfo(AllocationInfo):
node_name: str
jobid: str


class Remote:
def __init__(self, hostname, connection=None, transforms=(), keepalive=60):
self.hostname = hostname
Expand Down Expand Up @@ -192,7 +205,7 @@ def persist(self):
)
return self

def ensure_allocation(self):
def ensure_allocation(self) -> tuple[AllocationInfo, None]:
return {"node_name": self.hostname}, None

def run_script(self, name, *args, **kwargs):
Expand Down Expand Up @@ -252,7 +265,9 @@ def with_transforms(self, *transforms, persist=None):
def persist(self):
return self.with_transforms(persist=True)

def ensure_allocation(self):
def ensure_allocation(
self,
) -> tuple[AllocationInfo | PersistAllocationInfo, fabric.runners.Remote]:
if self._persist:
proc, results = self.extract(
"echo @@@ $(hostname) @@@ && sleep 1000d",
Expand Down
2 changes: 1 addition & 1 deletion milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def hoststring(self, host):
return "\n".join(lines)


def qualified(node_name):
def qualified(node_name: str):
"""Return the fully qualified name corresponding to this node name."""

if "." not in node_name and not node_name.endswith(".server.mila.quebec"):
Expand Down
Loading