Skip to content

Commit

Permalink
✨ Make log_steps and save_steps accept float value as a fraction …
Browse files Browse the repository at this point in the history
…in `TrainerConfig`
  • Loading branch information
arxyzan committed Jun 14, 2024
1 parent 348cd1f commit 28c39e0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
7 changes: 4 additions & 3 deletions docs/tutorial/training/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ Let's explore all the available parameters:
- **metric_for_best_model** (str):Reference metric key to watch for the best model.
Recommended to have a {`train.` | `evaluation.`} prefix (e.g, evaluation.f1, train.accuracy, etc.) but if not, defaults
to `evaluation.{metric_for_best_model}`.
- **save_steps** (int): Save the trainer outputs every `save_steps` steps. Set to `0` or `None` to ignore saving between
training steps.
- **log_steps** (int): Save training metrics every `log_steps` steps.
- **save_steps** (int): Save the trainer outputs every `save_steps` steps. Leave as None to ignore saving in-between training steps.
If set to a float value between 0 and 1, it will be interpreted as a fraction of the total steps.
- **log_steps** (int): Save training metrics every `log_steps` steps. If set to a float value between 0 and 1, it will be
interpreted as a fraction of the total steps.
- **checkpoints_dir** (str): Path to the checkpoints' folder. The actual files will be saved under `{output_dir}/{checkpoints_dir}`.
- **logs_dir** (str): Path to the logs' folder. The actual log files will be saved under `{output_dir}/{logs_dir}`.

Expand Down
14 changes: 9 additions & 5 deletions hezar/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,14 @@ class TrainerConfig(Config):
`evaluation.{metric_for_best_model}`.
save_freq (int) (DEPRECATED):
Deprecated and renamed to `save_steps`.
save_enabled (bool):
Whether to save checkpoints at all. `False` disables even the saves in-between the epochs.
save_steps (int):
Save the trainer outputs every `save_steps` steps. Leave as `0` to ignore saving between training steps.
Save the trainer outputs every `save_steps` steps. Leave as None to ignore saving in-between training steps.
If set to a float value between 0 and 1, it will be interpreted as a fraction of the total steps.
log_steps (int):
Save training metrics every `log_steps` steps.
Save training metrics every `log_steps` steps. If set to a float value between 0 and 1, it will be
interpreted as a fraction of the total steps.
checkpoints_dir (str):
Path to the checkpoints' folder. The actual files will be saved under `{output_dir}/{checkpoints_dir}`.
logs_dir (str):
Expand Down Expand Up @@ -492,7 +496,7 @@ class TrainerConfig(Config):
metrics: List[str | MetricConfig] = None
metric_for_best_model: str = "loss"
save_enabled: bool = True
save_freq: int = None
save_freq: int = "deprecated"
save_steps: int = None
log_steps: int = None
checkpoints_dir: str = "checkpoints"
Expand All @@ -519,15 +523,15 @@ def __post_init__(self):
self.metric_for_best_model = f"train.{self.metric_for_best_model}"

# Validate steps
if self.save_steps is not None and self.save_steps % self.gradient_accumulation_steps != 0:
if isinstance(self.save_steps, int) and self.save_steps % self.gradient_accumulation_steps != 0:
logger.warning(
f"It's recommended to set a `save_steps` dividable by `gradient_accumulation_steps`, "
f"otherwise, the saved model will have non-updated weights!\n"
f"`save_steps={self.save_steps}`, `gradient_accumulation_steps={self.gradient_accumulation_steps}`"
)

# Validate deprecated fields
if self.save_freq is not None:
if self.save_freq != "deprecated":
logger.warning(
"Trainer argument `save_freq` is deprecated! Use `save_steps` (number of training steps per save)."
"Note that saving is also done at the end of each epoch unless you set `save_enabled` to `False` !"
Expand Down
19 changes: 19 additions & 0 deletions hezar/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,30 @@ def __init__(
self.eval_dataset = eval_dataset
self.data_collator = data_collator or getattr(self.train_dataset, "data_collator", None)

# Configure steps
self.num_batches = math.ceil(len(self.train_dataset) / self.config.batch_size)
self.config.max_steps = (
math.ceil(self.config.max_steps * self.num_batches) * self.config.num_epochs
if isinstance(self.config.max_steps, float)
and 0 < self.config.max_steps <= 1
else self.config.max_steps
)
self.total_steps = min(
self.config.max_steps or self.num_batches * self.config.num_epochs,
self.num_batches * self.config.num_epochs
)
self.config.save_steps = (
math.ceil(self.config.save_steps * self.total_steps)
if isinstance(self.config.save_steps, float)
and 0 < self.config.save_steps <= 1
else self.config.save_steps
)
self.config.log_steps = (
math.ceil(self.config.log_steps * self.total_steps)
if isinstance(self.config.log_steps, float)
and 0 < self.config.log_steps <= 1
else self.config.log_steps
)
self.steps_in_epoch = min(self.num_batches, self.total_steps)
self.config.save_steps = self.steps_in_epoch if not self.config.save_steps else self.config.save_steps
self.saves_in_epoch = math.floor(self.steps_in_epoch / self.config.save_steps)
Expand Down

0 comments on commit 28c39e0

Please sign in to comment.