-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathview_crops.py
138 lines (104 loc) · 3.75 KB
/
view_crops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
from copy import deepcopy as dc
import pandas as pd
from pathlib import Path
from PIL import Image
import numpy as np
from src.label_checker_automata import LabelCheckerAutomata
from src.armoria_api import ArmoriaAPIPayload, ArmoriaAPIWrapper
from src.caption import Caption
DEBUG = False
@st.cache
def load_data():
# data_dir = Path("data/cropped_coas/out/images")
data_dir = Path("data/new/images")
# else:
# data_dir = Path("data/cropped_coas/out_valid")
images, captions = [], []
for image_fn in data_dir.iterdir():
if image_fn.suffix == ".jpg" and not image_fn.name.startswith("."):
image = Image.open(image_fn)
image.thumbnail((150,150))
image = np.asarray(image)
images.append(image)
captions.append("_".join(image_fn.stem.split("_")[1:]))
df = pd.DataFrame.from_dict({
"image": images,
"caption": captions,
})
if DEBUG:
df = pd.DataFrame(df.sample(n=10))
generated_images = []
for irow, row in df.iterrows():
caption = Caption(row.caption, support_plural=False)
if caption.is_valid:
try:
armoria_payload = caption.get_armoria_payload_dict()
generated_image = np.asarray(ArmoriaAPIWrapper(
size=150,
format="png",
coa=armoria_payload
).get_image_bytes())
generated_images.append(generated_image)
except ValueError:
generated_images.append(None)
else:
generated_images.append(None)
df["generated_image"] = generated_images
return df
df = load_data()
df = dc(df)
with st.container():
st.write(f"""## Data description
* {len(df)} images
* {df.generated_image.notnull().sum()} images with armoria generated image"""
)
df = pd.DataFrame(df[df.generated_image.notnull()])
@st.cache
def get_caption_data(df):
result = []
automata = LabelCheckerAutomata()
for _, row in df.iterrows():
parsed_label = automata.parse_label(row.caption)
aligned = automata.align_parsed_label(row.caption, parsed_label)
result.append(aligned)
# print(result)
return pd.DataFrame(result)
df.reset_index(inplace=True)
caption_data = get_caption_data(df)
df = pd.concat([df, caption_data], axis=1)
all_colors = sorted(list(set([it for ll in caption_data.colors.to_list() for it in ll])))
all_objects = sorted(list(set([it for ll in caption_data.objects.to_list() for it in ll])))
all_modifiers = sorted(list(set([it for ll in caption_data.modifiers.to_list() for it in ll])))
selected_colors = st.sidebar.multiselect(
'colors',
all_colors,
all_colors
)
selected_objects = st.sidebar.multiselect(
'objects',
all_objects,
all_objects
)
selected_modifiers = st.sidebar.multiselect(
'modifiers',
all_modifiers,
all_modifiers
)
view = df[np.logical_and.reduce([
df.colors.apply(lambda color_list: len(set(color_list) - set(selected_colors)) == 0 ),
df.objects.apply(lambda object_list: len(set(object_list) & set(selected_objects))>0),
df.modifiers.apply(lambda modifier_list: len(set(modifier_list) & set(selected_modifiers))>0),
])]
if len(view)>10:
len_current_filter_set = len(view)
view = view.sample(n=100)
st.write("## !reduce result set to maximum of 100 samples!")
st.write(f"The current filter set actually returned {len_current_filter_set} samples")
for _,row in view.iterrows():
with st.container():
st.write(f"## {row.caption}")
col1, col2 = st.columns(2)
col1.image(row.image)
if row.generated_image is not None:
col2.image(row.generated_image)