The goal of this Google Colab notebook is to fine-tune Facebook's DETR (DEtection TRansformer).
From left to right: results obtained with pre-trained DETR, and after fine-tuning on the balloon
dataset.
- Acquire a dataset, e.g. the the
balloon
dataset, - Convert the dataset to the COCO format,
- Run
finetune_detr.ipynb
to fine-tune DETR on this dataset. - Alternatively, run
finetune_detectron2.ipynb
to rely on the detectron2 wrapper.
NB: Fine-tuning is recommended if your dataset has less than 10k images. Otherwise, training from scratch would be an option.
DETR will be fine-tuned on a tiny dataset: the balloon
dataset.
We refer to it as the custom
dataset.
There are 61 images in the training set, and 13 images in the validation set.
We expect the directory structure to be the following:
path/to/coco/
├ annotations/ # JSON annotations
│ ├ annotations/custom_train.json
│ └ annotations/custom_val.json
├ train2017/ # training images
└ val2017/ # validation images
NB: if you are confused about the number of classes, check this Github issue.
Typical metrics to monitor, partially shown in this notebook, include:
- the Average Precision (AP), which is the primary challenge metric for the COCO dataset,
- losses (total loss, classification loss, l1 bbox distance loss, GIoU loss),
- errors (cardinality error, class error).
As mentioned in the paper, there are 3 components to the matching cost and to the total loss:
- classification loss,
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
[...]
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {'loss_ce': loss_ce}
- l1 bounding box distance loss,
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h),normalized by the image
size.
"""
[...]
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
- Generalized Intersection over Union (GIoU) loss, which is scale-invariant.
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes),
box_ops.box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes
Moreover, there are two errors:
- cardinality error,
def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty
boxes. This is not really a loss, it is intended for logging purposes only. It doesn't
propagate gradients
"""
[...]
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
where accuracy
is:
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
You should obtain acceptable results with 10 epochs, which require a few minutes of fine-tuning.
Out of curiosity, I have over-finetuned the model for 300 epochs (close to 3 hours). Here are:
- the last checkpoint (~ 500 MB),
- the log file.
All of the validation results are shown in view_balloon_validation.ipynb
.
- Official repositories:
- Facebook's DETR (and the paper)
- Facebook's detectron2 wrapper for DETR ; caveat: this wrapper only supports box detection
- DETR checkpoints: remove the classification head, then fine-tune
- My forks:
- Official notebooks:
- An official notebook showcasing DETR
- An official notebook showcasing the COCO API
- An official notebook showcasing the detectron2 wrapper for DETR
- Tutorials:
- A Github issue discussing the fine-tuning of DETR
- A Github Gist explaining how to fine-tune DETR
- A Github issue explaining how to load a fine-tuned DETR
- Datasets: