-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
155 lines (134 loc) · 5.93 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token
from transformers.generation.streamers import TextIteratorStreamer
from PIL import Image
import requests
from io import BytesIO
from cog import BasePredictor, Input, Path, ConcatenateIterator
import time
import subprocess
from threading import Thread
import os
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
# url for the weights mirror
REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default"
# files to download from the weights mirrors
weights = [
{
"dest": "zhibinlan/AVG-LLaVA",
# git commit hash from huggingface
# "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8",
"files": [
"config.json",
"generation_config.json",
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin",
"pytorch_model.bin.index.json",
"special_tokens_map.json",
"tokenizer.model",
"tokenizer_config.json",
]
},
{
"dest": "openai/clip-vit-large-patch14-336",
"src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
"files": [
"config.json",
"preprocessor_config.json",
"pytorch_model.bin"
],
}
]
def download_json(url: str, dest: Path):
res = requests.get(url, allow_redirects=True)
if res.status_code == 200 and res.content:
with dest.open("wb") as f:
f.write(res.content)
else:
print(f"Failed to download {url}. Status code: {res.status_code}")
def download_weights(baseurl: str, basedest: str, files: list[str]):
basedest = Path(basedest)
start = time.time()
print("downloading to: ", basedest)
basedest.mkdir(parents=True, exist_ok=True)
for f in files:
dest = basedest / f
url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f)
if not dest.exists():
print("downloading url: ", url)
if dest.suffix == ".json":
download_json(url, dest)
else:
subprocess.check_call(["pget", url, str(dest)], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
for weight in weights:
download_weights(weight["src"], weight["dest"], weight["files"])
disable_torch_init()
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("zhibinlan/AVG-LLaVA", model_name="zhibinlan/AVG-LLaVA", model_base=None, load_8bit=False, load_4bit=False)
def predict(
self,
image: Path = Input(description="Input image"),
prompt: str = Input(description="Prompt to use for text generation"),
top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0),
temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0),
max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0),
) -> ConcatenateIterator[str]:
"""Run a single prediction on the model"""
conv_mode = "llava_v1"
conv = conv_templates[conv_mode].copy()
image_data = load_image(str(image))
image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda()
# loop start
# just one turn, always prepend image token
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0)
with torch.inference_mode():
thread = Thread(target=self.model.generate, kwargs=dict(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_tokens,
streamer=streamer,
use_cache=True))
thread.start()
# workaround: second-to-last token is always " "
# but we want to keep it if it's not the second-to-last token
prepend_space = False
for new_text in streamer:
if new_text == " ":
prepend_space = True
continue
if new_text.endswith(stop_str):
new_text = new_text[:-len(stop_str)].strip()
prepend_space = False
elif prepend_space:
new_text = " " + new_text
prepend_space = False
if len(new_text):
yield new_text
if prepend_space:
yield " "
thread.join()
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image