Skip to content

Commit

Permalink
fix ML-AOI schema URI - succeeding STAC Item validation test
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Mar 18, 2024
1 parent e01d0b5 commit ef2f130
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
2 changes: 1 addition & 1 deletion json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "https://stac-extensions.github.io/ml-aoi/v0.1.0/schema.json#",
"title": "ML AOI Extension",
"description": "ML AOI Extension for STAC Items.",
"description": "ML AOI Extension for STAC definitions.",
"oneOf": [
{
"$comment": "This is the schema for STAC Collections.",
Expand Down
2 changes: 1 addition & 1 deletion pystac_ml_aoi/extensions/ml_aoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
ML_AOI_SCHEMA = json.load(schema_file)

ML_AOI_SCHEMA_ID: SchemaName = get_args(SchemaName)[0]
ML_AOI_SCHEMA_URI: str = ML_AOI_SCHEMA["$id"]
ML_AOI_SCHEMA_URI: str = ML_AOI_SCHEMA["$id"].split("#")[0]
ML_AOI_PREFIX = f"{ML_AOI_SCHEMA_ID}:"
ML_AOI_PROPERTY = f"{ML_AOI_SCHEMA_ID}_".replace("-", "_")

Expand Down
59 changes: 54 additions & 5 deletions tests/test_pystac_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,30 @@
from pystac.extensions.label import LabelExtension, LabelTask, LabelType
from pystac_ml_aoi.extensions.ml_aoi import (
ML_AOI_Extension,
ML_AOI_CollectionExtension,
ML_AOI_ItemExtension,
ML_AOI_SCHEMA_PATH,
ML_AOI_SCHEMA_URI,
ML_AOI_Role,
ML_AOI_Split,
)


EUROSAT_EXAMPLE_BASE_URL = "https://raw.githubusercontent.com/ai-extensions/stac-data-loader/0.5.0/data/EuroSAT"
EUROSAT_EXAMPLE_ASSET_ITEM_URL = (
EUROSAT_EXAMPLE_BASE_URL +
"/stac/subset/train/item-42.json"
)
EUROSAT_EXAMPLE_ASSET_LABEL_URL = (
EUROSAT_EXAMPLE_BASE_URL +
"/data/subset/ds/images/remote_sensing/otherDatasets/sentinel_2/label/Residential/Residential_1331.geojson"
)
EUROSAT_EXAMPLE_ASSET_RASTER_URL = (
EUROSAT_EXAMPLE_BASE_URL +
"/data/subset/ds/images/remote_sensing/otherDatasets/sentinel_2/tif/Residential/Residential_1331.tif"
)


@pytest.fixture(scope="session", name="stac_validator", autouse=True)
def make_stac_ml_aoi_validator(
request: pytest.FixtureRequest,
Expand All @@ -38,8 +55,7 @@ def make_stac_ml_aoi_validator(
validator = pystac.validation.RegisteredValidator.get_validator()
validator = cast(pystac.validation.stac_validator.JsonSchemaSTACValidator, validator)
validation_schema = json.loads(pystac.StacIO.default().read_text(ML_AOI_SCHEMA_PATH))
validation_uri = ML_AOI_SCHEMA_URI.split("#")[0]
validator.schema_cache[validation_uri] = validation_schema
validator.schema_cache[ML_AOI_SCHEMA_URI] = validation_schema
pystac.validation.RegisteredValidator.set_validator(validator) # apply globally to allow 'STACObject.validate()'
return validator

Expand Down Expand Up @@ -111,20 +127,39 @@ def make_base_stac_item() -> pystac.Item:
]
}
bbox = list(shapely.geometry.shape(geom).bounds)
asset_label = pystac.Asset(
href=EUROSAT_EXAMPLE_ASSET_LABEL_URL,
media_type=pystac.MediaType.GEOJSON,
roles=["data"],
extra_fields={
"ml-aoi:role": ML_AOI_Role.LABEL,
}
)
asset_raster = pystac.Asset(
href=EUROSAT_EXAMPLE_ASSET_RASTER_URL,
media_type=pystac.MediaType.GEOTIFF,
roles=["data"],
extra_fields={
"ml-aoi:reference-grid": True,
"ml-aoi:role": ML_AOI_Role.FEATURE,
}
)
item = pystac.Item(
id="EuroSAT-subset-train-sample-42-class-Residential",
geometry=geom,
bbox=bbox,
start_datetime=dt_parse("2015-06-27T10:25:31.456Z"),
end_datetime=dt_parse("2017-06-14T00:00:00Z"),
datetime=None,
properties={}
properties={},
assets={"label": asset_label, "raster": asset_raster},
)
label_item = LabelExtension.ext(item, add_if_missing=True)
label_item.apply(
label_description="ml-aoi-test",
label_type=LabelType.RASTER,
label_tasks=[LabelTask.CLASSIFICATION],
label_type=LabelType.VECTOR,
label_tasks=[LabelTask.CLASSIFICATION, LabelTask.SEGMENTATION],
label_properties=["class"],
)
return item

Expand Down Expand Up @@ -193,5 +228,19 @@ def test_ml_aoi_pystac_item_with_field_property(
assert ml_aoi_item_json["properties"]["ml-aoi:split"] == "train"


def test_ml_aoi_pystac_item_filter_assets(
item: pystac.Item,
stac_validator: pystac.validation.stac_validator.JsonSchemaSTACValidator,
) -> None:
"""
Validate extending a STAC Collection with ML-AOI extension.
"""
assert ML_AOI_SCHEMA_URI not in item.stac_extensions
ml_aoi_item = cast(ML_AOI_ItemExtension, ML_AOI_Extension.ext(item, add_if_missing=True))
ml_aoi_item.split = ML_AOI_Split.TRAIN
assets = ml_aoi_item.get_assets(reference_grid=True)
assert len(assets) == 1 and "raster" in assets


if __name__ == "__main__":
unittest.main()

0 comments on commit ef2f130

Please sign in to comment.