-
Notifications
You must be signed in to change notification settings - Fork 1
/
live.py
65 lines (55 loc) · 2 KB
/
live.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
import shlex
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def similarity(image, text, threshold, order):
lines = text.splitlines()
if len(lines) == 0:
return "", ""
inputs = processor(text=lines, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
similarities = outputs.logits_per_image.view(-1)
# convert to plain list of floats for display
similarities = [s.item() for s in similarities]
if order:
tfm = lambda xs: sorted(xs, reverse=True)
else:
tfm = lambda xs: xs
detections = [(f"{line}: {similarity:0.2f}", "yes" if similarity > threshold else "no") for similarity, line in tfm(zip(similarities, lines))]
# TODO add better indication of detection
clicmd = shlex.join([
"pluginctl",
"run",
"--name", "clip-app",
"waggle/clip-app:0.11.0",
"--",
"--input=bottom",
f"--threshold={threshold}",
*lines,
])
return detections, clicmd
demo = gr.Interface(
title="CLIP Explorer",
description="Input an image and lines of text then press submit to output the image-text similarity scores.",
fn=similarity,
inputs=[
gr.Image(label="Webcam", source="webcam", streaming=True),
gr.TextArea(label="Text descriptions"),
gr.Slider(0, 40, 26, label="Similarity threshold"),
gr.Checkbox(value=True, label="Order by similarity score?"),
],
outputs=[
gr.HighlightedText(label="Image-text similarity scores", color_map={
"yes": "green",
"no": "red",
}),
gr.TextArea(label="Pluginctl command"),
],
# outputs=[gr.TextArea(label="Image-text similarity scores"), gr.TextArea(label="Pluginctl command")],
live=True,
)
if __name__ == "__main__":
demo.launch()