From bd93924396c98a83f7359e57b5f7e59446a9aa7c Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 19 Sep 2023 14:47:31 -0400 Subject: [PATCH] [MT-67] No ssh connection to get the $HOME path Signed-off-by: Fabrice Normandin --- milatools/cli/commands.py | 61 +++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index c721e84c..fbcde1fe 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -35,6 +35,7 @@ from .utils import ( CommandNotFoundError, MilatoolsUserError, + SSHConfig, SSHConnectionError, T, get_fully_qualified_name, @@ -73,9 +74,7 @@ def main(): "title": f"[v{mversion}] Issue running the command " + (f"`mila {command}`" if command else "`mila`"), } - github_issue_url = ( - f"https://github.com/mila-iqia/milatools/issues/new?{urlencode(options)}" - ) + github_issue_url = f"https://github.com/mila-iqia/milatools/issues/new?{urlencode(options)}" print( T.bold_yellow( f"An error occured during the execution of the command " @@ -116,7 +115,9 @@ def mila(): help="Open the Mila cluster documentation.", formatter_class=SortingHelpFormatter, ) - docs_parser.add_argument("SEARCH", nargs=argparse.REMAINDER, help="Search terms") + docs_parser.add_argument( + "SEARCH", nargs=argparse.REMAINDER, help="Search terms" + ) docs_parser.set_defaults(function=docs) # ----- mila intranet ------ @@ -353,7 +354,9 @@ def mila(): return function(**args_dict) -def _convert_uppercase_keys_to_lowercase(args_dict: dict[str, Any]) -> dict[str, Any]: +def _convert_uppercase_keys_to_lowercase( + args_dict: dict[str, Any] +) -> dict[str, Any]: return {(k.lower() if k.isupper() else k): v for k, v in args_dict.items()} @@ -466,7 +469,11 @@ def init(): ################### print(T.bold_cyan("=" * 60)) - print(T.bold_cyan("Congrats! You are now ready to start working on the cluster!")) + print( + T.bold_cyan( + "Congrats! You are now ready to start working on the cluster!" + ) + ) print(T.bold_cyan("=" * 60)) print(T.bold("To connect to a login node:")) print(" ssh mila") @@ -475,7 +482,9 @@ def init(): print(T.bold("To open a directory on the cluster with VSCode:")) print(" mila code path/to/code/on/cluster") print(T.bold("Same as above, but allocate 1 GPU, 4 CPUs, 32G of RAM:")) - print(" mila code path/to/code/on/cluster --alloc --gres=gpu:1 --mem=32G -c 4") + print( + " mila code path/to/code/on/cluster --alloc --gres=gpu:1 --mem=32G -c 4" + ) print() print( "For more information, read the milatools documentation at", @@ -559,11 +568,11 @@ def code( data, proc = cnode.ensure_allocation() node_name = data["node_name"] - if not path.startswith("/"): # Get $HOME because we have to give the full path to code - home = remote.home() - path = "/".join([home, path]) + user = get_user_from_ssh_config() + home = f"/home/mila/{user[0]}/{user}" + path = f"{home}/{path}" command_path = shutil.which(command) if not command_path: @@ -652,12 +661,16 @@ def serve_list(purge: bool): for identifier in remote.get_lines("ls .milatools/control", hide=True): info = _get_server_info(remote, identifier, hide=True) jobid = info.get("jobid", None) - status = remote.get_output(f"squeue -j {jobid} -ho %T", hide=True, warn=True) + status = remote.get_output( + f"squeue -j {jobid} -ho %T", hide=True, warn=True + ) program = info.pop("program", "???") if status == "RUNNING": necessary_keys = {"node_name", "to_forward"} if any(k not in info for k in necessary_keys): - qn.print(f"{identifier} ({program}, MISSING INFO)", style="bold red") + qn.print( + f"{identifier} ({program}, MISSING INFO)", style="bold red" + ) to_purge.append((identifier, jobid)) else: qn.print(f"{identifier} ({program})", style="bold yellow") @@ -728,7 +741,9 @@ def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]): path: Path to open on the remote machine """ if path and path.endswith(".ipynb"): - exit("Only directories can be given to the mila serve notebook command") + exit( + "Only directories can be given to the mila serve notebook command" + ) _standard_server( path, @@ -992,7 +1007,9 @@ def _standard_server( if cf is not None: remote.simple_run(f"echo program = {program} >> {cf}") - remote.simple_run(f"echo node_name = {results['node_name']} >> {cf}") + remote.simple_run( + f"echo node_name = {results['node_name']} >> {cf}" + ) remote.simple_run(f"echo host = {host} >> {cf}") remote.simple_run(f"echo to_forward = {to_forward} >> {cf}") if token_pattern: @@ -1051,7 +1068,9 @@ def _get_disk_quota_usage( normandf|1471600598|| 97.20 GiB| 100.00 GiB|| 806898| 1000000 ``` """ - disk_quota_output = remote.get_output("disk-quota", hide=not print_command_output) + disk_quota_output = remote.get_output( + "disk-quota", hide=not print_command_output + ) last_line_parts = disk_quota_output.splitlines()[-1] ( _username, @@ -1071,9 +1090,7 @@ def _get_disk_quota_usage( def check_disk_quota(remote: Remote) -> None: - cluster = ( - "mila" # todo: if we run this on CC, then we should use `diskusage_report` - ) + cluster = "mila" # todo: if we run this on CC, then we should use `diskusage_report` # todo: Check the disk-quota of other filesystems if needed. filesystem = "$HOME" logger.debug("Checking disk quota on $HOME...") @@ -1216,5 +1233,13 @@ def _forward( return proc, port +def get_user_from_ssh_config() -> str: + ssh_config_path = Path("~/.ssh/config") + ssh_config = SSHConfig(ssh_config_path) + mila_entry = ssh_config.host("mila") + user: str = mila_entry["user"] + return user + + if __name__ == "__main__": main()