From 45f9a2461526379372b50488d56221ec2487f415 Mon Sep 17 00:00:00 2001 From: Harim Kang Date: Fri, 20 Sep 2024 12:14:49 +0900 Subject: [PATCH] Add type checker in converter for callable functions (optimizer, scheduler) (#3968) Fix converter callable functions (optimizer, scheduler) --- src/otx/tools/converter.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/otx/tools/converter.py b/src/otx/tools/converter.py index 98c9d4aee86..23818c8ce0a 100644 --- a/src/otx/tools/converter.py +++ b/src/otx/tools/converter.py @@ -277,12 +277,22 @@ def update_inference_batch_size(param_value: int) -> None: config["data"]["test_subset"]["batch_size"] = param_value def update_learning_rate(param_value: float) -> None: - config["model"]["init_args"]["optimizer"]["init_args"]["lr"] = param_value + optimizer = config["model"]["init_args"]["optimizer"] + if isinstance(optimizer, dict) and "init_args" in optimizer: + optimizer["init_args"]["lr"] = param_value + else: + warn("Warning: learning_rate is not updated", stacklevel=1) def update_learning_rate_warmup_iters(param_value: int) -> None: scheduler = config["model"]["init_args"]["scheduler"] - if scheduler["class_path"] == "otx.core.schedulers.LinearWarmupSchedulerCallable": + if ( + isinstance(scheduler, dict) + and "class_path" in scheduler + and scheduler["class_path"] == "otx.core.schedulers.LinearWarmupSchedulerCallable" + ): scheduler["init_args"]["num_warmup_steps"] = param_value + else: + warn("Warning: learning_rate_warmup_iters is not updated", stacklevel=1) def update_num_iters(param_value: int) -> None: config["max_epochs"] = param_value