Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can an image be input into a model to output its scene graph information and bounding box information for visualization? #7

Open
SuperFanBin opened this issue Aug 29, 2024 · 13 comments

Comments

@SuperFanBin
Copy link

First of all, the work is very instructive, thank you! As a novice to scene graph generation, my current research requires methods to extract information from images using scene graph generation. So I'm curious about how to use the output of the model to generate a scene graph, and I'm confused about how to visualize the output. Any guidance would be greatly appreciated.

@JeongSooHwan
Copy link

Hi :). I was also deeply impressed by the EGTR paper. I am facing the same issue. As I am not very knowledgeable in the SGG domain, I understand that it involves predicting relationships between objects. I need to perform the SGG as a preliminary step to use SG embedding for downstream tasks. I would like to refer to EGTR as the SGG module, but could you please let me know how the output scene graph is generated?

@jinbae
Copy link
Collaborator

jinbae commented Aug 30, 2024

First of all, I have not attempted to extract the scene graph well other than measuring measures for evaluation.
So, although I share the code snippet for extracting the scene graph, I recommend that you improve it to suit your own needs.
Following code snippet is based on evaluate_egtr.py.

from glob import glob

import torch
from PIL import Image

from model.deformable_detr import DeformableDetrConfig, DeformableDetrFeatureExtractor
from model.egtr import DetrForSceneGraphGeneration

# config
architecture = "SenseTime/deformable-detr"
min_size = 800
max_size = 1333
artifact_path = YOUR_ARTIFACT_PATH

# feature extractor
feature_extractor = DeformableDetrFeatureExtractor.from_pretrained(
    architecture, size=min_size, max_size=max_size
)

# inference image
image = Image.open(YOUR_IMAGE_PATH)
image = feature_extractor(image, return_tensors="pt")

# model
config = DeformableDetrConfig.from_pretrained(artifact_path)
model = DetrForSceneGraphGeneration.from_pretrained(
    architecture, config=config, ignore_mismatched_sizes=True
)
ckpt_path = sorted(
    glob(f"{artifact_path}/checkpoints/epoch=*.ckpt"),
    key=lambda x: int(x.split("epoch=")[1].split("-")[0]),
)[-1]
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
for k in list(state_dict.keys()):
    state_dict[k[6:]] = state_dict.pop(k)  # "model."

model.load_state_dict(state_dict)
model.cuda()
model.eval()

# output
outputs = model(
    pixel_values=image['pixel_values'].cuda(), 
    pixel_mask=image['pixel_mask'].cuda(), 
    output_attention_states=True
)

pred_logits = outputs['logits'][0]
obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)
pred_boxes = outputs['pred_boxes'][0]

pred_connectivity = outputs['pred_connectivity'][0]
pred_rel = outputs['pred_rel'][0]
pred_rel = torch.mul(pred_rel, pred_connectivity)

# get valid objects and triplets
obj_threshold = YOUR_OBJ_THRESHOLD
valid_obj_indices = (obj_scores >= obj_threshold).nonzero()[:, 0]

valid_obj_classes = pred_classes[valid_obj_indices] # [num_valid_objects]
valid_obj_boxes = pred_boxes[valid_obj_indices] # [num_valid_objects, 4]

rel_threshold = YOUR_REL_THRESHOLD
valid_triplets = (pred_rel[valid_obj_indices][:, valid_obj_indices] >= rel_threshold).nonzero() # [num_valid_triplets, 3]

You can generate a scene graph based on valid_obj_classes, valid_obj_boxes, and valid_triplets.

  • valid_obj_classes: object classes
  • valid_obj_boxes: object bounding boxes (cxcywh format)
  • valid_triplets: relation triplets (subject entity index, object entity index, relation class)
    • Please note that subject entity index and object entity index indicate the indices of valid objects.

I built a scene graph using thresholds in this example, but it can also be implemented by selecting the top k objects or triplets.
Since the thresholds have never been explored, it may be important to set the threshold well.

