Skip to content

Commit

Permalink
Split read_params_from_cmdline for job and main script and use argparse
Browse files Browse the repository at this point in the history
`read_params_from_cmdline()` was previously used by both the
cluster_utils main script (e.g. grid_search.py) and the job scripts.
They have different requirements and thus the function was pretty
convoluted.  With this change now

- `read_params_from_cmdline()` is only responsible only for the job
  scripts.  It is refactored and uses argparse internally for hopefully
  better readable code and `--help`-support.
- `read_main_script_params_with_smart_settings()` is the counter-part
  for the grid_search/hp_optimization scripts.

BREAKING: `read_params_from_cmdline()`: The job script now expects
named arguments instead of positional ones.  This makes the optionality
of server information and distinction between settings file vs
dictionary string much easier to implement.
This affects manual calls of the job script and may break non-python job
scripts which operate on the arguments.
Further support for custom hooks has been removed as it doesn't seem to
be relevant here.
  • Loading branch information
luator committed May 28, 2024
1 parent 10b45fb commit 5b20f93
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 68 deletions.
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
cluster_utils.grid_search ...`)
- **Breaking:** All exit codes other than 0 or 3 (the magic "restart for resume" code)
are now considered as failures. Previously only 1 was considered as failure.
- **Breaking:** Changed the parsing of arguments in `read_params_from_cmdline()`.
Previously it was expecting something like this:
```
script.py [server_communication_info] \
(config_file [cmd_line_parameters ...]] | config_dictionary)
```
Now it expects named arguments:
```
script.py [--server-connection-info VALUE] \
(--parameter-dict VALUE | --parameter-file VALUE) \
[--parameters KEY_VALUE [KEY_VALUE ...]]
```
Use `--help` to get more information on the arguments.
This change is relevant
- if you want to run the job scripts manually, and
- if you are using non-python scripts where you parse the arguments yourself.
- The raw data of `grid_search` is saved to a file "all_data.csv" instead of
"results_raw.csv" to be consistent with `hp_optimization` (the format of the file
didn't change, only the name).
Expand Down
24 changes: 9 additions & 15 deletions cluster_utils/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,30 +150,24 @@ def generate_execution_cmd(self, paths, cmd_prefix: Optional[str] = None):

self.final_settings = current_setting

arguments = (
f'--server-connection-info="{self.comm_server_info}"'
f' --parameter-dict="{current_setting}"'
)
if is_python_script:
run_script_as_module_main = paths.get("run_as_module", False)
setting_string = '"' + str(current_setting) + '"'
comm_info_string = '"' + str(self.comm_server_info) + '"'
if run_script_as_module_main:
# convert path to module name
module_name = (
paths["script_to_run"].replace("/", ".").replace(".py", "")
)
exec_cmd = f"{python_executor} -m {module_name} {comm_info_string} {setting_string}"
exec_cmd = f"{python_executor} -m {module_name} {arguments}"
else:
base_exec_cmd = "{}".format(python_executor) + " {} {} {}"
exec_cmd = base_exec_cmd.format(
os.path.join(paths["main_path"], paths["script_to_run"]),
comm_info_string,
setting_string,
)
script_path = os.path.join(paths["main_path"], paths["script_to_run"])
exec_cmd = f"{python_executor} {script_path} {arguments}"
else:
base_exec_cmd = "{} {} {}"
exec_cmd = base_exec_cmd.format(
os.path.join(paths["main_path"], paths["script_to_run"]),
'"' + str(self.comm_server_info) + '"',
'"' + str(current_setting) + '"',
)
script_path = os.path.join(paths["main_path"], paths["script_to_run"])
exec_cmd = f"{script_path} {arguments}"

if self.singularity_settings:
exec_cmd = self.singularity_wrap(
Expand Down
184 changes: 131 additions & 53 deletions cluster_utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
import time
import traceback
from typing import Any, NamedTuple
from typing import Any, NamedTuple, Optional

import pyuv
import smart_settings
Expand Down Expand Up @@ -352,9 +352,9 @@ def read_main_script_params_from_args(args: argparse.Namespace):
Returns:
smart_settings parameter structure.
"""
return read_params_from_cmdline(
cmd_line=[sys.argv[0], str(args.settings_file), *args.settings],
verbose=False,
return read_main_script_params_with_smart_settings(
settings_file=args.settings_file,
cmdline_settings=args.settings,
pre_unpack_hooks=[check_import_in_fixed_params],
post_unpack_hooks=[
rename_import_promise,
Expand All @@ -363,6 +363,14 @@ def read_main_script_params_from_args(args: argparse.Namespace):
)


def check_reserved_params(orig_dict: dict) -> None:
"""Check if the given dict contains reserved keys. If yes, raise ValueError."""
for key in orig_dict:
if key in constants.RESERVED_PARAMS:
msg = f"'{key}' is a reserved param name"
raise ValueError(msg)


def add_cmd_line_params(base_dict, extra_flags):
for extra_flag in extra_flags:
lhs, eq, rhs = extra_flag.rpartition("=")
Expand All @@ -375,74 +383,143 @@ def add_cmd_line_params(base_dict, extra_flags):
raise RuntimeError(f"Command {cmd} failed") from e


def read_params_from_cmdline(
cmd_line=None,
make_immutable=True,
verbose=True,
dynamic=True,
pre_unpack_hooks=None,
post_unpack_hooks=None,
save_params=True,
):
"""Updates default settings based on command line input.
:param cmd_line: Expecting (same format as) sys.argv
:param verbose: Boolean to determine if final settings are pretty printed
:return: Settings object with (deep) dot access.
def read_main_script_params_with_smart_settings(
settings_file: pathlib.Path,
cmdline_settings: Optional[list[str]] = None,
make_immutable: bool = True,
dynamic: bool = True,
pre_unpack_hooks: Optional[list] = None,
post_unpack_hooks: Optional[list] = None,
) -> smart_settings.AttributeDict:
"""Read parameters for the cluster_utils main script using smart_settings.
Args:
settings_file: Path to the settings file.
cmdline_settings: List of additional parameters provided via command line.
make_immutable: See ``smart_settings.load()``
dynamic: See ``smart_settings.load()``
pre_unpack_hooks: See ``smart_settings.load()``
post_unpack_hooks: See ``smart_settings.load()``
Returns:
Parameters as loaded by smart_settings.
"""
cmdline_settings = cmdline_settings or []
pre_unpack_hooks = pre_unpack_hooks or []
post_unpack_hooks = post_unpack_hooks or []

if not is_settings_file(os.fspath(settings_file)):
raise ValueError(f"{settings_file} is not a supported settings file.")

def add_cmd_params(orig_dict):
add_cmd_line_params(orig_dict, cmdline_settings)

return smart_settings.load(
os.fspath(settings_file),
make_immutable=make_immutable,
dynamic=dynamic,
post_unpack_hooks=([add_cmd_params, check_reserved_params] + post_unpack_hooks),
pre_unpack_hooks=pre_unpack_hooks,
)


def read_params_from_cmdline(
cmd_line: Optional[list[str]] = None,
make_immutable: bool = True,
verbose: bool = True,
dynamic: bool = True,
save_params: bool = True,
) -> smart_settings.AttributeDict:
"""Read parameters based on command line input.
Args:
cmd_line: Command line arguments (defaults to sys.argv).
make_immutable: See ``smart_settings.loads()``
verbose: If true, print the loaded parameters.
dynamic: See ``smart_settings.loads()``
save_params: If true, save the settings as JSON file in the working_dir.
Returns:
Parameters as loaded by smart_settings.
"""
if not cmd_line:
cmd_line = sys.argv

try:
connection_details = ast.literal_eval(cmd_line[1])
except (SyntaxError, ValueError):
connection_details = {}
pass

if set(connection_details.keys()) == {constants.ID, "ip", "port"}:
submission_state.communication_server_ip = connection_details["ip"]
submission_state.communication_server_port = connection_details["port"]
submission_state.job_id = connection_details[constants.ID]
del cmd_line[1]
submission_state.connection_details_available = True
submission_state.connection_active = False
# expected keys of the server connection dictionary
server_connection_keys = {constants.ID, "ip", "port"}

def check_reserved_params(orig_dict):
for key in orig_dict:
if key in constants.RESERVED_PARAMS:
raise ValueError(f"{key} is a reserved param name")
parser = argparse.ArgumentParser()
parser.add_argument(
"--server-connection-info",
type=ast.literal_eval,
help="""Information to communicate with the cluster_utils main process.
Dictionary with keys {}.
""".format(
server_connection_keys
),
)
param_group = parser.add_mutually_exclusive_group(required=True)
param_group.add_argument(
"--parameter-dict",
type=ast.literal_eval,
help="Dictionary with the job parameters.",
)
param_group.add_argument(
"--parameter-file", type=pathlib.Path, help="File with the job parameters."
)
parser.add_argument(
"--parameters",
nargs="+",
type=str,
default=[],
metavar="KEY_VALUE",
help="""Additional parameters in the format '<key>=<value>'. Values provided
here overwrite parameters provided via `--parameter-dict` or
`--parameter-file`. Key has to match a configuration option, value has to
be valid Python. Example: `--parameters 'results_dir="/tmp"'
'optimization_setting.run_local=True'`
""",
)

args = parser.parse_args(cmd_line[1:])

if len(cmd_line) < 2:
final_params = {}
elif is_settings_file(cmd_line[1]):
if args.server_connection_info:
if not isinstance(args.server_connection_info, dict):
msg = "'--server-connection-info' must be a dictionary or `None`."
raise ValueError(msg)
elif set(args.server_connection_info.keys()) != server_connection_keys:
msg = (
f"'--server-connection-info' must contain keys {server_connection_keys}"
)
raise ValueError(msg)

def add_cmd_params(orig_dict):
add_cmd_line_params(orig_dict, cmd_line[2:])
submission_state.communication_server_ip = args.server_connection_info["ip"]
submission_state.communication_server_port = args.server_connection_info["port"]
submission_state.job_id = args.server_connection_info[constants.ID]
submission_state.connection_details_available = True
submission_state.connection_active = False

def add_cmd_params(orig_dict):
add_cmd_line_params(orig_dict, args.parameters)

if args.parameter_file:
final_params = smart_settings.load(
cmd_line[1],
os.fspath(args.parameter_file),
make_immutable=make_immutable,
dynamic=dynamic,
post_unpack_hooks=(
[add_cmd_params, check_reserved_params] + post_unpack_hooks
),
pre_unpack_hooks=pre_unpack_hooks,
post_unpack_hooks=([add_cmd_params, check_reserved_params]),
)
else:
if not isinstance(args.parameter_dict, dict):
msg = "'--parameter-dict' must be a dictionary."
raise ValueError(msg)

elif len(cmd_line) == 2 and is_parseable_dict(cmd_line[1]):
final_params = ast.literal_eval(cmd_line[1])
final_params = smart_settings.loads(
json.dumps(final_params),
json.dumps(args.parameter_dict),
make_immutable=make_immutable,
dynamic=dynamic,
post_unpack_hooks=[check_reserved_params] + post_unpack_hooks,
pre_unpack_hooks=pre_unpack_hooks,
post_unpack_hooks=([add_cmd_params, check_reserved_params]),
)
else:
raise ValueError("Failed to parse command line")

if verbose:
print(final_params)
Expand All @@ -455,7 +532,8 @@ def add_cmd_params(orig_dict):
sys.excepthook = report_error_at_server
atexit.register(report_exit_at_server)
submission_state.connection_active = True
read_params_from_cmdline.start_time = time.time()

read_params_from_cmdline.start_time = time.time() # type: ignore

if save_params and "working_dir" in final_params:
os.makedirs(final_params.working_dir, exist_ok=True)
Expand Down

0 comments on commit 5b20f93

Please sign in to comment.