diff --git a/src/cleaner/selfclean.py b/src/cleaner/selfclean.py index 05e0cd4..e6a3c50 100644 --- a/src/cleaner/selfclean.py +++ b/src/cleaner/selfclean.py @@ -13,7 +13,7 @@ from torchvision.transforms import InterpolationMode from ..cleaner.selfclean_cleaner import SelfCleanCleaner -from ..ssl_library.src.augmentations.ibot import iBOTDataAugmentation +from ..ssl_library.src.augmentations.multi_crop import MultiCropAugmentation from ..ssl_library.src.pkg import Embedder, embed_dataset from ..ssl_library.src.trainers.dino_trainer import DINOTrainer from ..ssl_library.src.utils.logging import set_log_level @@ -37,17 +37,15 @@ "model": { "out_dim": 4096, "emb_dim": 192, - "base_model": "vit_tiny", + "base_model": "pretrained_imagenet_vit_tiny", "model_type": "VIT", "use_bn_in_head": False, "norm_last_layer": True, "student": { "drop_path_rate": 0.1, - "pretrained": True, }, "teacher": { "drop_path_rate": 0.1, - "pretrained": True, }, "eval": {"n_last_blocks": 4, "avgpool_patchtokens": False}, }, @@ -57,7 +55,7 @@ "local_crops_scale": "(0.05, 0.4)", "global_crops_number": 2, "local_crops_number": 12, - "random_rotation": True, + "apply_random_rotation": True, } }, "loss": { @@ -305,7 +303,7 @@ def train_dino( hyperparameters["work_dir"] = work_dir init_distributed_mode() - ssl_augmentation = iBOTDataAugmentation( + ssl_augmentation = MultiCropAugmentation( **hyperparameters["dataset"]["augmentations"] ) set_dataset_transformation(dataset=dataset, transform=ssl_augmentation) diff --git a/src/ssl_library b/src/ssl_library index 40f16de..7637a5d 160000 --- a/src/ssl_library +++ b/src/ssl_library @@ -1 +1 @@ -Subproject commit 40f16de553d018a9f47ab24cc8b32200bd029ee8 +Subproject commit 7637a5df885fa1a983868c4c96ad53e140a598df