@jinbae
Copy link
Collaborator

jinbae commented Sep 2, 2024

@PiPiSang

(1) obj boxes

valid_obj_boxes: object bounding boxes (cxcywh format)

As I mentioned before, pred_boxes are cxcywh format.
Please make sure that pred_boxes have been converted to xyxy format before bbox visualization.

(2) obj scores

We used Deformable DETR rather than DETR, and in Deformable DETR, focal loss is used for object detection instead of cross entropy loss.

obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)

Therefore, It is more natural to use sigmoid instead of softmax (https://github.com/huggingface/transformers/blob/409fcfdfccde77a14b7cc36972b774cabc371ae1/src/transformers/models/deformable_detr/image_processing_deformable_detr.py#L1555), but we used softmax to get obj_scores.
Obj_scores may be low compared to models trained with the cross entropy loss.

@Aoihigashi
Copy link

@jinbae
Thank you very much for your response regarding scene graph visualization. I followed the code you provided for inference on a single image and used the pretrained weights for the oi dataset as mentioned in the README for inference. However, I encountered the following error.
image
I then tried using the weights I trained myself on the oi dataset, but the same error occurred. I would like to know why the class dimension in the weights is 91 instead of 601, which is the number of classes in the oi dataset. I am eagerly awaiting your response.

@SuperFanBin
Copy link
Author

SuperFanBin commented Sep 5, 2024

@Aoihigashi
In fact, this is not an error. This so-called error is emitted during the execution of the model initialization. You can try debugging this test code, and you will find that after executing the code on line 229, the console will output this message. Moreover, the code on line 229 does not actually load the trained weights into the model. This code simply creates a model that matches the structure specified in the configuration file and initializes it with the official default weights. Therefore, this message is indicating that the structure of the model does not fully match the official model structure, which leads to the inability to completely load the weights during initialization. The code that truly loads our own trained weights into the model is on line 240. So, you can relax, there is no issue.

@Aoihigashi
Copy link

@PiPiSang
Thank you! I see, that’s exactly the case, and it does run. However, I have to set the obj_threshold and rel_threshold very low (0.1) to get any output. Have you encountered a similar issue?

@SuperFanBin
Copy link
Author

@Aoihigashi Yes. The obj_threshold and rel_threshold for me are 0.3 and 1e-4. But based on my tests, I strongly suggest that you directly select the top_k triplets based on their scores. Then, you can read the function where the author calculates the loss in train_egtr.py. This demonstrates how to select the top_k triplets and perform the transformation.

@Aoihigashi
Copy link

@PiPiSang
Thank you for your suggestions and response. I will try following the advice.

@riverjunhao
Copy link

@Aoihigashi Yes. The obj_threshold and rel_threshold for me are 0.3 and 1e-4. But based on my tests, I strongly suggest that you directly select the top_k triplets based on their scores. Then, you can read the function where the author calculates the loss in train_egtr.py. This demonstrates how to select the top_k triplets and perform the transformation.

Could you please explain in detail how to operate it? Thank you very much.

@SuperFanBin
Copy link
Author

SuperFanBin commented Sep 27, 2024

@riverjunhao Sure! Here’s my code. However, I have only visualized the bounding boxes. The scene graph is stored in a text file in a textual format. Since I am not working on scene graph research, I haven’t delved into how to visualize the scene graph on an image. So, if needed, you can explore how to implement this feature yourself. If you could share your findings, I believe it would be very helpful for other beginners.I modified the calculate_fps function in the original file evaluate_egtr.py and added some necessary auxiliary functions. However, since I haven’t tried this code for more than a month, there might be some issues.
截屏2024-09-27 11 54 27

import os
import argparse
import json
from glob import glob
from PIL import Image, ImageDraw, ImageFont
import torch

@torch.no_grad()
def calculate_fps(model, dataloader, top_k):
    model.eval()
    i = 0
    for batch in tqdm(dataloader):
        i = i + 1
        if i > 10:
            break
        outputs = model(
            pixel_values=batch["pixel_values"].cuda(),
            pixel_mask=batch["pixel_mask"].cuda(),
            output_attentions=False,
            output_attention_states=True,
            output_hidden_states=True,
        )
        pred_logits = outputs['logits'][0]
        obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)
        pred_classes_numpy = pred_classes.cpu().clone().numpy()
        sub_ob_scores = torch.outer(obj_scores, obj_scores)
        sub_ob_scores[
            torch.arange(pred_logits.size(0)), torch.arange(pred_logits.size(0))
        ] = 0.0  # prevent self-connection

        pred_boxes = outputs['pred_boxes'][0]
        pred_boxes_numpy = pred_boxes.cpu().clone().numpy()
        pred_connectivity = outputs["pred_connectivity"][0]
        pred_rel = outputs["pred_rel"][0]
        pred_rel = torch.mul(pred_rel, pred_connectivity)
        triplet_scores = torch.mul(pred_rel.max(-1)[0], sub_ob_scores)
        pred_rel_inds = argsort_desc(triplet_scores.cpu().clone().numpy())[:top_k, :]  # [pred_rels, 2(s,o)]
        rel_scores = (pred_rel.cpu().clone().numpy()[pred_rel_inds[:, 0], pred_rel_inds[:, 1]])
        pred_rels = np.column_stack((pred_rel_inds, rel_scores.argmax(1)))
        triplets, triplets_box = get_triplets(pred_rels, pred_classes_numpy, pred_boxes_numpy)
        
        # Get valid objects and triplets
        obj_threshold = 0.3
        valid_obj_indices = (obj_scores >= obj_threshold).nonzero()[:top_k, 0]
        valid_obj_classes = pred_classes[valid_obj_indices]  # [num_valid_objects]
        valid_obj_boxes = pred_boxes[valid_obj_indices]  # [num_valid_objects, 4](x_center, y_center, width, height)
        
        if not os.path.exists("./val_images"):
            os.mkdir("./val_images")
        if not os.path.exists('./relationship'):
            os.mkdir("./relationship")
        if not os.path.exists(f"./val_images/{batch['id']}.png"):
            image = Image.fromarray(batch['img'], 'RGB')
            filename = f"./val_images/{batch['id']}"
            image.save(f"{filename}.png")
            visualization(image, valid_obj_boxes, filename, valid_obj_classes)
            write_triplets(triplets, batch['id'])

