Skip to content

Commit

Permalink
Add type checker in converter for callable functions (optimizer, sche…
Browse files Browse the repository at this point in the history
…duler) (#3968)

Fix converter callable functions (optimizer, scheduler)
  • Loading branch information
harimkang authored Sep 20, 2024
1 parent 93f1a55 commit 45f9a24
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/otx/tools/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,22 @@ def update_inference_batch_size(param_value: int) -> None:
config["data"]["test_subset"]["batch_size"] = param_value

def update_learning_rate(param_value: float) -> None:
config["model"]["init_args"]["optimizer"]["init_args"]["lr"] = param_value
optimizer = config["model"]["init_args"]["optimizer"]
if isinstance(optimizer, dict) and "init_args" in optimizer:
optimizer["init_args"]["lr"] = param_value
else:
warn("Warning: learning_rate is not updated", stacklevel=1)

def update_learning_rate_warmup_iters(param_value: int) -> None:
scheduler = config["model"]["init_args"]["scheduler"]
if scheduler["class_path"] == "otx.core.schedulers.LinearWarmupSchedulerCallable":
if (
isinstance(scheduler, dict)
and "class_path" in scheduler
and scheduler["class_path"] == "otx.core.schedulers.LinearWarmupSchedulerCallable"
):
scheduler["init_args"]["num_warmup_steps"] = param_value
else:
warn("Warning: learning_rate_warmup_iters is not updated", stacklevel=1)

def update_num_iters(param_value: int) -> None:
config["max_epochs"] = param_value
Expand Down

0 comments on commit 45f9a24

Please sign in to comment.