diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index dab8d82a5a..53b2f7e39f 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -5,6 +5,7 @@ import warnings from copy import deepcopy from typing import Union, Tuple, Mapping, Dict, Any, List, Optional +import concurrent.futures import hydra import numpy as np @@ -115,12 +116,12 @@ logger = get_logger(__name__) + class PrefetchIterable: def __init__(self, iterable): self.iterable = iterable def __iter__(self): - import concurrent.futures executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) try: @@ -141,6 +142,7 @@ def _prefetch(): finally: executor.shutdown() + class Trainer: """ SuperGradient Model - Base Class for Sg Models