def get_triplets(pred_triplets, classes, boxes):
    triplets, triplet_boxes = [], []
    for sub, obj, rel in pred_triplets:
        triplet_boxes.append([boxes[sub], boxes[obj]])
        sub = id2label[classes[sub]]
        obj = id2label[classes[obj]]
        rel = id2relation[rel]
        triplets.append([obj, rel, sub])
    return triplets, triplet_boxes

def write_triplets(triplets, id):
    file_path = f'./relationship/{id}.txt'
    with open(file_path, 'w') as file:
        for sub, rel, obj in triplets:
            file.write(str(obj) + ' ' + str(rel) + ' ' + str(sub) + '\n')

def get_relation(triplets, id):
    relation_list = []
    for sub, obj, rel in triplets:
        sub = id2label[sub]
        obj = id2label[obj]
        rel = id2relation[rel]
        relation_list.append([sub, obj, rel])
    file_path = f'./relationship/{id}.txt'
    with open(file_path, 'w') as file:
        for sub, obj, rel in relation_list:
            file.write(str(obj) + ' ' + str(rel) + ' ' + str(sub) + '\n')

def visualization(img, box, filename, obj_class):
    colors = ['red', 'blue', 'orange', 'green', 'yellow', 'purple']
    font = ImageFont.truetype("/usr/share/fonts/truetype/tlwg/Garuda.ttf", size=100)
    obj_list = obj_class.cpu().numpy()
    bboxes = box.cpu().numpy()
    draw = ImageDraw.Draw(img)
    font = ImageFont.load_default()
    i = 0
    for bbox, obj in zip(bboxes, obj_list):
        x_center, y_center, width, height = bbox
        x_max = int((x_center + width/2) * img.width)
        y_max = int((y_center + height/2) * img.height)
        x_min = int((x_center - width/2) * img.width)
        y_min = int((y_center - height/2) * img.height)
        obj = id2label[obj]
        text_width, text_height = 10, 10
        text_x = max(x_min + text_width + 5, 0)
        text_y = y_min + text_height + 5
        draw.text((text_x, text_y), str(obj), fill=colors[i], font=font)
        draw.rectangle([(x_min, y_min), (x_max, y_max)], outline=colors[i], fill=None, width=3)
        i = (i + 1)%6
    img.save(f"{filename}_box.png")

