Skip to content
Qisheng Robert He edited this page Jan 12, 2024 · 23 revisions

Packages

#torchmanager

Contains callbacks for fit method in the manager

Contains organized configurations tools

Contains wrapped loss functions

Contains wrapped metric functions

Classes

BaseManager

The basic manager

  • Properties: - compiled: A bool flag of if the manager has been compiled - loss_fn: A Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor - metrics: A dict of metrics with a name in str and a Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor - model: A target torch.nn.Module to be trained - optimizer: A torch.optim.Optimizer to train the model
  • Methods:
    • Constructor - Parameters: - loss_fn: An optional Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses - metrics: An optional dict of metrics with a name in str and a Metric object to calculate the metric - model: An optional target torch.nn.Module to be trained - optimizer: An optional torch.optim.Optimizer to train the model
    • compile
      • Compiles the manager
        • Parameters:
          • loss_fn: A Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses
          • metrics: A dict of metrics with a name in str and a Metric object to calculate the metric
          • optimizer: A torch.optim.Optimizer to train the model
    • from_checkpoint
      • Method to load a manager from a saved Checkpoint. The manager will not be compiled with a loss function and its metrics.
      • classmethod
        • Returns: A loaded Manager
    • to_checkpoint Convert the current manager to a checkpoint - Returns: A Checkpoint with its model in Module type

DataManager

The manager to load data during training or testing

  • Methods:
    • unpack_data Unpacks data to input and target - Parameters: - data: Any kind of data object - Returns: A tuple of Any kind of input and Any kind of target

TestingManager

A testing manager, only used for testing

  • extends: BaseManager, DataManager
  • Properties
    • compiled_losses: The loss function in Loss that must be exist
    • compiled_metrics: The dict of metrics in Metric that does not contain losses
  • Methods
    • test
      • Test target model
        • Parameters:
          • dataset: Either SizedIterable or data.DataLoader to load the dataset
          • device: An optional torch.device to test on
          • use_multi_gpus: A bool flag to use multi gpus during testing
          • show_verbose: A bool flag to show the progress bar during testing
        • Returns: A dict of validation summary
    • test_step
      • A single testing step
      • Parameters:
        • x_train: The testing data in torch.Tensor
        • y_train: The testing label in torch.Tensor
      • Returns: A dict of validation summary

Manager

A training manager

  • extends: TestingManager
  • [Deprecation Warning]: Method train becomes protected from v1.0.2, the public method will be removed from v1.2.0. Override _train method instead.
  • Compile a model, optimizer, loss function, and metrics into the manager:
import torch
from torchmanager import losses, metrics
class SomeModel(torch.nn.Module): ...
model = SomeModel()
optimizer = torch.optim.SGD(...)
loss_fn = losses.Loss(...)
metric_fns = {
...    ...
... }
manager = Manager(model, optimizer, loss_fn, metric_fns=metric_fns)
  • Train using fit method:
from torch.utils.data import Dataset, DataLoader
dataset = Dataset(...)
dataset = DataLoader(dataset, ...)
epochs: int = ...
manager.fit(dataset, epochs, ...)
  • Properties
    • current_epoch: The int index of current training epoch
    • compiled_optimizer: The torch.optim.Optimizer that must be exist
  • Methods
    • _train
      • The single training step for an epoch
        • Parameters:
          • dataset: A SizedIterable training dataset
          • iterations: An optional int of total training iterations, must be smaller than the size of dataset
          • device: A torch.device where the data is moved to, should be same as the model
          • use_multi_gpus: A bool flag of if using multi gpus
          • show_verbose: A bool flag of if showing progress bar
          • verbose_type: A view.VerboseType that controls the display of verbose
          • callbacks_list: A list of callbacks in Callback
        • Returns: A summary of dict with keys as str and values as float
    • fit
      • Training algorithm
        • Parameters:
          • training_dataset: Any kind of training dataset, must performs to SizedIterable
          • epochs: An optional int number of training epochs
          • iterations: An optional int number of training iterations
          • lr_scheduelr: An optioanl torch.optim.lr_scheduler._LRScheduler to update the lr per epoch
          • is_dynamic_pruning: A bool flag of if using dynamic pruning
          • val_dataset: An optional validation Any
          • device: An optional torch.device where the data is moved to, gpu will be used when available if not specified.
          • use_multi_gpus: A bool flag of if using multi gpus
          • callbacks_list: A list of callbacks in Callback
          • **kwargs: Additional keyword arguments that will be passed to train method.
        • Returns: A trained torch.nn.Module
    • train_step
      • A single training step
        • Parameters:
          • x_train: The training data
          • y_train: The training label
        • Returns: A summary of dict with keys as str and values as float