diff --git a/src/data_collection/data_collector.py b/src/data_collection/data_collector.py index 3dee8e8..cd39b5e 100644 --- a/src/data_collection/data_collector.py +++ b/src/data_collection/data_collector.py @@ -107,7 +107,8 @@ def collect_data(self) -> None: object_types.append(obj.category) object_bounding_boxes.append(np.array(obj.xywh)) object_xy.append(np.array(obj.xy)) - last_idx = -1 if not hasattr(obj, 'last_xy') or obj.last_xy == (0,0) else self.episode_object_xy[-1].index(np.array(obj.last_xy)) + last_idx = -1 if len(self.episode_object_xy) == 0 or obj.prev_xy == (0,0) else \ + self.episode_object_xy[-1].index(np.array(obj.prev_xy)) object_last_idx.append(last_idx) object_bounding_boxes = np.array(object_bounding_boxes)