@riverjunhao
Copy link

@riverjunhao Sure! Here’s my code. However, I have only visualized the bounding boxes. The scene graph is stored in a text file in a textual format. Since I am not working on scene graph research, I haven’t delved into how to visualize the scene graph on an image. So, if needed, you can explore how to implement this feature yourself. If you could share your findings, I believe it would be very helpful for other beginners.I modified the calculate_fps function in the original file evaluate_egtr.py and added some necessary auxiliary functions. However, since I haven’t tried this code for more than a month, there might be some issues. 截屏2024-09-27 11 54 27

import os
import argparse
import json
from glob import glob
from PIL import Image, ImageDraw, ImageFont
import torch

@torch.no_grad()
def calculate_fps(model, dataloader, top_k):
    model.eval()
    i = 0
    for batch in tqdm(dataloader):
        i = i + 1
        if i > 10:
            break
        outputs = model(
            pixel_values=batch["pixel_values"].cuda(),
            pixel_mask=batch["pixel_mask"].cuda(),
            output_attentions=False,
            output_attention_states=True,
            output_hidden_states=True,
        )
        pred_logits = outputs['logits'][0]
        obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)
        pred_classes_numpy = pred_classes.cpu().clone().numpy()
        sub_ob_scores = torch.outer(obj_scores, obj_scores)
        sub_ob_scores[
            torch.arange(pred_logits.size(0)), torch.arange(pred_logits.size(0))
        ] = 0.0  # prevent self-connection

        pred_boxes = outputs['pred_boxes'][0]
        pred_boxes_numpy = pred_boxes.cpu().clone().numpy()
        pred_connectivity = outputs["pred_connectivity"][0]
        pred_rel = outputs["pred_rel"][0]
        pred_rel = torch.mul(pred_rel, pred_connectivity)
        triplet_scores = torch.mul(pred_rel.max(-1)[0], sub_ob_scores)
        pred_rel_inds = argsort_desc(triplet_scores.cpu().clone().numpy())[:top_k, :]  # [pred_rels, 2(s,o)]
        rel_scores = (pred_rel.cpu().clone().numpy()[pred_rel_inds[:, 0], pred_rel_inds[:, 1]])
        pred_rels = np.column_stack((pred_rel_inds, rel_scores.argmax(1)))
        triplets, triplets_box = get_triplets(pred_rels, pred_classes_numpy, pred_boxes_numpy)
        
        # Get valid objects and triplets
        obj_threshold = 0.3
        valid_obj_indices = (obj_scores >= obj_threshold).nonzero()[:top_k, 0]
        valid_obj_classes = pred_classes[valid_obj_indices]  # [num_valid_objects]
        valid_obj_boxes = pred_boxes[valid_obj_indices]  # [num_valid_objects, 4](x_center, y_center, width, height)
        
        if not os.path.exists("./val_images"):
            os.mkdir("./val_images")
        if not os.path.exists('./relationship'):
            os.mkdir("./relationship")
        if not os.path.exists(f"./val_images/{batch['id']}.png"):
            image = Image.fromarray(batch['img'], 'RGB')
            filename = f"./val_images/{batch['id']}"
            image.save(f"{filename}.png")
            visualization(image, valid_obj_boxes, filename, valid_obj_classes)
            write_triplets(triplets, batch['id'])

