forked from Kashu7100/Qualia2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathceleb_a.py
31 lines (27 loc) · 951 Bytes
/
celeb_a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
from .. import to_cpu
from ..core import *
from .dataset import *
from .transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import os
import glob
class CelebA(Dataset):
''' CelebA Dataset\n
Args:
data_dir (str): the location of downloaded dataset
train (bool): if True, load training dataset
transforms (transforms): transforms to apply on the features
target_transforms (transforms): transforms to apply on the labels
'''
def __init__(self, data_dir,
train=True,
transforms=Compose([ToTensor(), Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),
target_transforms=None):
self.data_dir = data_dir
super().__init__(train, transforms, target_transforms)
def state_dict(self):
return {}
def prepare(self):
if len(glob.glob(self.root+'/*')) == 0:
pass