Skip to content

Commit

Permalink
updated training script
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Nov 2, 2023
1 parent 8038373 commit aee1250
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
5 changes: 4 additions & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class GeneralParams:
log_confidence: bool = True
use_threshold: bool = True
folder: str = os.path.join(WVN_ROOT_DIR, "results")
perugia_root: str = "TBD"

general: GeneralParams = GeneralParams()

Expand Down Expand Up @@ -60,6 +61,7 @@ class LossAnomalyParams:

@dataclass
class TrainerParams:
default_root_dir: Optional[str] = None
precision: int = 32
accumulate_grad_batches: int = 1
fast_dev_run: bool = False
Expand All @@ -75,6 +77,7 @@ class TrainerParams:
enable_progress_bar: bool = True
weights_summary: Optional[str] = "top"
progress_bar_refresh_rate: Optional[int] = None
gpus: int = -1

trainer: TrainerParams = TrainerParams()

Expand Down Expand Up @@ -163,7 +166,7 @@ class VisuParams:

@dataclass
class LearningVisuParams:
p_visu: Optional[bool] = None
p_visu: Optional[str] = None
store: bool = True
log: bool = True

Expand Down
6 changes: 3 additions & 3 deletions wild_visual_navigation/general/training_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def training_routine(exp: ExperimentParams, seed=42) -> torch.Tensor:
with read_write(exp):
# Update model paths
exp.general.model_path = model_path
exp.general.name = os.path.relpath(model_path, exp.folder)
exp.general.name = os.path.relpath(model_path, exp.general.folder)
exp.trainer.default_root_dir = model_path
exp.visu.learning_visu.p_visu = join(model_path, "visu")

logger = get_logger(exp, env)
logger = get_logger(exp)

# Set gpus
exp.trainer.gpus = 1 if torch.cuda.is_available() else None
Expand All @@ -66,7 +66,7 @@ def training_routine(exp: ExperimentParams, seed=42) -> torch.Tensor:

train_dl, val_dl, test_dl = get_ablation_module(
**exp.ablation_data_module,
perugia_root=env.perugia_root,
perugia_root=exp.general.perugia_root,
get_train_val_dataset=not exp.general.skip_train,
get_test_dataset=not exp.ablation_data_module.val_equals_test,
)
Expand Down
31 changes: 15 additions & 16 deletions wild_visual_navigation/utils/get_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,19 @@ def get_neptune_run(neptune_project_name: str, tags: [str]) -> any:
return run


def get_neptune_logger(exp: dict, env: dict) -> NeptuneLogger:
def get_neptune_logger(exp: dict) -> NeptuneLogger:
"""Returns NeptuneLogger
Args:
exp (dict): Content of environment file
env (dict): Content of experiment file
Returns:
(logger): Logger
"""
project_name = exp["logger"]["neptune_project_name"] # Neptune AI project_name "username/project"
project_name = exp.logger.neptune_project_name # Neptune AI project_name "username/project"

params = flatten_dict(exp)

name_full = exp["general"]["name"]
name_full = exp.general.name
name_short = "__".join(name_full.split("/")[-2:])

proxies = None
Expand All @@ -59,7 +58,7 @@ def get_neptune_logger(exp: dict, env: dict) -> NeptuneLogger:
)


def get_wandb_logger(exp: dict, env: dict) -> WandbLogger:
def get_wandb_logger(exp: dict) -> WandbLogger:
"""Returns NeptuneLogger
Args:
Expand All @@ -68,17 +67,17 @@ def get_wandb_logger(exp: dict, env: dict) -> WandbLogger:
Returns:
(logger): Logger
"""
project_name = exp["logger"]["wandb_project_name"] # project_name (str): W&B project_name
save_dir = os.path.join(exp["general"]["model_path"]) # save_dir (str): File path to save directory
project_name = exp.logger.wandb_project_name # project_name (str): W&B project_name
save_dir = os.path.join(exp.general.model_path) # save_dir (str): File path to save directory
params = flatten_dict(exp)
name_full = exp["general"]["name"]
name_full = exp.general.name
name_short = "__".join(name_full.split("/")[-2:])
return WandbLogger(
name=name_short, project=project_name, entity=exp["logger"]["wandb_entity"], save_dir=save_dir, offline=False
name=name_short, project=project_name, entity=exp.logger.wandb_entity, save_dir=save_dir, offline=False
)


def get_tensorboard_logger(exp: dict, env: dict) -> TensorBoardLogger:
def get_tensorboard_logger(exp: dict) -> TensorBoardLogger:
"""Returns TensorboardLoggers
Args:
Expand All @@ -88,10 +87,10 @@ def get_tensorboard_logger(exp: dict, env: dict) -> TensorBoardLogger:
(logger): Logger
"""
params = flatten_dict(exp)
return TensorBoardLogger(save_dir=exp["name"], name="tensorboard", default_hp_metric=params)
return TensorBoardLogger(save_dir=exp.name, name="tensorboard", default_hp_metric=params)


def get_skip_logger(exp: dict, env: dict) -> None:
def get_skip_logger(exp: dict) -> None:
"""Returns None
Args:
Expand All @@ -103,8 +102,8 @@ def get_skip_logger(exp: dict, env: dict) -> None:
return None


def get_logger(exp: dict, env: dict) -> any:
name = exp["logger"]["name"]
save_dir = os.path.join(env["base"], exp["general"]["name"])
def get_logger(exp: dict) -> any:
name = exp.logger.name
save_dir = os.path.join(exp.general.folder, exp.general.name)
register = {k: v for k, v in globals().items() if inspect.isfunction(v)}
return register[f"get_{name}_logger"](exp, env)
return register[f"get_{name}_logger"](exp)

0 comments on commit aee1250

Please sign in to comment.