From aee1250aa33fc9273ae0654a9eaa94d10e556796 Mon Sep 17 00:00:00 2001 From: Jonas Frey Date: Thu, 2 Nov 2023 10:50:01 +0100 Subject: [PATCH] updated training script --- .../cfg/experiment_params.py | 5 ++- .../general/training_routine.py | 6 ++-- wild_visual_navigation/utils/get_logger.py | 31 +++++++++---------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/wild_visual_navigation/cfg/experiment_params.py b/wild_visual_navigation/cfg/experiment_params.py index d1e2968e..5fcdc110 100644 --- a/wild_visual_navigation/cfg/experiment_params.py +++ b/wild_visual_navigation/cfg/experiment_params.py @@ -20,6 +20,7 @@ class GeneralParams: log_confidence: bool = True use_threshold: bool = True folder: str = os.path.join(WVN_ROOT_DIR, "results") + perugia_root: str = "TBD" general: GeneralParams = GeneralParams() @@ -60,6 +61,7 @@ class LossAnomalyParams: @dataclass class TrainerParams: + default_root_dir: Optional[str] = None precision: int = 32 accumulate_grad_batches: int = 1 fast_dev_run: bool = False @@ -75,6 +77,7 @@ class TrainerParams: enable_progress_bar: bool = True weights_summary: Optional[str] = "top" progress_bar_refresh_rate: Optional[int] = None + gpus: int = -1 trainer: TrainerParams = TrainerParams() @@ -163,7 +166,7 @@ class VisuParams: @dataclass class LearningVisuParams: - p_visu: Optional[bool] = None + p_visu: Optional[str] = None store: bool = True log: bool = True diff --git a/wild_visual_navigation/general/training_routine.py b/wild_visual_navigation/general/training_routine.py index bcb99aae..0c2d527a 100644 --- a/wild_visual_navigation/general/training_routine.py +++ b/wild_visual_navigation/general/training_routine.py @@ -36,11 +36,11 @@ def training_routine(exp: ExperimentParams, seed=42) -> torch.Tensor: with read_write(exp): # Update model paths exp.general.model_path = model_path - exp.general.name = os.path.relpath(model_path, exp.folder) + exp.general.name = os.path.relpath(model_path, exp.general.folder) exp.trainer.default_root_dir = model_path exp.visu.learning_visu.p_visu = join(model_path, "visu") - logger = get_logger(exp, env) + logger = get_logger(exp) # Set gpus exp.trainer.gpus = 1 if torch.cuda.is_available() else None @@ -66,7 +66,7 @@ def training_routine(exp: ExperimentParams, seed=42) -> torch.Tensor: train_dl, val_dl, test_dl = get_ablation_module( **exp.ablation_data_module, - perugia_root=env.perugia_root, + perugia_root=exp.general.perugia_root, get_train_val_dataset=not exp.general.skip_train, get_test_dataset=not exp.ablation_data_module.val_equals_test, ) diff --git a/wild_visual_navigation/utils/get_logger.py b/wild_visual_navigation/utils/get_logger.py index 0bf277e5..457802ee 100644 --- a/wild_visual_navigation/utils/get_logger.py +++ b/wild_visual_navigation/utils/get_logger.py @@ -30,20 +30,19 @@ def get_neptune_run(neptune_project_name: str, tags: [str]) -> any: return run -def get_neptune_logger(exp: dict, env: dict) -> NeptuneLogger: +def get_neptune_logger(exp: dict) -> NeptuneLogger: """Returns NeptuneLogger Args: exp (dict): Content of environment file - env (dict): Content of experiment file Returns: (logger): Logger """ - project_name = exp["logger"]["neptune_project_name"] # Neptune AI project_name "username/project" + project_name = exp.logger.neptune_project_name # Neptune AI project_name "username/project" params = flatten_dict(exp) - name_full = exp["general"]["name"] + name_full = exp.general.name name_short = "__".join(name_full.split("/")[-2:]) proxies = None @@ -59,7 +58,7 @@ def get_neptune_logger(exp: dict, env: dict) -> NeptuneLogger: ) -def get_wandb_logger(exp: dict, env: dict) -> WandbLogger: +def get_wandb_logger(exp: dict) -> WandbLogger: """Returns NeptuneLogger Args: @@ -68,17 +67,17 @@ def get_wandb_logger(exp: dict, env: dict) -> WandbLogger: Returns: (logger): Logger """ - project_name = exp["logger"]["wandb_project_name"] # project_name (str): W&B project_name - save_dir = os.path.join(exp["general"]["model_path"]) # save_dir (str): File path to save directory + project_name = exp.logger.wandb_project_name # project_name (str): W&B project_name + save_dir = os.path.join(exp.general.model_path) # save_dir (str): File path to save directory params = flatten_dict(exp) - name_full = exp["general"]["name"] + name_full = exp.general.name name_short = "__".join(name_full.split("/")[-2:]) return WandbLogger( - name=name_short, project=project_name, entity=exp["logger"]["wandb_entity"], save_dir=save_dir, offline=False + name=name_short, project=project_name, entity=exp.logger.wandb_entity, save_dir=save_dir, offline=False ) -def get_tensorboard_logger(exp: dict, env: dict) -> TensorBoardLogger: +def get_tensorboard_logger(exp: dict) -> TensorBoardLogger: """Returns TensorboardLoggers Args: @@ -88,10 +87,10 @@ def get_tensorboard_logger(exp: dict, env: dict) -> TensorBoardLogger: (logger): Logger """ params = flatten_dict(exp) - return TensorBoardLogger(save_dir=exp["name"], name="tensorboard", default_hp_metric=params) + return TensorBoardLogger(save_dir=exp.name, name="tensorboard", default_hp_metric=params) -def get_skip_logger(exp: dict, env: dict) -> None: +def get_skip_logger(exp: dict) -> None: """Returns None Args: @@ -103,8 +102,8 @@ def get_skip_logger(exp: dict, env: dict) -> None: return None -def get_logger(exp: dict, env: dict) -> any: - name = exp["logger"]["name"] - save_dir = os.path.join(env["base"], exp["general"]["name"]) +def get_logger(exp: dict) -> any: + name = exp.logger.name + save_dir = os.path.join(exp.general.folder, exp.general.name) register = {k: v for k, v in globals().items() if inspect.isfunction(v)} - return register[f"get_{name}_logger"](exp, env) + return register[f"get_{name}_logger"](exp)