Skip to content

Commit

Permalink
Merge pull request #1052 from tensorflow/sync
Browse files Browse the repository at this point in the history
Sync (0-30 of 206)
  • Loading branch information
danjan1234 authored Apr 19, 2024
2 parents 906be52 + bb2193e commit 2be44ad
Show file tree
Hide file tree
Showing 32 changed files with 554 additions and 319 deletions.
30 changes: 29 additions & 1 deletion models/official/detection/dataloader/tf_example_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ def __init__(
# copypara:strip_end
regenerate_source_id=False,
label_key='image/object/class/label',
label_dtype=tf.int64):
label_dtype=tf.int64,
include_keypoint=False,
num_keypoints_per_instance=0):
self._include_mask = include_mask
self._include_keypoint = include_keypoint
# copypara:strip_begin
self._include_polygon = include_polygon
# copypara:strip_end
Expand All @@ -62,6 +65,13 @@ def __init__(
self._label_dtype)
if include_mask:
self._keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string)
if include_keypoint:
self._num_keypoints_per_instance = num_keypoints_per_instance
self._keys_to_features.update({
'image/object/keypoint/visibility': tf.io.VarLenFeature(tf.int64),
'image/object/keypoint/x': tf.io.VarLenFeature(tf.float32),
'image/object/keypoint/y': tf.io.VarLenFeature(tf.float32),
})

def _decode_image(self, parsed_tensors):
"""Decodes the image and set its static shape."""
Expand Down Expand Up @@ -106,6 +116,17 @@ def _decode_areas(self, parsed_tensors):
lambda: parsed_tensors['image/object/area'],
lambda: (xmax - xmin) * (ymax - ymin) * height * width)

def _decode_keypoints(self, parsed_tensors):
"""Decode keypoint coordinates and visibilities."""
keypoint_x = parsed_tensors['image/object/keypoint/x']
keypoint_y = parsed_tensors['image/object/keypoint/y']
keypoints = tf.stack([keypoint_y, keypoint_x], axis=-1)
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints_per_instance, 2])
keypoint_visibilities = parsed_tensors['image/object/keypoint/visibility']
keypoint_visibilities = tf.reshape(keypoint_visibilities,
[-1, self._num_keypoints_per_instance])
return keypoints, keypoint_visibilities

