-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtask_vlcs.py
64 lines (58 loc) · 2.18 KB
/
task_vlcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from torchvision import transforms
from domainlab.tasks.task_folder_mk import mk_task_folder
from domainlab.tasks.utils_task import ImSize
# relative path is essential here since this file is used for testing, no absolute directory possible
path_this_file = os.path.dirname(os.path.realpath(__file__))
chain = mk_task_folder(
extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"},
list_str_y=["chair", "car"],
dict_domain_folder_name2class={
"caltech": {"auto": "car", "stuhl": "chair"},
"sun": {"vehicle": "car", "sofa": "chair"},
"labelme": {"drive": "car", "sit": "chair"},
},
dict_domain_img_trans={
"caltech": transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
"sun": transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
"labelme": transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
},
img_trans_te=transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
isize=ImSize(3, 224, 224),
dict_domain2imgroot={
"caltech": os.path.join(path_this_file, "../../domainlab/zdata/vlcs_mini/caltech/"),
"sun": os.path.join(path_this_file, "../../domainlab/zdata/vlcs_mini/sun/"),
"labelme": os.path.join(path_this_file, "../../domainlab/zdata/vlcs_mini/labelme/"),
},
taskna="e_mini_vlcs",
)
def get_task(na=None):
return chain