From a9e9b9c8f6c5c6bf3c1fb2158d4989d342dda1e6 Mon Sep 17 00:00:00 2001 From: brando90 Date: Sat, 5 Feb 2022 10:39:25 -0600 Subject: [PATCH 1/3] added the rfs transforms for cifarfs (and fc100) added the rfs transforms for cifarfs (and fc100) --- .../vision/benchmarks/cifarfs_benchmark.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/learn2learn/vision/benchmarks/cifarfs_benchmark.py b/learn2learn/vision/benchmarks/cifarfs_benchmark.py index 886cb679..41b830ad 100644 --- a/learn2learn/vision/benchmarks/cifarfs_benchmark.py +++ b/learn2learn/vision/benchmarks/cifarfs_benchmark.py @@ -12,11 +12,58 @@ def cifarfs_tasksets( test_ways=5, test_samples=10, root='~/data', + data_augmentation=None, device=None, **kwargs, ): """Tasksets for CIFAR-FS benchmarks.""" - data_transform = tv.transforms.ToTensor() + if data_augmentation is None: + train_data_transforms = tv.transforms.ToTensor() + test_data_transforms = tv.transforms.ToTensor() + elif data_augmentation == 'normalize': + train_data_transforms = Compose([ + lambda x: x / 255.0, + ]) + test_data_transforms = train_data_transforms + elif data_augmentation == 'rfs2020': + """ + # original + if augment: + transform = transforms.Compose([ + lambda x: Image.fromarray(x), + transforms.RandomCrop(32, padding=4), + transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), + transforms.RandomHorizontalFlip(), + lambda x: np.asarray(x), + transforms.ToTensor(), + normalize_cifar100 + ]) + else: + transform = transforms.Compose([ + lambda x: Image.fromarray(x), + transforms.ToTensor(), + normalize_cifar100 + ]) + return transform + """ + mean = [0.5071, 0.4867, 0.4408] + std = [0.2675, 0.2565, 0.2761] + normalize = tv.transforms.Normalize(mean=mean, std=std) + train_data_transforms = Compose([ + ToPILImage(), + RandomCrop(32, padding=4), + ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ]) + test_data_transforms = Compose([ + ToTensor(), + normalize, + ]) + else: + raise('Invalid data_augmentation argument.') + train_dataset = l2l.vision.datasets.CIFARFS(root=root, transform=data_transform, mode='train', From 6b01dea4c5873a4e2ce8a88309baba5c076cf067 Mon Sep 17 00:00:00 2001 From: brando90 Date: Sat, 5 Feb 2022 12:12:03 -0600 Subject: [PATCH 2/3] Update cifarfs_benchmark.py --- learn2learn/vision/benchmarks/cifarfs_benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/learn2learn/vision/benchmarks/cifarfs_benchmark.py b/learn2learn/vision/benchmarks/cifarfs_benchmark.py index 41b830ad..5fbc2fc7 100644 --- a/learn2learn/vision/benchmarks/cifarfs_benchmark.py +++ b/learn2learn/vision/benchmarks/cifarfs_benchmark.py @@ -65,15 +65,15 @@ def cifarfs_tasksets( raise('Invalid data_augmentation argument.') train_dataset = l2l.vision.datasets.CIFARFS(root=root, - transform=data_transform, + transform=train_data_transforms, mode='train', download=True) valid_dataset = l2l.vision.datasets.CIFARFS(root=root, - transform=data_transform, + transform=train_data_transforms, mode='validation', download=True) test_dataset = l2l.vision.datasets.CIFARFS(root=root, - transform=data_transform, + transform=test_data_transforms, mode='test', download=True) if device is not None: From cbafcd062a48fc57f09217df1fee2bee5d717989 Mon Sep 17 00:00:00 2001 From: brando90 Date: Sat, 5 Feb 2022 12:46:34 -0600 Subject: [PATCH 3/3] forgot from torchvision.transforms import Compose --- learn2learn/vision/benchmarks/cifarfs_benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/learn2learn/vision/benchmarks/cifarfs_benchmark.py b/learn2learn/vision/benchmarks/cifarfs_benchmark.py index 5fbc2fc7..b03e27f6 100644 --- a/learn2learn/vision/benchmarks/cifarfs_benchmark.py +++ b/learn2learn/vision/benchmarks/cifarfs_benchmark.py @@ -5,6 +5,7 @@ from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels +from torchvision.transforms import Compose def cifarfs_tasksets( train_ways=5,