Skip to content

Commit

Permalink
possibility for non-gpu dino training
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianGroeger96 committed Mar 19, 2024
1 parent cf671fb commit e4daa88
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 26 deletions.
25 changes: 9 additions & 16 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,7 @@ ifeq ($(origin CONTAINER_NAME), undefined)
endif

ifeq ($(origin LOCAL_DATA_DIR), undefined)
LOCAL_DATA_DIR := /data/
endif

ifeq ($(origin DOCKER_SRC_DIR), undefined)
DOCKER_SRC_DIR := "/workspace/"
endif

ifeq ($(origin LOCAL_DATA_DIR), undefined)
LOCAL_DATA_DIR := /data/
LOCAL_DATA_DIR := $$PWD/data/
endif

ifeq ($(origin GPU_ID), undefined)
Expand All @@ -59,7 +51,11 @@ else
endif

ifeq ("$(GPU)", "false")
GPU_ARGS := --gpus '"device="'
ifeq (, $(shell which nvidia-smi))
GPU_ARGS :=
else
GPU_ARGS := --gpus '"device="'
endif
DOCKER_CONTAINER_NAME := --name $(PROJECTNAME)_$(CONTAINER_NAME)
else
GPU_ARGS := --gpus '"device=$(GPU_ID)"' --shm-size 200G --ipc=host
Expand Down Expand Up @@ -88,10 +84,6 @@ DOCKER_CMD := docker run $(DOCKER_ARGS) $(GPU_ARGS) $(DOCKER_CONTAINER_NAME) -it
###########################
# PROJECT UTILS
###########################
.PHONY: init
init: ##@Utils initializes the project and pulls all the nessecary data
@git submodule update --init --recursive

.PHONY: install
install: ##@Utils install the dependencies for the project
@python3 -m pip install -r requirements.txt
Expand Down Expand Up @@ -122,8 +114,9 @@ run_bash: _build ##@Docker run an interactive bash inside the docker image (def

start_jupyter: _build ##@Docker start a jupyter notebook inside the docker image
@echo "Starting jupyter notebook"
@-docker rm $(PROJECTNAME)_gpu_$(GPU_ID)
$(DOCKER_GPU_CMD) /bin/bash -c "jupyter notebook --allow-root --ip 0.0.0.0 --port 8888"
@-docker rm $(DOCKER_CONTAINER_NAME)
$(DOCKER_CMD) /bin/bash -c "jupyter notebook --allow-root --ip 0.0.0.0 --port 8888"
.DEFAULT_GOAL := help

