diff --git a/geest/core/tasks/study_area.py b/geest/core/tasks/study_area.py index 81e258e..0611b5a 100644 --- a/geest/core/tasks/study_area.py +++ b/geest/core/tasks/study_area.py @@ -23,6 +23,7 @@ QgsVectorFileWriter, QgsFields, QgsCoordinateTransformContext, + QgsWkbTypes, Qgis, ) from qgis.PyQt.QtCore import QVariant @@ -316,24 +317,12 @@ def process_singlepart_geometry( ) # Process the geometry based on the selected mode if self.mode == "vector": - log_message( - f"Creating vector grid for {normalized_name}.", - tag="Geest", - level=Qgis.Info, - ) + log_message(f"Creating vector grid for {normalized_name}.") self.create_and_save_grid(geom, bbox) elif self.mode == "raster": - log_message( - f"Creating raster mask for {normalized_name}.", - tag="Geest", - level=Qgis.Info, - ) + log_message(f"Creating raster mask for {normalized_name}.") self.create_raster_mask(geom, bbox, normalized_name) - log_message( - f"Creating vector grid for {normalized_name}.", - tag="Geest", - level=Qgis.Info, - ) + log_message(f"Creating vector grid for {normalized_name}.") self.create_and_save_grid(geom, bbox) self.counter += 1 @@ -684,13 +673,15 @@ def create_and_save_grid(self, geom: QgsGeometry, bbox: QgsRectangle) -> None: def create_raster_mask( self, geom: QgsGeometry, aligned_box: QgsRectangle, mask_name: str - ) -> None: + ) -> str: """ Creates a 1-bit raster mask for a single geometry. :param geom: Geometry to be rasterized. :param aligned_box: Aligned bounding box for the geometry. :param mask_name: Name for the output raster file. + + :return: The path to the created raster mask. """ mask_filepath = os.path.join(self.working_dir, "study_area", f"{mask_name}.tif") @@ -700,14 +691,38 @@ def create_raster_mask( ) temp_layer_data_provider = temp_layer.dataProvider() # get the geometry as a linestring - multiline = geom.convertToType( - QgsGeometry.Type.Line, QgsGeometry.MultiComponent - ) + multiline = geom.coerceToType(QgsWkbTypes.LineString)[0] + + # Write multiline geometry as WKT to /tmp/multiline.wkt + # multiline_wkt_path = "/tmp/multiline.wkt" + # with open(multiline_wkt_path, "w") as wkt_file: + # wkt_file.write(multiline.asWkt()) + # log_message(f"Multiline geometry written to {multiline_wkt_path}") # select all grid cells that intersect the linestring gpkg_layer_path = f"{self.gpkg_path}|layername=study_area_grid" gpkg_layer = QgsVectorLayer(gpkg_layer_path, "study_area_grid", "ogr") + # Create a spatial index for efficient spatial querying + spatial_index = QgsSpatialIndex(gpkg_layer.getFeatures()) + + # Get feature IDs of candidates that may intersect with the multiline geometry + candidate_ids = spatial_index.intersects(multiline.boundingBox()) + + # Filter candidates by precise geometry intersection + intersecting_ids = [] + for feature_id in candidate_ids: + feature = gpkg_layer.getFeature(feature_id) + if feature.geometry().intersects(multiline): + intersecting_ids.append(feature_id) + + # Select intersecting features in the layer + gpkg_layer.selectByIds(intersecting_ids) + + log_message( + f"Selected {len(intersecting_ids)} features that intersect with the multiline geometry." + ) + # Define a field to store the mask value temp_layer_data_provider.addAttributes( [QgsField(self.field_name, QVariant.String)] @@ -718,8 +733,46 @@ def create_raster_mask( temp_feature = QgsFeature() temp_feature.setGeometry(geom) temp_feature.setAttributes(["1"]) # Setting an arbitrary value for the mask + + # Add the main geometry for this part of the country temp_layer_data_provider.addFeature(temp_feature) + # Now all the grid cells get added that intersect with the geometry border + # since gdal rasterize only includes cells that have 50% coverage or more + # by the looks of things. + selected_features = gpkg_layer.selectedFeatures() + new_features = [] + + for feature in selected_features: + # Create a new feature for emp_layer + new_feature = QgsFeature() + new_feature.setGeometry(feature.geometry()) + new_feature.setAttributes(["1"]) + new_features.append(new_feature) + + # Add the features to temp_layer + temp_layer_data_provider.addFeatures(new_features) + + # commit all changes + temp_layer.updateExtents() + temp_layer.commitChanges() + # check how many features we have + feature_count = temp_layer.featureCount() + log_message( + f"Added {feature_count} features to the temp layer for mask creation." + ) + + # Write temp_layer to /tmp/result.shp + result_shp_path = "/tmp/result.shp" + QgsVectorFileWriter.writeAsVectorFormat( + temp_layer, + result_shp_path, + "utf-8", + temp_layer.crs(), + "ESRI Shapefile", + ) + log_message(f"Temp layer written to {result_shp_path}") + # Ensure resolution parameters are properly formatted as float values x_res = self.cell_size_m # 100m pixel size in X direction y_res = self.cell_size_m # 100m pixel size in Y direction @@ -745,6 +798,7 @@ def create_raster_mask( } processing.run("gdal:rasterize", params) log_message(f"Created raster mask: {mask_filepath}") + return mask_filepath def calculate_utm_zone(self, bbox: QgsRectangle) -> int: """ diff --git a/test/test_study_area_processing_task.py b/test/test_study_area_processing_task.py index b1104eb..d877006 100644 --- a/test/test_study_area_processing_task.py +++ b/test/test_study_area_processing_task.py @@ -107,7 +107,10 @@ def test_process_study_area(self): gpkg_path = os.path.join( self.working_directory, "study_area", "study_area.gpkg" ) - self.assertTrue(os.path.exists(gpkg_path)) + self.assertTrue( + os.path.exists(gpkg_path), + msg=f"GeoPackage not created in {self.working_directory}", + ) def test_process_singlepart_geometry(self): """Test processing of singlepart geometry.""" @@ -128,7 +131,18 @@ def test_process_singlepart_geometry(self): gpkg_path = os.path.join( self.working_directory, "study_area", "study_area.gpkg" ) - self.assertTrue(os.path.exists(gpkg_path)) + self.assertTrue( + os.path.exists(gpkg_path), + msg=f"GeoPackage not created in {self.working_directory}", + ) + # Validate mask is a valid file + mask_path = os.path.join( + self.working_directory, "study_area", "saint_lucia_part0.tif" + ) + self.assertTrue( + os.path.exists(mask_path), + msg=f"mask saint_lucia_part0.tif not created in {mask_path}", + ) def test_grid_aligned_bbox(self): """Test grid alignment of bounding boxes.""" @@ -175,12 +189,16 @@ def test_create_raster_vrt(self): vrt_path = os.path.join( self.working_directory, "study_area", "combined_mask.vrt" ) - self.assertTrue(os.path.exists(vrt_path)) + self.assertTrue( + os.path.exists(vrt_path), + msg=f"VRT file not created in {self.working_directory}", + ) @classmethod def tearDownClass(cls): """Clean up shared resources.""" - if os.path.exists(cls.working_directory): + cleanup = False + if os.path.exists(cls.working_directory) and cleanup: for root, dirs, files in os.walk(cls.working_directory, topdown=False): for name in files: os.remove(os.path.join(root, name))