Skip to content

Commit

Permalink
[MT-67] No ssh connection to get the $HOME path
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Sep 25, 2023
1 parent 0812760 commit bd93924
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .utils import (
CommandNotFoundError,
MilatoolsUserError,
SSHConfig,
SSHConnectionError,
T,
get_fully_qualified_name,
Expand Down Expand Up @@ -73,9 +74,7 @@ def main():
"title": f"[v{mversion}] Issue running the command "
+ (f"`mila {command}`" if command else "`mila`"),
}
github_issue_url = (
f"https://github.com/mila-iqia/milatools/issues/new?{urlencode(options)}"
)
github_issue_url = f"https://github.com/mila-iqia/milatools/issues/new?{urlencode(options)}"
print(
T.bold_yellow(
f"An error occured during the execution of the command "
Expand Down Expand Up @@ -116,7 +115,9 @@ def mila():
help="Open the Mila cluster documentation.",
formatter_class=SortingHelpFormatter,
)
docs_parser.add_argument("SEARCH", nargs=argparse.REMAINDER, help="Search terms")
docs_parser.add_argument(
"SEARCH", nargs=argparse.REMAINDER, help="Search terms"
)
docs_parser.set_defaults(function=docs)

# ----- mila intranet ------
Expand Down Expand Up @@ -353,7 +354,9 @@ def mila():
return function(**args_dict)


def _convert_uppercase_keys_to_lowercase(args_dict: dict[str, Any]) -> dict[str, Any]:
def _convert_uppercase_keys_to_lowercase(
args_dict: dict[str, Any]
) -> dict[str, Any]:
return {(k.lower() if k.isupper() else k): v for k, v in args_dict.items()}


Expand Down Expand Up @@ -466,7 +469,11 @@ def init():
###################

print(T.bold_cyan("=" * 60))
print(T.bold_cyan("Congrats! You are now ready to start working on the cluster!"))
print(
T.bold_cyan(
"Congrats! You are now ready to start working on the cluster!"
)
)
print(T.bold_cyan("=" * 60))
print(T.bold("To connect to a login node:"))
print(" ssh mila")
Expand All @@ -475,7 +482,9 @@ def init():
print(T.bold("To open a directory on the cluster with VSCode:"))
print(" mila code path/to/code/on/cluster")
print(T.bold("Same as above, but allocate 1 GPU, 4 CPUs, 32G of RAM:"))
print(" mila code path/to/code/on/cluster --alloc --gres=gpu:1 --mem=32G -c 4")
print(
" mila code path/to/code/on/cluster --alloc --gres=gpu:1 --mem=32G -c 4"
)
print()
print(
"For more information, read the milatools documentation at",
Expand Down Expand Up @@ -559,11 +568,11 @@ def code(
data, proc = cnode.ensure_allocation()

node_name = data["node_name"]

if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
home = remote.home()
path = "/".join([home, path])
user = get_user_from_ssh_config()
home = f"/home/mila/{user[0]}/{user}"
path = f"{home}/{path}"

command_path = shutil.which(command)
if not command_path:
Expand Down Expand Up @@ -652,12 +661,16 @@ def serve_list(purge: bool):
for identifier in remote.get_lines("ls .milatools/control", hide=True):
info = _get_server_info(remote, identifier, hide=True)
jobid = info.get("jobid", None)
status = remote.get_output(f"squeue -j {jobid} -ho %T", hide=True, warn=True)
status = remote.get_output(
f"squeue -j {jobid} -ho %T", hide=True, warn=True
)
program = info.pop("program", "???")
if status == "RUNNING":
necessary_keys = {"node_name", "to_forward"}
if any(k not in info for k in necessary_keys):
qn.print(f"{identifier} ({program}, MISSING INFO)", style="bold red")
qn.print(
f"{identifier} ({program}, MISSING INFO)", style="bold red"
)
to_purge.append((identifier, jobid))
else:
qn.print(f"{identifier} ({program})", style="bold yellow")
Expand Down Expand Up @@ -728,7 +741,9 @@ def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]):
path: Path to open on the remote machine
"""
if path and path.endswith(".ipynb"):
exit("Only directories can be given to the mila serve notebook command")
exit(
"Only directories can be given to the mila serve notebook command"
)

_standard_server(
path,
Expand Down Expand Up @@ -992,7 +1007,9 @@ def _standard_server(

if cf is not None:
remote.simple_run(f"echo program = {program} >> {cf}")
remote.simple_run(f"echo node_name = {results['node_name']} >> {cf}")
remote.simple_run(
f"echo node_name = {results['node_name']} >> {cf}"
)
remote.simple_run(f"echo host = {host} >> {cf}")
remote.simple_run(f"echo to_forward = {to_forward} >> {cf}")
if token_pattern:
Expand Down Expand Up @@ -1051,7 +1068,9 @@ def _get_disk_quota_usage(
normandf|1471600598|| 97.20 GiB| 100.00 GiB|| 806898| 1000000
```
"""
disk_quota_output = remote.get_output("disk-quota", hide=not print_command_output)
disk_quota_output = remote.get_output(
"disk-quota", hide=not print_command_output
)
last_line_parts = disk_quota_output.splitlines()[-1]
(
_username,
Expand All @@ -1071,9 +1090,7 @@ def _get_disk_quota_usage(


def check_disk_quota(remote: Remote) -> None:
cluster = (
"mila" # todo: if we run this on CC, then we should use `diskusage_report`
)
cluster = "mila" # todo: if we run this on CC, then we should use `diskusage_report`
# todo: Check the disk-quota of other filesystems if needed.
filesystem = "$HOME"
logger.debug("Checking disk quota on $HOME...")
Expand Down Expand Up @@ -1216,5 +1233,13 @@ def _forward(
return proc, port


def get_user_from_ssh_config() -> str:
ssh_config_path = Path("~/.ssh/config")
ssh_config = SSHConfig(ssh_config_path)
mila_entry = ssh_config.host("mila")
user: str = mila_entry["user"]
return user


if __name__ == "__main__":
main()

0 comments on commit bd93924

Please sign in to comment.