-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathapp.py
127 lines (102 loc) · 3.88 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from predict import Predictor, model_cfg
from PIL import Image
import gradio as gr
# set a lot of global variables
predictor = None
vocabulary = ["bat man, woman"]
input_image: Image.Image = None
outputs: dict = None
cur_model_name: str = None
def set_vocabulary(text):
global vocabulary
vocabulary = text.split(",")
print("set vocabulary to", vocabulary)
def set_input(image):
global input_image
input_image = image
print("set input image to", image)
def set_predictor(model_name: str):
global cur_model_name
if cur_model_name == model_name:
return
global predictor
predictor = Predictor(**model_cfg[model_name])
print("set predictor to", model_name)
cur_model_name = model_name
set_predictor(list(model_cfg.keys())[0])
# for visualization
def visualize(vis_mode):
if outputs is None:
return None
return predictor.visualize(**outputs, mode=vis_mode)
def segment_image(vis_mode, voc_mode, model_name):
set_predictor(model_name)
if input_image is None:
return None
global outputs
result = predictor.predict(
input_image, vocabulary=vocabulary, augment_vocabulary=voc_mode
)
outputs = result
return visualize(vis_mode)
def segment_e2e(image, vis_mode):
set_input(image)
return segment_image(vis_mode)
# gradio
with gr.Blocks(
css="""
#submit {background: #3498db; color: white; border: none; padding: 10px 20px; border-radius: 5px;width: 20%;margin: 0 auto; display: block;}
"""
) as demo:
gr.Markdown(
f"<h1 style='text-align: center; margin-bottom: 1rem'>Side Adapter Network for Open-Vocabulary Semantic Segmentation</h1>"
)
gr.Markdown(
"""
This is the demo for our conference paper : "[Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)".
"""
)
# gr.Image(type="pil", value="./resources/arch.png", shape=(460, 200), elem_id="arch")
gr.Markdown(
"""
---
"""
)
with gr.Row():
image = gr.Image(type="pil", elem_id="input_image")
plt = gr.Image(type="pil", elem_id="output_image")
with gr.Row():
model_name = gr.Dropdown(
list(model_cfg.keys()), label="Model", value="san_vit_b_16"
)
augment_vocabulary = gr.Dropdown(
["COCO-all", "COCO-stuff"],
label="Vocabulary Expansion",
value="COCO-all",
)
vis_mode = gr.Dropdown(
["overlay", "mask"], label="Visualization Mode", value="overlay"
)
object_names = gr.Textbox(value=",".join(vocabulary), label="Object Names (Empty inputs will use the vocabulary specified in `Vocabulary Expansion`. Multiple names should be seperated with ,.)", lines=5)
button = gr.Button("Run", elem_id="submit")
note = gr.Markdown(
"""
---
### FAQ
- **Q**: What is the `Vocabulary Expansion` option for?
**A**: The vocabulary expansion option is used to expand the vocabulary of the model. The model assign category to each area with `argmax`. When only a vocabulary with few thing classes is provided, it will produce much false postive.
- **Q**: Error: `Unexpected token '<', " <h"... is not valid JSON.`. What should I do?
**A**: It is caused by a timeout error. Possibly your image is too large for a CPU server. Please try to use a smaller image or run it locally on a GPU server.
"""
)
#
object_names.change(set_vocabulary, [object_names], queue=False)
image.change(set_input, [image], queue=False)
vis_mode.change(visualize, [vis_mode], plt, queue=False)
button.click(
segment_image, [vis_mode, augment_vocabulary, model_name], plt, queue=False
)
demo.load(
segment_image, [vis_mode, augment_vocabulary, model_name], plt, queue=False
)
demo.queue().launch()