-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How can an image be input into a model to output its scene graph information and bounding box information for visualization? #7
Comments
Hi :). I was also deeply impressed by the EGTR paper. I am facing the same issue. As I am not very knowledgeable in the SGG domain, I understand that it involves predicting relationships between objects. I need to perform the SGG as a preliminary step to use SG embedding for downstream tasks. I would like to refer to EGTR as the SGG module, but could you please let me know how the output scene graph is generated? |
First of all, I have not attempted to extract the scene graph well other than measuring measures for evaluation. from glob import glob
import torch
from PIL import Image
from model.deformable_detr import DeformableDetrConfig, DeformableDetrFeatureExtractor
from model.egtr import DetrForSceneGraphGeneration
# config
architecture = "SenseTime/deformable-detr"
min_size = 800
max_size = 1333
artifact_path = YOUR_ARTIFACT_PATH
# feature extractor
feature_extractor = DeformableDetrFeatureExtractor.from_pretrained(
architecture, size=min_size, max_size=max_size
)
# inference image
image = Image.open(YOUR_IMAGE_PATH)
image = feature_extractor(image, return_tensors="pt")
# model
config = DeformableDetrConfig.from_pretrained(artifact_path)
model = DetrForSceneGraphGeneration.from_pretrained(
architecture, config=config, ignore_mismatched_sizes=True
)
ckpt_path = sorted(
glob(f"{artifact_path}/checkpoints/epoch=*.ckpt"),
key=lambda x: int(x.split("epoch=")[1].split("-")[0]),
)[-1]
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
for k in list(state_dict.keys()):
state_dict[k[6:]] = state_dict.pop(k) # "model."
model.load_state_dict(state_dict)
model.cuda()
model.eval()
# output
outputs = model(
pixel_values=image['pixel_values'].cuda(),
pixel_mask=image['pixel_mask'].cuda(),
output_attention_states=True
)
pred_logits = outputs['logits'][0]
obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)
pred_boxes = outputs['pred_boxes'][0]
pred_connectivity = outputs['pred_connectivity'][0]
pred_rel = outputs['pred_rel'][0]
pred_rel = torch.mul(pred_rel, pred_connectivity)
# get valid objects and triplets
obj_threshold = YOUR_OBJ_THRESHOLD
valid_obj_indices = (obj_scores >= obj_threshold).nonzero()[:, 0]
valid_obj_classes = pred_classes[valid_obj_indices] # [num_valid_objects]
valid_obj_boxes = pred_boxes[valid_obj_indices] # [num_valid_objects, 4]
rel_threshold = YOUR_REL_THRESHOLD
valid_triplets = (pred_rel[valid_obj_indices][:, valid_obj_indices] >= rel_threshold).nonzero() # [num_valid_triplets, 3] You can generate a scene graph based on valid_obj_classes, valid_obj_boxes, and valid_triplets.
I built a scene graph using thresholds in this example, but it can also be implemented by selecting the top k objects or triplets. |
@PiPiSang (1) obj boxes
As I mentioned before, pred_boxes are cxcywh format. (2) obj scores We used Deformable DETR rather than DETR, and in Deformable DETR, focal loss is used for object detection instead of cross entropy loss.
Therefore, It is more natural to use sigmoid instead of softmax (https://github.com/huggingface/transformers/blob/409fcfdfccde77a14b7cc36972b774cabc371ae1/src/transformers/models/deformable_detr/image_processing_deformable_detr.py#L1555), but we used softmax to get obj_scores. |
@jinbae |
@Aoihigashi |
@PiPiSang |
@Aoihigashi Yes. The obj_threshold and rel_threshold for me are 0.3 and 1e-4. But based on my tests, I strongly suggest that you directly select the top_k triplets based on their scores. Then, you can read the function where the author calculates the loss in train_egtr.py. This demonstrates how to select the top_k triplets and perform the transformation. |
@PiPiSang |
Could you please explain in detail how to operate it? Thank you very much. |
@riverjunhao Sure! Here’s my code. However, I have only visualized the bounding boxes. The scene graph is stored in a text file in a textual format. Since I am not working on scene graph research, I haven’t delved into how to visualize the scene graph on an image. So, if needed, you can explore how to implement this feature yourself. If you could share your findings, I believe it would be very helpful for other beginners.I modified the calculate_fps function in the original file evaluate_egtr.py and added some necessary auxiliary functions. However, since I haven’t tried this code for more than a month, there might be some issues. import os
import argparse
import json
from glob import glob
from PIL import Image, ImageDraw, ImageFont
import torch
@torch.no_grad()
def calculate_fps(model, dataloader, top_k):
model.eval()
i = 0
for batch in tqdm(dataloader):
i = i + 1
if i > 10:
break
outputs = model(
pixel_values=batch["pixel_values"].cuda(),
pixel_mask=batch["pixel_mask"].cuda(),
output_attentions=False,
output_attention_states=True,
output_hidden_states=True,
)
pred_logits = outputs['logits'][0]
obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)
pred_classes_numpy = pred_classes.cpu().clone().numpy()
sub_ob_scores = torch.outer(obj_scores, obj_scores)
sub_ob_scores[
torch.arange(pred_logits.size(0)), torch.arange(pred_logits.size(0))
] = 0.0 # prevent self-connection
pred_boxes = outputs['pred_boxes'][0]
pred_boxes_numpy = pred_boxes.cpu().clone().numpy()
pred_connectivity = outputs["pred_connectivity"][0]
pred_rel = outputs["pred_rel"][0]
pred_rel = torch.mul(pred_rel, pred_connectivity)
triplet_scores = torch.mul(pred_rel.max(-1)[0], sub_ob_scores)
pred_rel_inds = argsort_desc(triplet_scores.cpu().clone().numpy())[:top_k, :] # [pred_rels, 2(s,o)]
rel_scores = (pred_rel.cpu().clone().numpy()[pred_rel_inds[:, 0], pred_rel_inds[:, 1]])
pred_rels = np.column_stack((pred_rel_inds, rel_scores.argmax(1)))
triplets, triplets_box = get_triplets(pred_rels, pred_classes_numpy, pred_boxes_numpy)
# Get valid objects and triplets
obj_threshold = 0.3
valid_obj_indices = (obj_scores >= obj_threshold).nonzero()[:top_k, 0]
valid_obj_classes = pred_classes[valid_obj_indices] # [num_valid_objects]
valid_obj_boxes = pred_boxes[valid_obj_indices] # [num_valid_objects, 4](x_center, y_center, width, height)
if not os.path.exists("./val_images"):
os.mkdir("./val_images")
if not os.path.exists('./relationship'):
os.mkdir("./relationship")
if not os.path.exists(f"./val_images/{batch['id']}.png"):
image = Image.fromarray(batch['img'], 'RGB')
filename = f"./val_images/{batch['id']}"
image.save(f"{filename}.png")
visualization(image, valid_obj_boxes, filename, valid_obj_classes)
write_triplets(triplets, batch['id'])
def get_triplets(pred_triplets, classes, boxes):
triplets, triplet_boxes = [], []
for sub, obj, rel in pred_triplets:
triplet_boxes.append([boxes[sub], boxes[obj]])
sub = id2label[classes[sub]]
obj = id2label[classes[obj]]
rel = id2relation[rel]
triplets.append([obj, rel, sub])
return triplets, triplet_boxes
def write_triplets(triplets, id):
file_path = f'./relationship/{id}.txt'
with open(file_path, 'w') as file:
for sub, rel, obj in triplets:
file.write(str(obj) + ' ' + str(rel) + ' ' + str(sub) + '\n')
def get_relation(triplets, id):
relation_list = []
for sub, obj, rel in triplets:
sub = id2label[sub]
obj = id2label[obj]
rel = id2relation[rel]
relation_list.append([sub, obj, rel])
file_path = f'./relationship/{id}.txt'
with open(file_path, 'w') as file:
for sub, obj, rel in relation_list:
file.write(str(obj) + ' ' + str(rel) + ' ' + str(sub) + '\n')
def visualization(img, box, filename, obj_class):
colors = ['red', 'blue', 'orange', 'green', 'yellow', 'purple']
font = ImageFont.truetype("/usr/share/fonts/truetype/tlwg/Garuda.ttf", size=100)
obj_list = obj_class.cpu().numpy()
bboxes = box.cpu().numpy()
draw = ImageDraw.Draw(img)
font = ImageFont.load_default()
i = 0
for bbox, obj in zip(bboxes, obj_list):
x_center, y_center, width, height = bbox
x_max = int((x_center + width/2) * img.width)
y_max = int((y_center + height/2) * img.height)
x_min = int((x_center - width/2) * img.width)
y_min = int((y_center - height/2) * img.height)
obj = id2label[obj]
text_width, text_height = 10, 10
text_x = max(x_min + text_width + 5, 0)
text_y = y_min + text_height + 5
draw.text((text_x, text_y), str(obj), fill=colors[i], font=font)
draw.rectangle([(x_min, y_min), (x_max, y_max)], outline=colors[i], fill=None, width=3)
i = (i + 1)%6
img.save(f"{filename}_box.png") |
Thank you very much for your reply! I need to correspond the output index or label to the actual category, which seems to require id2label and id2elaboration in your code, but I didn't find any relevant information in the code provided by the author. Can you tell me how this was obtained? |
@riverjunhao In fact, the author has already provided the id2label and id2relation files. You can take a closer look at this answer, and then modify these two files as needed. I hope this helps. I tested this on the Visual Genome dataset, where the id2label and id2relation variables are loaded when loading the dataset. The code is as follows: # Dataset
if "visual_genome" in args.data_path:
test_dataset = VGDataset(
data_folder=args.data_path,
feature_extractor=feature_extractor,
split=args.split,
num_object_queries=args.num_queries,
)
id2relation = test_dataset.rel_categories
id2label = {
k - 1: v["name"] for k, v in test_dataset.coco.cats.items()
} # 0 ~ 149
coco_evaluator = CocoEvaluator(
test_dataset.coco, ["bbox"]
) # initialize evaluator with ground truths
oi_evaluator = None |
@SuperFanBin Thank you very much for your reply! I have tried the code you provided and made some modifications, but there is still an issue: writing information to the specified file requires reading the contents of the batch ['id '] and batch ['img'] keys, which do not seem to be in the 'batch' dictionary. I tried to print the 'batch' dictionary and use the information contained in the [image_id] key under [labels] as the path name to save the file, but I still couldn't extract the content of batch ['img '] well and save it in the image. Could you solve my problem? Thank you very much! |
First of all, the work is very instructive, thank you! As a novice to scene graph generation, my current research requires methods to extract information from images using scene graph generation. So I'm curious about how to use the output of the model to generate a scene graph, and I'm confused about how to visualize the output. Any guidance would be greatly appreciated.
The text was updated successfully, but these errors were encountered: