-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_inference.py
56 lines (48 loc) · 2.18 KB
/
simple_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
from glob import glob
import torch
torch.manual_seed(42)
from PIL import Image
import numpy as np
np.random.seed(42)
from typing import List
from model_swin import SalFormer
from transformers import AutoImageProcessor, AutoTokenizer, BertModel, SwinModel
DEVICE = 'cuda'
def predict(ques: str, img_path: str) -> List:
"""
Execute the prediction.
Args:
ques: a question string to feed into VisSalFormer
Returns: [list]
- heatmap from VisSalFormer (np.array)
- Average WAVE score across pixels (float, [0, 1))
"""
image = Image.open(img_path).convert("RGB")
img_pt = image_processor(image, return_tensors="pt").to(DEVICE)
inputs = tokenizer(ques, return_tensors="pt").to(DEVICE)
mask = model(img_pt['pixel_values'], inputs)
mask = mask.detach().cpu().squeeze().numpy()
heatmap = (mask * 255).astype(np.uint8)
im_grey = image.convert('L')
heatmap = np.resize(heatmap, (image.size[1], image.size[0]))
return [heatmap, image, np.array(im_grey)]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--img_path", type=str, default="/netpool/homes/wangyo/Projects/chi2025_scanpath/evaluation/images/economist_daily_chart_85.png")
parser.add_argument("--query", type=str, default="type your query")
parser.add_argument("--type", type=str, default="A")
args = vars(parser.parse_args())
print(args["query"])
image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
vit = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased")
model = SalFormer(vit, bert).to(DEVICE)
checkpoint = torch.load('/netpool/homes/wangyo/Projects/vega_editor_backend/model/model_lr6e-5_wd1e-4.tar')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
predictions = predict(args['query'], args['img_path'])
np.save(f'predictions/{args["img_path"].split("/")[-1].strip(".png")}_{args["type"]}.npy', predictions[0]/255.)