-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #271 from Labelbox/develop
3.3.0
- Loading branch information
Showing
47 changed files
with
1,146 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,4 @@ | |
from .collection import LabelGenerator | ||
|
||
from .metrics import ScalarMetric | ||
from .metrics import MetricAggregation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .scalar import ScalarMetric | ||
from .aggregations import MetricAggregation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from enum import Enum | ||
|
||
|
||
class MetricAggregation(Enum): | ||
ARITHMETIC_MEAN = "ARITHMETIC_MEAN" | ||
GEOMETRIC_MEAN = "GEOMETRIC_MEAN" | ||
HARMONIC_MEAN = "HARMONIC_MEAN" | ||
SUM = "SUM" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation | ||
from typing import Any, Dict, Optional | ||
from pydantic import BaseModel | ||
|
||
|
||
class ScalarMetric(BaseModel): | ||
""" Class representing metrics | ||
# For backwards compatibility, metric_name is optional. This will eventually be deprecated | ||
# The metric_name will be set to a default name in the editor if it is not set. | ||
# aggregation will be ignored wihtout providing a metric name. | ||
# Not providing a metric name is deprecated. | ||
""" | ||
value: float | ||
metric_name: Optional[str] = None | ||
feature_name: Optional[str] = None | ||
subclass_name: Optional[str] = None | ||
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN | ||
extra: Dict[str, Any] = {} | ||
|
||
def dict(self, *args, **kwargs): | ||
res = super().dict(*args, **kwargs) | ||
if res['metric_name'] is None: | ||
res.pop('aggregation') | ||
return {k: v for k, v in res.items() if v is not None} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
""" | ||
Tools for grouping features and labels so that we can compute metrics on the individual groups | ||
""" | ||
from collections import defaultdict | ||
from typing import Dict, List, Tuple, Union | ||
try: | ||
from typing import Literal | ||
except ImportError: | ||
from typing_extensions import Literal | ||
|
||
from labelbox.data.annotation_types import Label | ||
from labelbox.data.annotation_types.collection import LabelList | ||
from labelbox.data.annotation_types.feature import FeatureSchema | ||
|
||
|
||
def get_identifying_key( | ||
features_a: List[FeatureSchema], features_b: List[FeatureSchema] | ||
) -> Union[Literal['name'], Literal['feature_schema_id']]: | ||
""" | ||
Checks to make sure that features in both sets contain the same type of identifying keys. | ||
This can either be the feature name or feature schema id. | ||
Args: | ||
features_a : List of FeatureSchemas (usually ObjectAnnotations or ClassificationAnnotations) | ||
features_b : List of FeatureSchemas (usually ObjectAnnotations or ClassificationAnnotations) | ||
Returns: | ||
The field name that is present in both feature lists. | ||
""" | ||
|
||
all_schema_ids_defined_pred, all_names_defined_pred = all_have_key( | ||
features_a) | ||
if (not all_schema_ids_defined_pred and not all_names_defined_pred): | ||
raise ValueError("All data must have feature_schema_ids or names set") | ||
|
||
all_schema_ids_defined_gt, all_names_defined_gt = all_have_key(features_b) | ||
|
||
# Prefer name becuse the user will be able to know what it means | ||
# Schema id incase that doesn't exist. | ||
if (all_names_defined_pred and all_names_defined_gt): | ||
return 'name' | ||
elif all_schema_ids_defined_pred and all_schema_ids_defined_gt: | ||
return 'feature_schema_id' | ||
else: | ||
raise ValueError( | ||
"Ground truth and prediction annotations must have set all name or feature ids. " | ||
"Otherwise there is no key to match on. Please update.") | ||
|
||
|
||
def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: | ||
""" | ||
Checks to make sure that all FeatureSchemas have names set or feature_schema_ids set. | ||
Args: | ||
features (List[FeatureSchema]) : | ||
""" | ||
all_names = True | ||
all_schemas = True | ||
for feature in features: | ||
if feature.name is None: | ||
all_names = False | ||
if feature.feature_schema_id is None: | ||
all_schemas = False | ||
return all_schemas, all_names | ||
|
||
|
||
def get_label_pairs(labels_a: LabelList, | ||
labels_b: LabelList, | ||
match_on="uid", | ||
filter=False) -> Dict[str, Tuple[Label, Label]]: | ||
""" | ||
This is a function to pairing a list of prediction labels and a list of ground truth labels easier. | ||
There are a few potentiall problems with this function. | ||
We are assuming that the data row `uid` or `external id` have been provided by the user. | ||
However, these particular fields are not required and can be empty. | ||
If this assumption fails, then the user has to determine their own matching strategy. | ||
Args: | ||
labels_a (LabelList): A collection of labels to match with labels_b | ||
labels_b (LabelList): A collection of labels to match with labels_a | ||
match_on ('uid' or 'external_id'): The data row key to match labels by. Can either be uid or external id. | ||
filter (bool): Whether or not to ignore mismatches | ||
Returns: | ||
A dict containing the union of all either uids or external ids and values as a tuple of the matched labels | ||
""" | ||
|
||
if match_on not in ['uid', 'external_id']: | ||
raise ValueError("Can only match on `uid` or `exteranl_id`.") | ||
|
||
label_lookup_a = { | ||
getattr(label.data, match_on, None): label for label in labels_a | ||
} | ||
label_lookup_b = { | ||
getattr(label.data, match_on, None): label for label in labels_b | ||
} | ||
all_keys = set(label_lookup_a.keys()).union(label_lookup_b.keys()) | ||
if None in label_lookup_a or None in label_lookup_b: | ||
raise ValueError( | ||
f"One or more of the labels has a data row without the required key {match_on}." | ||
" It cannot be determined which labels match without this information." | ||
f" Either assign {match_on} to each Label or create your own pairing function." | ||
) | ||
pairs = defaultdict(list) | ||
for key in all_keys: | ||
a, b = label_lookup_a.pop(key, None), label_lookup_b.pop(key, None) | ||
if a is None or b is None: | ||
if not filter: | ||
raise ValueError( | ||
f"{match_on} {key} is not available in both LabelLists. " | ||
"Set `filter = True` to filter out these examples, assign the ids manually, or create your own matching function." | ||
) | ||
else: | ||
continue | ||
pairs[key].append([a, b]) | ||
return pairs | ||
|
||
|
||
def get_feature_pairs( | ||
features_a: List[FeatureSchema], features_b: List[FeatureSchema] | ||
) -> Dict[str, Tuple[List[FeatureSchema], List[FeatureSchema]]]: | ||
""" | ||
Matches features by schema_ids | ||
Args: | ||
labels_a (List[FeatureSchema]): A list of features to match with features_b | ||
labels_b (List[FeatureSchema]): A list of features to match with features_a | ||
Returns: | ||
The matched features as dict. The key will be the feature name and the value will be | ||
two lists each containing the matched features from each set. | ||
""" | ||
identifying_key = get_identifying_key(features_a, features_b) | ||
lookup_a, lookup_b = _create_feature_lookup( | ||
features_a, | ||
identifying_key), _create_feature_lookup(features_b, identifying_key) | ||
|
||
keys = set(lookup_a.keys()).union(set(lookup_b.keys())) | ||
result = defaultdict(list) | ||
for key in keys: | ||
result[key].extend([lookup_a[key], lookup_b[key]]) | ||
return result | ||
|
||
|
||
def _create_feature_lookup(features: List[FeatureSchema], | ||
key: str) -> Dict[str, List[FeatureSchema]]: | ||
""" | ||
Groups annotation by name (if available otherwise feature schema id). | ||
Args: | ||
annotations: List of annotations to group | ||
Returns: | ||
a dict where each key is the feature_schema_id (or name) | ||
and the value is a list of annotations that have that feature_schema_id (or name) | ||
""" | ||
grouped_features = defaultdict(list) | ||
for feature in features: | ||
grouped_features[getattr(feature, key)].append(feature) | ||
return grouped_features |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .calculation import * | ||
from .iou import * |
Oops, something went wrong.