diff --git a/src/cleaner/selfclean.py b/src/cleaner/selfclean.py index 06b2edd..4822bae 100644 --- a/src/cleaner/selfclean.py +++ b/src/cleaner/selfclean.py @@ -1,5 +1,6 @@ import gc import os +import platform from enum import Enum from pathlib import Path from typing import Optional, Union @@ -131,6 +132,7 @@ def run_on_image_folder( epochs: int = 100, batch_size: int = 64, ssl_pre_training: bool = True, + save_every_n_epochs: int = 10, work_dir: Optional[str] = None, num_workers: int = os.cpu_count(), pretraining_type: PretrainingType = PretrainingType.DINO, @@ -152,6 +154,7 @@ def run_on_image_folder( epochs=epochs, batch_size=batch_size, ssl_pre_training=ssl_pre_training, + save_every_n_epochs=save_every_n_epochs, work_dir=work_dir, num_workers=num_workers, pretraining_type=pretraining_type, @@ -171,6 +174,7 @@ def run_on_dataset( epochs: int = 100, batch_size: int = 64, ssl_pre_training: bool = True, + save_every_n_epochs: int = 10, work_dir: Optional[str] = None, num_workers: int = os.cpu_count(), pretraining_type: PretrainingType = PretrainingType.DINO, @@ -188,6 +192,7 @@ def run_on_dataset( epochs=epochs, batch_size=batch_size, ssl_pre_training=ssl_pre_training, + save_every_n_epochs=save_every_n_epochs, work_dir=work_dir, num_workers=num_workers, pretraining_type=pretraining_type, @@ -207,6 +212,7 @@ def _run( epochs: int = 100, batch_size: int = 64, ssl_pre_training: bool = True, + save_every_n_epochs: int = 10, work_dir: Optional[str] = None, num_workers: int = os.cpu_count(), pretraining_type: PretrainingType = PretrainingType.DINO, @@ -226,6 +232,7 @@ def _run( epochs=epochs, batch_size=batch_size, ssl_pre_training=ssl_pre_training, + save_every_n_epochs=save_every_n_epochs, work_dir=work_dir, hyperparameters=hyperparameters, num_workers=num_workers, @@ -278,6 +285,7 @@ def train_dino( epochs: int = 100, batch_size: int = 64, ssl_pre_training: bool = True, + save_every_n_epochs: int = 10, work_dir: Optional[str] = None, hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS, num_workers: int = os.cpu_count(), @@ -293,6 +301,7 @@ def train_dino( hyperparameters["epochs"] = epochs hyperparameters["batch_size"] = batch_size hyperparameters["ssl_pre_training"] = ssl_pre_training + hyperparameters["save_every_n_epochs"] = save_every_n_epochs if work_dir is not None: hyperparameters["work_dir"] = work_dir @@ -306,10 +315,16 @@ def train_dino( kwargs = {"sampler": sampler} else: kwargs = {"shuffle": True} + + # due to a problem with worker spawning on apple silicon + # we set it here to 0 + kwargs["num_workers"] = num_workers + if platform.machine().lower() == "arm64": + kwargs["num_workers"] = 0 + train_loader = DataLoader( dataset, batch_size=batch_size, - num_workers=num_workers, drop_last=False, pin_memory=True, **kwargs,