###########################
# TESTS
Expand Down
31 changes: 22 additions & 9 deletions src/cleaner/selfclean.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Union

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
Expand Down Expand Up @@ -119,7 +120,8 @@ def run_on_image_folder(
epochs: int = 100,
batch_size: int = 32,
ssl_pre_training: bool = True,
num_workers: int = 48,
work_dir: Optional[str] = None,
num_workers: int = 24,
pretraining_type: PretrainingType = PretrainingType.DINO,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
# embedding
Expand All @@ -139,6 +141,7 @@ def run_on_image_folder(
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
work_dir=work_dir,
num_workers=num_workers,
pretraining_type=pretraining_type,
hyperparameters=hyperparameters,
Expand All @@ -157,7 +160,8 @@ def run_on_dataset(
epochs: int = 100,
batch_size: int = 32,
ssl_pre_training: bool = True,
num_workers: int = 48,
work_dir: Optional[str] = None,
num_workers: int = 24,
pretraining_type: PretrainingType = PretrainingType.DINO,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
# embedding
Expand All @@ -173,6 +177,7 @@ def run_on_dataset(
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
work_dir=work_dir,
num_workers=num_workers,
pretraining_type=pretraining_type,
hyperparameters=hyperparameters,
Expand All @@ -191,7 +196,8 @@ def _run(
epochs: int = 100,
batch_size: int = 32,
ssl_pre_training: bool = True,
num_workers: int = 48,
work_dir: Optional[str] = None,
num_workers: int = 24,
pretraining_type: PretrainingType = PretrainingType.DINO,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
# embedding
Expand All @@ -209,6 +215,7 @@ def _run(
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
work_dir=work_dir,
hyperparameters=hyperparameters,
num_workers=num_workers,
additional_run_info=additional_run_info,
Expand Down Expand Up @@ -257,8 +264,9 @@ def train_dino(
epochs: int = 100,
batch_size: int = 32,
ssl_pre_training: bool = True,
work_dir: Optional[str] = None,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
num_workers: int = 48,
num_workers: int = 24,
# logging
additional_run_info: str = "",
wandb_logging: bool = False,
Expand All @@ -268,24 +276,29 @@ def train_dino(
key in hyperparameters for key in DINO_STANDARD_HYPERPARAMETERS
), "`hyperparameters` need to contain all standard hyperparameters."

init_distributed_mode()

hyperparameters["epochs"] = epochs
hyperparameters["batch_size"] = batch_size
hyperparameters["ssl_pre_training"] = ssl_pre_training
if work_dir is not None:
hyperparameters["work_dir"] = work_dir

init_distributed_mode()
ssl_augmentation = iBOTDataAugmentation(
**hyperparameters["dataset"]["augmentations"]
)
set_dataset_transformation(dataset=dataset, transform=ssl_augmentation)
sampler = DistributedSampler(dataset, shuffle=True)
if torch.cuda.is_available():
sampler = DistributedSampler(dataset, shuffle=True)
kwargs = {"sampler": sampler}
else:
kwargs = {"shuffle": True}
train_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
drop_last=False,
pin_memory=True,
**kwargs,
)
trainer = DINOTrainer(
train_dataset=train_loader,
Expand All @@ -295,7 +308,7 @@ def train_dino(
wandb_project_name=wandb_project_name,
)
model = trainer.fit()
del trainer
del trainer, train_loader
gc.collect()
cleanup()
return model
2 changes: 1 addition & 1 deletion src/ssl_library
44 changes: 44 additions & 0 deletions tests/integration_tests/test_selfclean_IT.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import shutil
import tempfile
import unittest

from torchvision.datasets import FakeData
Expand All @@ -19,11 +20,49 @@ def tearDownClass(cls):
if re.search(pattern, dir_path):
shutil.rmtree(searchDir + dir_path)

def test_run_with_files_dino_in_workdir(self):
temp_work_dir = tempfile.TemporaryDirectory()
selfclean = SelfClean()
out_dict = selfclean.run_on_image_folder(
input_path=testfiles_path,
pretraining_type=PretrainingType.DINO,
work_dir=temp_work_dir.name,
epochs=1,
num_workers=4,
)
self.assertTrue("irrelevants" in out_dict)
self.assertTrue("near_duplicates" in out_dict)
self.assertTrue("label_errors" in out_dict)
for v in out_dict.values():
self.assertTrue("indices" in v)
self.assertTrue("scores" in v)
self.assertIsNotNone(v["indices"])
self.assertIsNotNone(v["scores"])

def test_run_with_files_dino_wo_pretraining(self):
selfclean = SelfClean()
out_dict = selfclean.run_on_image_folder(
input_path=testfiles_path,
pretraining_type=PretrainingType.DINO,
ssl_pre_training=False,
num_workers=4,
)
self.assertTrue("irrelevants" in out_dict)
self.assertTrue("near_duplicates" in out_dict)
self.assertTrue("label_errors" in out_dict)
for v in out_dict.values():
self.assertTrue("indices" in v)
self.assertTrue("scores" in v)
self.assertIsNotNone(v["indices"])
self.assertIsNotNone(v["scores"])

def test_run_with_files_dino(self):
selfclean = SelfClean()
out_dict = selfclean.run_on_image_folder(
input_path=testfiles_path,
pretraining_type=PretrainingType.DINO,
epochs=1,
num_workers=4,
)
self.assertTrue("irrelevants" in out_dict)
self.assertTrue("near_duplicates" in out_dict)
Expand All @@ -39,6 +78,7 @@ def test_run_with_files_imagenet(self):
out_dict = selfclean.run_on_image_folder(
input_path=testfiles_path,
pretraining_type=PretrainingType.IMAGENET,
num_workers=4,
)
self.assertTrue("irrelevants" in out_dict)
self.assertTrue("near_duplicates" in out_dict)
Expand All @@ -54,6 +94,8 @@ def test_run_with_files_imagenet_vit(self):
out_dict = selfclean.run_on_image_folder(
input_path=testfiles_path,
pretraining_type=PretrainingType.IMAGENET_VIT,
epochs=1,
num_workers=4,
)
self.assertTrue("irrelevants" in out_dict)
self.assertTrue("near_duplicates" in out_dict)
Expand All @@ -69,6 +111,8 @@ def test_run_with_dataset(self):
selfclean = SelfClean()
out_dict = selfclean.run_on_dataset(
dataset=fake_dataset,
epochs=1,
num_workers=4,
)
self.assertTrue("irrelevants" in out_dict)
self.assertTrue("near_duplicates" in out_dict)
Expand Down

0 comments on commit e4daa88

Please sign in to comment.