Skip to content

Commit

Permalink
✏️ Minor changes and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Aug 31, 2024
1 parent 02b12f3 commit 7c570c8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions hezar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, config: ModelConfig, *args, **kwargs):
self.config = config.update(kwargs)
self._preprocessor = None
self._loss_func = self._set_loss_func(self.loss_func_name, **self.loss_func_kwargs)
self.inference_fn = self.generate if self.is_generative else self.__call__

def __repr__(self):
representation = super().__repr__()
Expand Down Expand Up @@ -397,9 +398,9 @@ def predict(
Prediction results, each model or task can have its own type and structure
"""
# Unpack kwargs for each step
preprocess_kwargs, forward_kwargs, post_process_kwargs = self._unpack_prediction_kwargs(**kwargs)
preprocess_kwargs, inference_kwargs, post_process_kwargs = self._unpack_prediction_kwargs(**kwargs)
invalid_kwargs = {
k: v for k, v in kwargs.items() if k not in {**preprocess_kwargs, **forward_kwargs, **post_process_kwargs}
k: v for k, v in kwargs.items() if k not in {**preprocess_kwargs, **inference_kwargs, **post_process_kwargs}
}
if len(invalid_kwargs):
logger.warning(
Expand All @@ -417,14 +418,11 @@ def predict(
model_inputs = self._move_inputs_to_device(model_inputs, device)
self.to(device)

# Specify model inference function
inference_fn = self.generate if self.is_generative else self.__call__

# Model inference step (forward for regular models and generate for generative models)
if isinstance(model_inputs, dict) and unpack_forward_inputs:
model_outputs = inference_fn(**model_inputs, **forward_kwargs)
model_outputs = self.inference_fn(**model_inputs, **inference_kwargs)
else:
model_outputs = inference_fn(model_inputs, **forward_kwargs)
model_outputs = self.inference_fn(model_inputs, **inference_kwargs)

# Post-processing step
processed_outputs = self.post_process(model_outputs, **post_process_kwargs) if post_process else model_outputs
Expand Down Expand Up @@ -470,10 +468,10 @@ def _unpack_prediction_kwargs(self, **kwargs):
inference_fn = type(self).generate if self.is_generative else type(self).forward

preprocess_kwargs = sanitize_function_parameters(type(self).preprocess, kwargs)
forward_kwargs = sanitize_function_parameters(inference_fn, kwargs)
inference_kwargs = sanitize_function_parameters(inference_fn, kwargs)
post_process_kwargs = sanitize_function_parameters(type(self).post_process, kwargs)

return preprocess_kwargs, forward_kwargs, post_process_kwargs
return preprocess_kwargs, inference_kwargs, post_process_kwargs

@property
def device(self):
Expand Down
2 changes: 1 addition & 1 deletion hezar/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def train(self, resume_from_checkpoint="deprecated"):
The full training process like training, evaluation, logging and saving model checkpoints.
The steps are as follows:
- The following is run for `self.config.num_epochs` times
The following is run for `self.config.num_epochs` times
- Run the training loop on the train dataset
- Save checkpoints
- Run evaluation on the evaluation dataset
Expand Down

0 comments on commit 7c570c8

Please sign in to comment.