Skip to content

Commit

Permalink
fixed problem with apple silicon + option to save more frequently
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianGroeger96 committed Mar 22, 2024
1 parent a4d5e25 commit 0e6a0cf
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/cleaner/selfclean.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import os
import platform
from enum import Enum
from pathlib import Path
from typing import Optional, Union
Expand Down Expand Up @@ -131,6 +132,7 @@ def run_on_image_folder(
epochs: int = 100,
batch_size: int = 64,
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
num_workers: int = os.cpu_count(),
pretraining_type: PretrainingType = PretrainingType.DINO,
Expand All @@ -152,6 +154,7 @@ def run_on_image_folder(
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
save_every_n_epochs=save_every_n_epochs,
work_dir=work_dir,
num_workers=num_workers,
pretraining_type=pretraining_type,
Expand All @@ -171,6 +174,7 @@ def run_on_dataset(
epochs: int = 100,
batch_size: int = 64,
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
num_workers: int = os.cpu_count(),
pretraining_type: PretrainingType = PretrainingType.DINO,
Expand All @@ -188,6 +192,7 @@ def run_on_dataset(
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
save_every_n_epochs=save_every_n_epochs,
work_dir=work_dir,
num_workers=num_workers,
pretraining_type=pretraining_type,
Expand All @@ -207,6 +212,7 @@ def _run(
epochs: int = 100,
batch_size: int = 64,
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
num_workers: int = os.cpu_count(),
pretraining_type: PretrainingType = PretrainingType.DINO,
Expand All @@ -226,6 +232,7 @@ def _run(
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
save_every_n_epochs=save_every_n_epochs,
work_dir=work_dir,
hyperparameters=hyperparameters,
num_workers=num_workers,
Expand Down Expand Up @@ -278,6 +285,7 @@ def train_dino(
epochs: int = 100,
batch_size: int = 64,
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
num_workers: int = os.cpu_count(),
Expand All @@ -293,6 +301,7 @@ def train_dino(
hyperparameters["epochs"] = epochs
hyperparameters["batch_size"] = batch_size
hyperparameters["ssl_pre_training"] = ssl_pre_training
hyperparameters["save_every_n_epochs"] = save_every_n_epochs
if work_dir is not None:
hyperparameters["work_dir"] = work_dir

Expand All @@ -306,10 +315,16 @@ def train_dino(
kwargs = {"sampler": sampler}
else:
kwargs = {"shuffle": True}

# due to a problem with worker spawning on apple silicon
# we set it here to 0
kwargs["num_workers"] = num_workers
if platform.machine().lower() == "arm64":
kwargs["num_workers"] = 0

train_loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
pin_memory=True,
**kwargs,
Expand Down

0 comments on commit 0e6a0cf

Please sign in to comment.