def decode(self, serialized_example):
"""Decode the serialized example.
Expand Down Expand Up @@ -165,6 +186,8 @@ def decode(self, serialized_example):
lambda: _get_source_id_from_encoded_image(parsed_tensors))
if self._include_mask:
masks = self._decode_masks(parsed_tensors)
if self._include_keypoint:
keypoints, keypoint_visibilities = self._decode_keypoints(parsed_tensors)

groundtruth_classes = parsed_tensors[self._label_key]
decoded_tensors = {
Expand All @@ -182,4 +205,9 @@ def decode(self, serialized_example):
'groundtruth_instance_masks': masks,
'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
})
if self._include_keypoint:
decoded_tensors.update({
'groundtruth_keypoints': keypoints,
'groundtruth_keypoint_visibilities': keypoint_visibilities,
})
return decoded_tensors
163 changes: 117 additions & 46 deletions models/official/detection/evaluation/coco_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
need_rescale_bboxes=True,
per_category_metrics=False,
remove_invalid_boxes=False,
include_keypoint=False,
need_rescale_keypoints=False,
kpt_oks_sigmas=None
):
"""Constructs COCO evaluation class.
Expand All @@ -80,6 +83,13 @@ def __init__(
per_category_metrics: Whether to return per category metrics.
remove_invalid_boxes: A boolean indicating whether to remove invalid box
during evaluation.
include_keypoint: a boolean to indicate whether or not to include the
keypoint eval.
need_rescale_keypoints: If true keypoints in `predictions` will be
rescaled back to absolute values (`image_info` is needed in this case).
kpt_oks_sigmas: The sigmas used to calculate keypoint OKS. See
http://cocodataset.org/#keypoints-eval. When None, it will use the
defaults in COCO.
"""
if annotation_file:
if annotation_file.startswith('gs://'):
Expand All @@ -105,7 +115,8 @@ def __init__(
'detection_boxes'
]
self._need_rescale_bboxes = need_rescale_bboxes
if self._need_rescale_bboxes:
self._need_rescale_keypoints = need_rescale_keypoints
if self._need_rescale_bboxes or self._need_rescale_keypoints:
self._required_prediction_fields.append('image_info')
self._required_groundtruth_fields = [
'source_id', 'height', 'width', 'classes', 'boxes'
Expand All @@ -116,6 +127,18 @@ def __init__(
self._required_prediction_fields.extend(['detection_masks'])
self._required_groundtruth_fields.extend(['masks'])
self.remove_invalid_boxes = remove_invalid_boxes
self._include_keypoint = include_keypoint
self._kpt_oks_sigmas = kpt_oks_sigmas
if self._include_keypoint:
keypoint_metric_names = [
'AP', 'AP50', 'AP75', 'APm', 'APl', 'ARmax1', 'ARmax10', 'ARmax100',
'ARm', 'ARl'
]
keypoint_metric_names = ['keypoint_' + x for x in keypoint_metric_names]
self._metric_names.extend(keypoint_metric_names)
self._required_prediction_fields.extend(['detection_keypoints'])
self._required_groundtruth_fields.extend(['keypoints'])

self.reset()

def reset(self):
Expand Down Expand Up @@ -168,7 +191,7 @@ def evaluate(self):
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
coco_metrics = coco_eval.stats
metrics = coco_eval.stats

if self._include_mask:
mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm')
Expand All @@ -177,11 +200,17 @@ def evaluate(self):
mcoco_eval.accumulate()
mcoco_eval.summarize()
mask_coco_metrics = mcoco_eval.stats

if self._include_mask:
metrics = np.hstack((coco_metrics, mask_coco_metrics))
else:
metrics = coco_metrics
metrics = np.hstack((metrics, mask_coco_metrics))

if self._include_keypoint:
kcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='keypoints',
kpt_oks_sigmas=self._kpt_oks_sigmas)
kcoco_eval.params.imgIds = image_ids
kcoco_eval.evaluate()
kcoco_eval.accumulate()
kcoco_eval.summarize()
keypoint_coco_metrics = kcoco_eval.stats
metrics = np.hstack((metrics, keypoint_coco_metrics))

# Cleans up the internal variables in order for a fresh eval next time.
self.reset()
Expand All @@ -192,46 +221,64 @@ def evaluate(self):

# Adds metrics per category.
if self._per_category_metrics and hasattr(coco_eval, 'category_stats'):
for category_index, category_id in enumerate(coco_eval.params.catIds):
metrics_dict['Precision mAP ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[0][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@50IoU/{}'.format(
category_id)] = coco_eval.category_stats[1][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@75IoU/{}'.format(
category_id)] = coco_eval.category_stats[2][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (small) /{}'.format(
category_id)] = coco_eval.category_stats[3][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (medium) /{}'.format(
category_id)] = coco_eval.category_stats[4][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (large) /{}'.format(
category_id)] = coco_eval.category_stats[5][category_index].astype(
np.float32)
metrics_dict['Recall AR@1 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[6][category_index].astype(
np.float32)
metrics_dict['Recall AR@10 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[7][category_index].astype(
np.float32)
metrics_dict['Recall AR@100 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[8][category_index].astype(
np.float32)
metrics_dict['Recall AR (small) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[9][category_index].astype(
np.float32)
metrics_dict['Recall AR (medium) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[10][category_index].astype(
np.float32)
metrics_dict['Recall AR (large) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[11][category_index].astype(
np.float32)
metrics_dict.update(self._retrieve_per_category_metrics(coco_eval))

if self._include_keypoint:
metrics_dict.update(self._retrieve_per_category_metrics(
kcoco_eval, prefix='keypoints'))
return metrics_dict

def _process_predictions(self, predictions):
def _retrieve_per_category_metrics(self, coco_eval, prefix=''):
"""Retrieves and per-category metrics and returns them in a dict.
Args:
coco_eval: a cocoeval.COCOeval object containing evaluation data.
prefix: str, A string used to prefix metric names.
Returns:
metrics_dict: A dictionary with per category metrics.
"""

metrics_dict = {}
if prefix:
prefix = prefix + ' '

for category_index, category_id in enumerate(coco_eval.params.catIds):
if 'keypoints' in prefix:
metrics_dict_keys = [
'Precision mAP ByCategory',
'Precision mAP ByCategory@50IoU',
'Precision mAP ByCategory@75IoU',
'Precision mAP ByCategory (medium)',
'Precision mAP ByCategory (large)',
'Recall AR@1 ByCategory',
'Recall AR@10 ByCategory',
'Recall AR@100 ByCategory',
'Recall AR (medium) ByCategory',
'Recall AR (large) ByCategory',
]
else:
metrics_dict_keys = [
'Precision mAP ByCategory',
'Precision mAP ByCategory@50IoU',
'Precision mAP ByCategory@75IoU',
'Precision mAP ByCategory (small)',
'Precision mAP ByCategory (medium)',
'Precision mAP ByCategory (large)',
'Recall AR@1 ByCategory',
'Recall AR@10 ByCategory',
'Recall AR@100 ByCategory',
'Recall AR (small) ByCategory',
'Recall AR (medium) ByCategory',
'Recall AR (large) ByCategory',
]
for idx, key in enumerate(metrics_dict_keys):
metrics_dict[prefix + key + '/{}'.format(
category_id)] = coco_eval.category_stats[idx][
category_index].astype(np.float32)
return metrics_dict

def _process_bboxes_predictions(self, predictions):
image_scale = np.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2))
predictions['detection_boxes'] = (
predictions['detection_boxes'].astype(np.float32))
Expand All @@ -241,6 +288,13 @@ def _process_predictions(self, predictions):
predictions['detection_outer_boxes'].astype(np.float32))
predictions['detection_outer_boxes'] /= image_scale

def _process_keypoints_predictions(self, predictions):
image_scale = tf.reshape(predictions['image_info'][:, 2:3, :],
[-1, 1, 1, 2])
predictions['detection_keypoints'] = (
predictions['detection_keypoints'].astype(np.float32))
predictions['detection_keypoints'] /= image_scale

def update(self, predictions, groundtruths=None):
"""Update and aggregate detection results and groundtruth data.
Expand Down Expand Up @@ -286,7 +340,9 @@ def update(self, predictions, groundtruths=None):
raise ValueError(
'Missing the required key `{}` in predictions!'.format(k))
if self._need_rescale_bboxes:
self._process_predictions(predictions)
self._process_bboxes_predictions(predictions)
if self._need_rescale_keypoints:
self._process_keypoints_predictions(predictions)
for k, v in six.iteritems(predictions):
if k not in self._predictions:
self._predictions[k] = [v]
Expand All @@ -305,6 +361,20 @@ def update(self, predictions, groundtruths=None):
else:
self._groundtruths[k].append(v)

def merge(self, other):
"""Merges the states from the other CocoEvaluator."""
for k, v in other._predictions.items(): # pylint: disable=protected-access
if k not in self._predictions:
self._predictions[k] = v
else:
self._predictions[k].extend(v)

for k, v in other._groundtruths.items(): # pylint: disable=protected-access
if k not in self._groundtruths:
self._groundtruths[k] = v
else:
self._groundtruths[k].extend(v)


class ShapeMaskCOCOEvaluator(COCOEvaluator):
"""COCO evaluation metric class for ShapeMask."""
Expand Down Expand Up @@ -463,6 +533,7 @@ def __init__(
self._metric_names.extend(mask_metric_names)
self._required_prediction_fields.extend(['detection_masks'])
self._required_groundtruth_fields.extend(['masks'])
self._need_rescale_keypoints = False

self.reset()

Expand Down
36 changes: 36 additions & 0 deletions models/official/detection/evaluation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def convert_predictions_to_coco_annotations(
Optional fields:
- detection_masks: a list of numpy arrays of float of shape
[batch_size, K, mask_height, mask_width].
- detection_keypoints: a list of numpy arrays of float of shape
[batch_size, K, num_keypoints, 2]
remove_invalid_boxes: A boolean indicating whether to remove invalid box
during evaluation.
Expand All @@ -160,6 +162,19 @@ def convert_predictions_to_coco_annotations(

# NOTE: Batch size may differ between chunks.
batch_size = predictions['source_id'][i].shape[0]
if 'detection_keypoints' in predictions:
# Adds extra ones to indicate the visibility for each keypoint as is
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
# as mandated by COCO.
num_keypoints = predictions['detection_keypoints'][i].shape[2]
coco_keypoints = np.concatenate(
[
predictions['detection_keypoints'][i][..., 1:],
predictions['detection_keypoints'][i][..., :1],
np.ones([batch_size, max_num_detections, num_keypoints, 1]),
],
axis=-1,
).astype(int)
for j in range(batch_size):
if 'detection_masks' in predictions:
image_masks = mask_utils.paste_instance_masks(
Expand All @@ -185,6 +200,8 @@ def convert_predictions_to_coco_annotations(
ann['score'] = predictions['detection_scores'][i][j, k]
if 'detection_masks' in predictions:
ann['segmentation'] = encoded_masks[k]
if 'detection_keypoints' in predictions:
ann['keypoints'] = coco_keypoints[j, k].flatten().tolist()
coco_predictions.append(ann)

for i, ann in enumerate(coco_predictions):
Expand Down Expand Up @@ -272,6 +289,25 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
ann['segmentation'] = encoded_mask
if 'areas' not in groundtruths:
ann['area'] = mask_api.area(encoded_mask)
if 'keypoints' in groundtruths:
keypoints = groundtruths['keypoints'][i]
coco_keypoints = []
num_valid_keypoints = 0
for z in range(len(keypoints[j, k, :, 1])):
# Convert from [y, x] to [x, y] as mandated by COCO.
x = float(keypoints[j, k, z, 1])
y = float(keypoints[j, k, z, 0])
coco_keypoints.append(x)
coco_keypoints.append(y)
if tf.math.is_nan(x) or tf.math.is_nan(y) or (
x == 0 and y == 0):
visibility = 0
else:
visibility = 2
num_valid_keypoints = num_valid_keypoints + 1
coco_keypoints.append(visibility)
ann['keypoints'] = coco_keypoints
ann['num_keypoints'] = num_valid_keypoints
gt_annotations.append(ann)

for i, ann in enumerate(gt_annotations):
Expand Down
Loading

0 comments on commit 2be44ad

Please sign in to comment.