def get_triplets(pred_triplets, classes, boxes):
    triplets, triplet_boxes = [], []
    for sub, obj, rel in pred_triplets:
        triplet_boxes.append([boxes[sub], boxes[obj]])
        sub = id2label[classes[sub]]
        obj = id2label[classes[obj]]
        rel = id2relation[rel]
        triplets.append([obj, rel, sub])
    return triplets, triplet_boxes

def write_triplets(triplets, id):
    file_path = f'./relationship/{id}.txt'
    with open(file_path, 'w') as file:
        for sub, rel, obj in triplets:
            file.write(str(obj) + ' ' + str(rel) + ' ' + str(sub) + '\n')

def get_relation(triplets, id):
    relation_list = []
    for sub, obj, rel in triplets:
        sub = id2label[sub]
        obj = id2label[obj]
        rel = id2relation[rel]
        relation_list.append([sub, obj, rel])
    file_path = f'./relationship/{id}.txt'
    with open(file_path, 'w') as file:
        for sub, obj, rel in relation_list:
            file.write(str(obj) + ' ' + str(rel) + ' ' + str(sub) + '\n')

def visualization(img, box, filename, obj_class):
    colors = ['red', 'blue', 'orange', 'green', 'yellow', 'purple']
    font = ImageFont.truetype("/usr/share/fonts/truetype/tlwg/Garuda.ttf", size=100)
    obj_list = obj_class.cpu().numpy()
    bboxes = box.cpu().numpy()
    draw = ImageDraw.Draw(img)
    font = ImageFont.load_default()
    i = 0
    for bbox, obj in zip(bboxes, obj_list):
        x_center, y_center, width, height = bbox
        x_max = int((x_center + width/2) * img.width)
        y_max = int((y_center + height/2) * img.height)
        x_min = int((x_center - width/2) * img.width)
        y_min = int((y_center - height/2) * img.height)
        obj = id2label[obj]
        text_width, text_height = 10, 10
        text_x = max(x_min + text_width + 5, 0)
        text_y = y_min + text_height + 5
        draw.text((text_x, text_y), str(obj), fill=colors[i], font=font)
        draw.rectangle([(x_min, y_min), (x_max, y_max)], outline=colors[i], fill=None, width=3)
        i = (i + 1)%6
    img.save(f"{filename}_box.png")

Thank you very much for your reply! I need to correspond the output index or label to the actual category, which seems to require id2label and id2elaboration in your code, but I didn't find any relevant information in the code provided by the author. Can you tell me how this was obtained?

@SuperFanBin
Copy link
Author

@riverjunhao In fact, the author has already provided the id2label and id2relation files. You can take a closer look at this answer, and then modify these two files as needed. I hope this helps. I tested this on the Visual Genome dataset, where the id2label and id2relation variables are loaded when loading the dataset. The code is as follows:

# Dataset
    if "visual_genome" in args.data_path:
        test_dataset = VGDataset(
            data_folder=args.data_path,
            feature_extractor=feature_extractor,
            split=args.split,
            num_object_queries=args.num_queries,
        )
        id2relation = test_dataset.rel_categories
        id2label = {
            k - 1: v["name"] for k, v in test_dataset.coco.cats.items()
        }  # 0 ~ 149
        coco_evaluator = CocoEvaluator(
            test_dataset.coco, ["bbox"]
        )  # initialize evaluator with ground truths
        oi_evaluator = None

@riverjunhao
Copy link

@SuperFanBin Thank you very much for your reply! I have tried the code you provided and made some modifications, but there is still an issue: writing information to the specified file requires reading the contents of the batch ['id '] and batch ['img'] keys, which do not seem to be in the 'batch' dictionary. I tried to print the 'batch' dictionary and use the information contained in the [image_id] key under [labels] as the path name to save the file, but I still couldn't extract the content of batch ['img '] well and save it in the image. Could you solve my problem? Thank you very much!
图片

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants