From 62917fa66d9fe5e164d1b10370ad4c653d111b6c Mon Sep 17 00:00:00 2001 From: <> Date: Sat, 19 Aug 2023 17:42:40 +0000 Subject: [PATCH] Deployed dd691e3 with MkDocs version: 1.5.2 --- .nojekyll | 0 404.html | 641 ++ CNAME | 1 + assets/_mkdocstrings.css | 16 + assets/images/favicon.png | Bin 0 -> 1870 bytes assets/javascripts/bundle.220ee61c.min.js | 29 + assets/javascripts/bundle.220ee61c.min.js.map | 8 + assets/javascripts/lunr/min/lunr.ar.min.js | 1 + assets/javascripts/lunr/min/lunr.da.min.js | 18 + assets/javascripts/lunr/min/lunr.de.min.js | 18 + assets/javascripts/lunr/min/lunr.du.min.js | 18 + assets/javascripts/lunr/min/lunr.es.min.js | 18 + assets/javascripts/lunr/min/lunr.fi.min.js | 18 + assets/javascripts/lunr/min/lunr.fr.min.js | 18 + assets/javascripts/lunr/min/lunr.hi.min.js | 1 + assets/javascripts/lunr/min/lunr.hu.min.js | 18 + assets/javascripts/lunr/min/lunr.hy.min.js | 1 + assets/javascripts/lunr/min/lunr.it.min.js | 18 + assets/javascripts/lunr/min/lunr.ja.min.js | 1 + assets/javascripts/lunr/min/lunr.jp.min.js | 1 + assets/javascripts/lunr/min/lunr.kn.min.js | 1 + assets/javascripts/lunr/min/lunr.ko.min.js | 1 + assets/javascripts/lunr/min/lunr.multi.min.js | 1 + assets/javascripts/lunr/min/lunr.nl.min.js | 18 + assets/javascripts/lunr/min/lunr.no.min.js | 18 + assets/javascripts/lunr/min/lunr.pt.min.js | 18 + assets/javascripts/lunr/min/lunr.ro.min.js | 18 + assets/javascripts/lunr/min/lunr.ru.min.js | 18 + assets/javascripts/lunr/min/lunr.sa.min.js | 1 + .../lunr/min/lunr.stemmer.support.min.js | 1 + assets/javascripts/lunr/min/lunr.sv.min.js | 18 + assets/javascripts/lunr/min/lunr.ta.min.js | 1 + assets/javascripts/lunr/min/lunr.te.min.js | 1 + assets/javascripts/lunr/min/lunr.th.min.js | 1 + assets/javascripts/lunr/min/lunr.tr.min.js | 18 + assets/javascripts/lunr/min/lunr.vi.min.js | 1 + assets/javascripts/lunr/min/lunr.zh.min.js | 1 + assets/javascripts/lunr/tinyseg.js | 206 + assets/javascripts/lunr/wordcut.js | 6708 +++++++++++++++++ .../workers/search.74e28a9f.min.js | 42 + .../workers/search.74e28a9f.min.js.map | 8 + assets/stylesheets/main.eebd395e.min.css | 1 + assets/stylesheets/main.eebd395e.min.css.map | 1 + assets/stylesheets/palette.ecc896b0.min.css | 1 + .../stylesheets/palette.ecc896b0.min.css.map | 1 + changelog/changelog.md | 211 + changelog/index.html | 1156 +++ common/common.md | 3 + common/index.html | 6495 ++++++++++++++++ contributing/contributing.md | 125 + contributing/index.html | 975 +++ examples/arcgis/arcgis.ipynb | 466 ++ examples/arcgis/index.html | 2927 +++++++ .../automatic_mask_generator.ipynb | 386 + examples/automatic_mask_generator/index.html | 2660 +++++++ .../automatic_mask_generator_hq.ipynb | 382 + .../automatic_mask_generator_hq/index.html | 2658 +++++++ examples/box_prompts/box_prompts.ipynb | 358 + examples/box_prompts/index.html | 2591 +++++++ examples/data/tree_boxes.geojson | 91 + examples/input_prompts/index.html | 2348 ++++++ examples/input_prompts/input_prompts.ipynb | 268 + examples/input_prompts_hq/index.html | 2347 ++++++ .../input_prompts_hq/input_prompts_hq.ipynb | 266 + examples/satellite-predictor/index.html | 2359 ++++++ .../satellite-predictor.ipynb | 318 + examples/satellite/index.html | 2480 ++++++ examples/satellite/satellite.ipynb | 320 + examples/swimming_pools/index.html | 2588 +++++++ examples/swimming_pools/swimming_pools.ipynb | 368 + examples/text_prompts/index.html | 2577 +++++++ examples/text_prompts/text_prompts.ipynb | 361 + examples/text_prompts_batch/index.html | 2366 ++++++ .../text_prompts_batch.ipynb | 261 + faq/faq.md | 1 + faq/index.html | 689 ++ hq_sam/hq_sam.md | 3 + hq_sam/index.html | 4166 ++++++++++ index.html | 951 +++ index.md | 105 + installation/index.html | 867 +++ installation/installation.md | 102 + objects.inv | Bin 0 -> 854 bytes overrides/main.html | 11 + samgeo/index.html | 4155 ++++++++++ samgeo/samgeo.md | 4 + search/search_index.json | 1 + sitemap.xml | 108 + sitemap.xml.gz | Bin 0 -> 365 bytes text_sam/index.html | 2804 +++++++ text_sam/text_sam.md | 3 + usage/index.html | 746 ++ usage/usage.md | 40 + 93 files changed, 63415 insertions(+) create mode 100644 .nojekyll create mode 100644 404.html create mode 100644 CNAME create mode 100644 assets/_mkdocstrings.css create mode 100644 assets/images/favicon.png create mode 100644 assets/javascripts/bundle.220ee61c.min.js create mode 100644 assets/javascripts/bundle.220ee61c.min.js.map create mode 100644 assets/javascripts/lunr/min/lunr.ar.min.js create mode 100644 assets/javascripts/lunr/min/lunr.da.min.js create mode 100644 assets/javascripts/lunr/min/lunr.de.min.js create mode 100644 assets/javascripts/lunr/min/lunr.du.min.js create mode 100644 assets/javascripts/lunr/min/lunr.es.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hu.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hy.min.js create mode 100644 assets/javascripts/lunr/min/lunr.it.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ja.min.js create mode 100644 assets/javascripts/lunr/min/lunr.jp.min.js create mode 100644 assets/javascripts/lunr/min/lunr.kn.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ko.min.js create mode 100644 assets/javascripts/lunr/min/lunr.multi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.nl.min.js create mode 100644 assets/javascripts/lunr/min/lunr.no.min.js create mode 100644 assets/javascripts/lunr/min/lunr.pt.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ro.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ru.min.js create mode 100644 assets/javascripts/lunr/min/lunr.sa.min.js create mode 100644 assets/javascripts/lunr/min/lunr.stemmer.support.min.js create mode 100644 assets/javascripts/lunr/min/lunr.sv.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ta.min.js create mode 100644 assets/javascripts/lunr/min/lunr.te.min.js create mode 100644 assets/javascripts/lunr/min/lunr.th.min.js create mode 100644 assets/javascripts/lunr/min/lunr.tr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.vi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.zh.min.js create mode 100644 assets/javascripts/lunr/tinyseg.js create mode 100644 assets/javascripts/lunr/wordcut.js create mode 100644 assets/javascripts/workers/search.74e28a9f.min.js create mode 100644 assets/javascripts/workers/search.74e28a9f.min.js.map create mode 100644 assets/stylesheets/main.eebd395e.min.css create mode 100644 assets/stylesheets/main.eebd395e.min.css.map create mode 100644 assets/stylesheets/palette.ecc896b0.min.css create mode 100644 assets/stylesheets/palette.ecc896b0.min.css.map create mode 100644 changelog/changelog.md create mode 100644 changelog/index.html create mode 100644 common/common.md create mode 100644 common/index.html create mode 100644 contributing/contributing.md create mode 100644 contributing/index.html create mode 100644 examples/arcgis/arcgis.ipynb create mode 100644 examples/arcgis/index.html create mode 100644 examples/automatic_mask_generator/automatic_mask_generator.ipynb create mode 100644 examples/automatic_mask_generator/index.html create mode 100644 examples/automatic_mask_generator_hq/automatic_mask_generator_hq.ipynb create mode 100644 examples/automatic_mask_generator_hq/index.html create mode 100644 examples/box_prompts/box_prompts.ipynb create mode 100644 examples/box_prompts/index.html create mode 100644 examples/data/tree_boxes.geojson create mode 100644 examples/input_prompts/index.html create mode 100644 examples/input_prompts/input_prompts.ipynb create mode 100644 examples/input_prompts_hq/index.html create mode 100644 examples/input_prompts_hq/input_prompts_hq.ipynb create mode 100644 examples/satellite-predictor/index.html create mode 100644 examples/satellite-predictor/satellite-predictor.ipynb create mode 100644 examples/satellite/index.html create mode 100644 examples/satellite/satellite.ipynb create mode 100644 examples/swimming_pools/index.html create mode 100644 examples/swimming_pools/swimming_pools.ipynb create mode 100644 examples/text_prompts/index.html create mode 100644 examples/text_prompts/text_prompts.ipynb create mode 100644 examples/text_prompts_batch/index.html create mode 100644 examples/text_prompts_batch/text_prompts_batch.ipynb create mode 100644 faq/faq.md create mode 100644 faq/index.html create mode 100644 hq_sam/hq_sam.md create mode 100644 hq_sam/index.html create mode 100644 index.html create mode 100644 index.md create mode 100644 installation/index.html create mode 100644 installation/installation.md create mode 100644 objects.inv create mode 100644 overrides/main.html create mode 100644 samgeo/index.html create mode 100644 samgeo/samgeo.md create mode 100644 search/search_index.json create mode 100644 sitemap.xml create mode 100644 sitemap.xml.gz create mode 100644 text_sam/index.html create mode 100644 text_sam/text_sam.md create mode 100644 usage/index.html create mode 100644 usage/usage.md diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 00000000..e69de29b diff --git a/404.html b/404.html new file mode 100644 index 00000000..b4dd8a0c --- /dev/null +++ b/404.html @@ -0,0 +1,641 @@ + + + +
+ + + + + + + + + + + + + +New Features
+New Features
+Improvements
+Improvements
+Improvements
+New Features
+Improvements
+New Features
+Improvements
+Improvements
+New Features
+Improvements
+New Features
+Improvements
+Contributors
+@p-vdp @LucasOsco
+Improvements
+New Features
+New Features
+Improvements
+Demos
+New Features
+Improvements
+Demos
+ +New Features
+SamGeo
class, including show_masks
, save_masks
, show_anns
, making it much easier to save segmentation results in GeoTIFF and vector formats.common
module, including array_to_image
, show_image
, download_file
, overlay_images
, blend_images
, and update_package
SamGeoPredictor
classImprovements
+SamGeo.generate()
methodDemos
+Contributors
+@darrenwiens
+New Features
+get_basemaps
, reproject
, tiff_to_shp
, and tiff_to_geojson
Improvement
+tiff_to_vector
crs bug #12crs
parameter to tms_to_geotiff
New Features
+SamGeo.generate
methodSamGeo.tiff_to_vector
methodNew Features
+SamGeo
classInitial release
+ +The source code is adapted from https://github.com/aliaksandr960/segment-anything-eo. Credit to the author Aliaksandr Hancharenka.
+ + + +array_to_image(array, output, source=None, dtype=None, compress='deflate', **kwargs)
+
+
+¶Save a NumPy array as a GeoTIFF using the projection information from an existing GeoTIFF file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
array |
+ np.ndarray |
+ The NumPy array to be saved as a GeoTIFF. |
+ required | +
output |
+ str |
+ The path to the output image. |
+ required | +
source |
+ str |
+ The path to an existing GeoTIFF file with map projection information. Defaults to None. |
+ None |
+
dtype |
+ np.dtype |
+ The data type of the output array. Defaults to None. |
+ None |
+
compress |
+ str |
+ The compression method. Can be one of the following: "deflate", "lzw", "packbits", "jpeg". Defaults to "deflate". |
+ 'deflate' |
+
samgeo/common.py
def array_to_image(
+ array, output, source=None, dtype=None, compress="deflate", **kwargs
+):
+ """Save a NumPy array as a GeoTIFF using the projection information from an existing GeoTIFF file.
+
+ Args:
+ array (np.ndarray): The NumPy array to be saved as a GeoTIFF.
+ output (str): The path to the output image.
+ source (str, optional): The path to an existing GeoTIFF file with map projection information. Defaults to None.
+ dtype (np.dtype, optional): The data type of the output array. Defaults to None.
+ compress (str, optional): The compression method. Can be one of the following: "deflate", "lzw", "packbits", "jpeg". Defaults to "deflate".
+ """
+
+ from PIL import Image
+
+ if isinstance(array, str) and os.path.exists(array):
+ array = cv2.imread(array)
+ array = cv2.cvtColor(array, cv2.COLOR_BGR2RGB)
+
+ if output.endswith(".tif") and source is not None:
+ with rasterio.open(source) as src:
+ crs = src.crs
+ transform = src.transform
+ if compress is None:
+ compress = src.compression
+
+ # Determine the minimum and maximum values in the array
+
+ min_value = np.min(array)
+ max_value = np.max(array)
+
+ if dtype is None:
+ # Determine the best dtype for the array
+ if min_value >= 0 and max_value <= 1:
+ dtype = np.float32
+ elif min_value >= 0 and max_value <= 255:
+ dtype = np.uint8
+ elif min_value >= -128 and max_value <= 127:
+ dtype = np.int8
+ elif min_value >= 0 and max_value <= 65535:
+ dtype = np.uint16
+ elif min_value >= -32768 and max_value <= 32767:
+ dtype = np.int16
+ else:
+ dtype = np.float64
+
+ # Convert the array to the best dtype
+ array = array.astype(dtype)
+
+ # Define the GeoTIFF metadata
+ if array.ndim == 2:
+ metadata = {
+ "driver": "GTiff",
+ "height": array.shape[0],
+ "width": array.shape[1],
+ "count": 1,
+ "dtype": array.dtype,
+ "crs": crs,
+ "transform": transform,
+ }
+ elif array.ndim == 3:
+ metadata = {
+ "driver": "GTiff",
+ "height": array.shape[0],
+ "width": array.shape[1],
+ "count": array.shape[2],
+ "dtype": array.dtype,
+ "crs": crs,
+ "transform": transform,
+ }
+
+ if compress is not None:
+ metadata["compress"] = compress
+ else:
+ raise ValueError("Array must be 2D or 3D.")
+
+ # Create a new GeoTIFF file and write the array to it
+ with rasterio.open(output, "w", **metadata) as dst:
+ if array.ndim == 2:
+ dst.write(array, 1)
+ elif array.ndim == 3:
+ for i in range(array.shape[2]):
+ dst.write(array[:, :, i], i + 1)
+
+ else:
+ img = Image.fromarray(array)
+ img.save(output, **kwargs)
+
bbox_to_xy(src_fp, coords, coord_crs='epsg:4326', **kwargs)
+
+
+¶Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates. + Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright + While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
src_fp |
+ str |
+ The source raster file path. |
+ required | +
coords |
+ list |
+ A list of coordinates in the format of [[minx, miny, maxx, maxy], [minx, miny, maxx, maxy], ...] |
+ required | +
coord_crs |
+ str |
+ The coordinate CRS of the input coordinates. Defaults to "epsg:4326". |
+ 'epsg:4326' |
+
Returns:
+Type | +Description | +
---|---|
list |
+ A list of pixel coordinates in the format of [[minx, maxy, maxx, miny], ...] from top left to bottom right. |
+
samgeo/common.py
def bbox_to_xy(
+ src_fp: str, coords: list, coord_crs: str = "epsg:4326", **kwargs
+) -> list:
+ """Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
+ Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright
+ While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright
+
+ Args:
+ src_fp (str): The source raster file path.
+ coords (list): A list of coordinates in the format of [[minx, miny, maxx, maxy], [minx, miny, maxx, maxy], ...]
+ coord_crs (str, optional): The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
+
+ Returns:
+ list: A list of pixel coordinates in the format of [[minx, maxy, maxx, miny], ...] from top left to bottom right.
+ """
+
+ if isinstance(coords, str):
+ gdf = gpd.read_file(coords)
+ coords = gdf.geometry.bounds.values.tolist()
+ if gdf.crs is not None:
+ coord_crs = f"epsg:{gdf.crs.to_epsg()}"
+ elif isinstance(coords, np.ndarray):
+ coords = coords.tolist()
+ if isinstance(coords, dict):
+ import json
+
+ geojson = json.dumps(coords)
+ gdf = gpd.read_file(geojson, driver="GeoJSON")
+ coords = gdf.geometry.bounds.values.tolist()
+
+ elif not isinstance(coords, list):
+ raise ValueError("coords must be a list of coordinates.")
+
+ if not isinstance(coords[0], list):
+ coords = [coords]
+
+ new_coords = []
+
+ with rasterio.open(src_fp) as src:
+ width = src.width
+ height = src.height
+
+ for coord in coords:
+ minx, miny, maxx, maxy = coord
+
+ if coord_crs != src.crs:
+ minx, miny = transform_coords(minx, miny, coord_crs, src.crs, **kwargs)
+ maxx, maxy = transform_coords(maxx, maxy, coord_crs, src.crs, **kwargs)
+
+ rows1, cols1 = rasterio.transform.rowcol(
+ src.transform, minx, miny, **kwargs
+ )
+ rows2, cols2 = rasterio.transform.rowcol(
+ src.transform, maxx, maxy, **kwargs
+ )
+
+ new_coords.append([cols1, rows1, cols2, rows2])
+
+ else:
+ new_coords.append([minx, miny, maxx, maxy])
+
+ result = []
+
+ for coord in new_coords:
+ minx, miny, maxx, maxy = coord
+
+ if (
+ minx >= 0
+ and miny >= 0
+ and maxx >= 0
+ and maxy >= 0
+ and minx < width
+ and miny < height
+ and maxx < width
+ and maxy < height
+ ):
+ # Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright
+ # While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright
+ result.append([minx, maxy, maxx, miny])
+
+ if len(result) == 0:
+ print("No valid pixel coordinates found.")
+ return None
+ elif len(result) == 1:
+ return result[0]
+ elif len(result) < len(coords):
+ print("Some coordinates are out of the image boundary.")
+
+ return result
+
blend_images(img1, img2, alpha=0.5, output=False, show=True, figsize=(12, 10), axis='off', **kwargs)
+
+
+¶Blends two images together using the addWeighted function from the OpenCV library.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
img1 |
+ numpy.ndarray |
+ The first input image on top represented as a NumPy array. |
+ required | +
img2 |
+ numpy.ndarray |
+ The second input image at the bottom represented as a NumPy array. |
+ required | +
alpha |
+ float |
+ The weighting factor for the first image in the blend. By default, this is set to 0.5. |
+ 0.5 |
+
output |
+ str |
+ The path to the output image. Defaults to False. |
+ False |
+
show |
+ bool |
+ Whether to display the blended image. Defaults to True. |
+ True |
+
figsize |
+ tuple |
+ The size of the figure. Defaults to (12, 10). |
+ (12, 10) |
+
axis |
+ str |
+ The axis of the figure. Defaults to "off". |
+ 'off' |
+
**kwargs |
+ + | Additional keyword arguments to pass to the cv2.addWeighted() function. |
+ {} |
+
Returns:
+Type | +Description | +
---|---|
numpy.ndarray |
+ The blended image as a NumPy array. |
+
samgeo/common.py
def blend_images(
+ img1,
+ img2,
+ alpha=0.5,
+ output=False,
+ show=True,
+ figsize=(12, 10),
+ axis="off",
+ **kwargs,
+):
+ """
+ Blends two images together using the addWeighted function from the OpenCV library.
+
+ Args:
+ img1 (numpy.ndarray): The first input image on top represented as a NumPy array.
+ img2 (numpy.ndarray): The second input image at the bottom represented as a NumPy array.
+ alpha (float): The weighting factor for the first image in the blend. By default, this is set to 0.5.
+ output (str, optional): The path to the output image. Defaults to False.
+ show (bool, optional): Whether to display the blended image. Defaults to True.
+ figsize (tuple, optional): The size of the figure. Defaults to (12, 10).
+ axis (str, optional): The axis of the figure. Defaults to "off".
+ **kwargs: Additional keyword arguments to pass to the cv2.addWeighted() function.
+
+ Returns:
+ numpy.ndarray: The blended image as a NumPy array.
+ """
+ # Resize the images to have the same dimensions
+ if isinstance(img1, str):
+ if img1.startswith("http"):
+ img1 = download_file(img1)
+
+ if not os.path.exists(img1):
+ raise ValueError(f"Input path {img1} does not exist.")
+
+ img1 = cv2.imread(img1)
+
+ if isinstance(img2, str):
+ if img2.startswith("http"):
+ img2 = download_file(img2)
+
+ if not os.path.exists(img2):
+ raise ValueError(f"Input path {img2} does not exist.")
+
+ img2 = cv2.imread(img2)
+
+ if img1.dtype == np.float32:
+ img1 = (img1 * 255).astype(np.uint8)
+
+ if img2.dtype == np.float32:
+ img2 = (img2 * 255).astype(np.uint8)
+
+ if img1.dtype != img2.dtype:
+ img2 = img2.astype(img1.dtype)
+
+ img1 = cv2.resize(img1, (img2.shape[1], img2.shape[0]))
+
+ # Blend the images using the addWeighted function
+ beta = 1 - alpha
+ blend_img = cv2.addWeighted(img1, alpha, img2, beta, 0, **kwargs)
+
+ if output:
+ array_to_image(blend_img, output, img2)
+
+ if show:
+ plt.figure(figsize=figsize)
+ plt.imshow(blend_img)
+ plt.axis(axis)
+ plt.show()
+ else:
+ return blend_img
+
boxes_to_vector(coords, src_crs, dst_crs='EPSG:4326', output=None, **kwargs)
+
+
+¶Convert a list of bounding box coordinates to vector data.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
coords |
+ list |
+ A list of bounding box coordinates in the format [[left, top, right, bottom], [left, top, right, bottom], ...]. |
+ required | +
src_crs |
+ int or str |
+ The EPSG code or proj4 string representing the source coordinate reference system (CRS) of the input coordinates. |
+ required | +
dst_crs |
+ int or str |
+ The EPSG code or proj4 string representing the destination CRS to reproject the data (default is "EPSG:4326"). |
+ 'EPSG:4326' |
+
output |
+ str or None |
+ The full file path (including the directory and filename without the extension) where the vector data should be saved. + If None (default), the function returns the GeoDataFrame without saving it to a file. |
+ None |
+
**kwargs |
+ + | Additional keyword arguments to pass to geopandas.GeoDataFrame.to_file() when saving the vector data. |
+ {} |
+
Returns:
+Type | +Description | +
---|---|
geopandas.GeoDataFrame or None |
+ The GeoDataFrame with the converted vector data if output is None, otherwise None if the data is saved to a file. |
+
samgeo/common.py
def boxes_to_vector(coords, src_crs, dst_crs="EPSG:4326", output=None, **kwargs):
+ """
+ Convert a list of bounding box coordinates to vector data.
+
+ Args:
+ coords (list): A list of bounding box coordinates in the format [[left, top, right, bottom], [left, top, right, bottom], ...].
+ src_crs (int or str): The EPSG code or proj4 string representing the source coordinate reference system (CRS) of the input coordinates.
+ dst_crs (int or str, optional): The EPSG code or proj4 string representing the destination CRS to reproject the data (default is "EPSG:4326").
+ output (str or None, optional): The full file path (including the directory and filename without the extension) where the vector data should be saved.
+ If None (default), the function returns the GeoDataFrame without saving it to a file.
+ **kwargs: Additional keyword arguments to pass to geopandas.GeoDataFrame.to_file() when saving the vector data.
+
+ Returns:
+ geopandas.GeoDataFrame or None: The GeoDataFrame with the converted vector data if output is None, otherwise None if the data is saved to a file.
+ """
+
+ from shapely.geometry import box
+
+ # Create a list of Shapely Polygon objects based on the provided coordinates
+ polygons = [box(*coord) for coord in coords]
+
+ # Create a GeoDataFrame with the Shapely Polygon objects
+ gdf = gpd.GeoDataFrame({"geometry": polygons}, crs=src_crs)
+
+ # Reproject the GeoDataFrame to the specified EPSG code
+ gdf_reprojected = gdf.to_crs(dst_crs)
+
+ if output is not None:
+ gdf_reprojected.to_file(output, **kwargs)
+ else:
+ return gdf_reprojected
+
check_file_path(file_path, make_dirs=True)
+
+
+¶Gets the absolute file path.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
file_path |
+ str |
+ The path to the file. |
+ required | +
make_dirs |
+ bool |
+ Whether to create the directory if it does not exist. Defaults to True. |
+ True |
+
Exceptions:
+Type | +Description | +
---|---|
FileNotFoundError |
+ If the directory could not be found. |
+
TypeError |
+ If the input directory path is not a string. |
+
Returns:
+Type | +Description | +
---|---|
str |
+ The absolute path to the file. |
+
samgeo/common.py
def check_file_path(file_path, make_dirs=True):
+ """Gets the absolute file path.
+
+ Args:
+ file_path (str): The path to the file.
+ make_dirs (bool, optional): Whether to create the directory if it does not exist. Defaults to True.
+
+ Raises:
+ FileNotFoundError: If the directory could not be found.
+ TypeError: If the input directory path is not a string.
+
+ Returns:
+ str: The absolute path to the file.
+ """
+ if isinstance(file_path, str):
+ if file_path.startswith("~"):
+ file_path = os.path.expanduser(file_path)
+ else:
+ file_path = os.path.abspath(file_path)
+
+ file_dir = os.path.dirname(file_path)
+ if not os.path.exists(file_dir) and make_dirs:
+ os.makedirs(file_dir)
+
+ return file_path
+
+ else:
+ raise TypeError("The provided file path must be a string.")
+
coords_to_geojson(coords, output=None)
+
+
+¶Convert a list of coordinates (lon, lat) to a GeoJSON string or file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
coords |
+ list |
+ A list of coordinates (lon, lat). |
+ required | +
output |
+ str |
+ The output file path. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
dict |
+ A GeoJSON dictionary. |
+
samgeo/common.py
def coords_to_geojson(coords, output=None):
+ """Convert a list of coordinates (lon, lat) to a GeoJSON string or file.
+
+ Args:
+ coords (list): A list of coordinates (lon, lat).
+ output (str, optional): The output file path. Defaults to None.
+
+ Returns:
+ dict: A GeoJSON dictionary.
+ """
+
+ import json
+
+ if len(coords) == 0:
+ return
+ # Create a GeoJSON FeatureCollection object
+ feature_collection = {"type": "FeatureCollection", "features": []}
+
+ # Iterate through the coordinates list and create a GeoJSON Feature object for each coordinate
+ for coord in coords:
+ feature = {
+ "type": "Feature",
+ "geometry": {"type": "Point", "coordinates": coord},
+ "properties": {},
+ }
+ feature_collection["features"].append(feature)
+
+ # Convert the FeatureCollection object to a JSON string
+ geojson_str = json.dumps(feature_collection)
+
+ if output is not None:
+ with open(output, "w") as f:
+ f.write(geojson_str)
+ else:
+ return geojson_str
+
coords_to_xy(src_fp, coords, coord_crs='epsg:4326', **kwargs)
+
+
+¶Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
src_fp |
+ str |
+ The source raster file path. |
+ required | +
coords |
+ list |
+ A list of coordinates in the format of [[x1, y1], [x2, y2], ...] |
+ required | +
coord_crs |
+ str |
+ The coordinate CRS of the input coordinates. Defaults to "epsg:4326". |
+ 'epsg:4326' |
+
**kwargs |
+ + | Additional keyword arguments to pass to rasterio.transform.rowcol. |
+ {} |
+
Returns:
+Type | +Description | +
---|---|
list |
+ A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...] |
+
samgeo/common.py
def coords_to_xy(
+ src_fp: str, coords: list, coord_crs: str = "epsg:4326", **kwargs
+) -> list:
+ """Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
+
+ Args:
+ src_fp: The source raster file path.
+ coords: A list of coordinates in the format of [[x1, y1], [x2, y2], ...]
+ coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
+ **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.
+
+ Returns:
+ A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
+ """
+ if isinstance(coords, np.ndarray):
+ coords = coords.tolist()
+
+ xs, ys = zip(*coords)
+ with rasterio.open(src_fp) as src:
+ width = src.width
+ height = src.height
+ if coord_crs != src.crs:
+ xs, ys = transform_coords(xs, ys, coord_crs, src.crs, **kwargs)
+ rows, cols = rasterio.transform.rowcol(src.transform, xs, ys, **kwargs)
+ result = [[col, row] for col, row in zip(cols, rows)]
+
+ result = [
+ [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height
+ ]
+ if len(result) == 0:
+ print("No valid pixel coordinates found.")
+ elif len(result) < len(coords):
+ print("Some coordinates are out of the image boundary.")
+
+ return result
+
download_checkpoint(model_type='vit_h', checkpoint_dir=None, hq=False)
+
+
+¶Download the SAM model checkpoint.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
model_type |
+ str |
+ The model type. Can be one of ['vit_h', 'vit_l', 'vit_b']. +Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details. |
+ 'vit_h' |
+
checkpoint_dir |
+ str |
+ The checkpoint_dir directory. Defaults to None, "~/.cache/torch/hub/checkpoints". |
+ None |
+
hq |
+ bool |
+ Whether to use HQ-SAM model (https://github.com/SysCV/sam-hq). Defaults to False. |
+ False |
+
samgeo/common.py
def download_checkpoint(model_type="vit_h", checkpoint_dir=None, hq=False):
+ """Download the SAM model checkpoint.
+
+ Args:
+ model_type (str, optional): The model type. Can be one of ['vit_h', 'vit_l', 'vit_b'].
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ checkpoint_dir (str, optional): The checkpoint_dir directory. Defaults to None, "~/.cache/torch/hub/checkpoints".
+ hq (bool, optional): Whether to use HQ-SAM model (https://github.com/SysCV/sam-hq). Defaults to False.
+ """
+
+ if not hq:
+ model_types = {
+ "vit_h": {
+ "name": "sam_vit_h_4b8939.pth",
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
+ },
+ "vit_l": {
+ "name": "sam_vit_l_0b3195.pth",
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
+ },
+ "vit_b": {
+ "name": "sam_vit_b_01ec64.pth",
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
+ },
+ }
+ else:
+ model_types = {
+ "vit_h": {
+ "name": "sam_hq_vit_h.pth",
+ "url": "https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing",
+ },
+ "vit_l": {
+ "name": "sam_hq_vit_l.pth",
+ "url": "https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing",
+ },
+ "vit_b": {
+ "name": "sam_hq_vit_b.pth",
+ "url": "https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing",
+ },
+ "vit_tiny": {
+ "name": "sam_hq_vit_tiny.pth",
+ "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth",
+ },
+ }
+
+ if model_type not in model_types:
+ raise ValueError(
+ f"Invalid model_type: {model_type}. It must be one of {', '.join(model_types)}"
+ )
+
+ if checkpoint_dir is None:
+ checkpoint_dir = os.environ.get(
+ "TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints")
+ )
+
+ checkpoint = os.path.join(checkpoint_dir, model_types[model_type]["name"])
+ if not os.path.exists(checkpoint):
+ print(f"Model checkpoint for {model_type} not found.")
+ url = model_types[model_type]["url"]
+ download_file(url, checkpoint)
+ return checkpoint
+
download_checkpoint_legacy(url=None, output=None, overwrite=False, **kwargs)
+
+
+¶Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
url |
+ str |
+ The checkpoint URL. Defaults to None. |
+ None |
+
output |
+ str |
+ The output file path. Defaults to None. |
+ None |
+
overwrite |
+ bool |
+ Overwrite the file if it already exists. Defaults to False. |
+ False |
+
Returns:
+Type | +Description | +
---|---|
str |
+ The output file path. |
+
samgeo/common.py
def download_checkpoint_legacy(url=None, output=None, overwrite=False, **kwargs):
+ """Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
+
+ Args:
+ url (str, optional): The checkpoint URL. Defaults to None.
+ output (str, optional): The output file path. Defaults to None.
+ overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
+
+ Returns:
+ str: The output file path.
+ """
+ checkpoints = {
+ "sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
+ "sam_vit_l_0b3195.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
+ "sam_vit_b_01ec64.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
+ }
+
+ if isinstance(url, str) and url in checkpoints:
+ url = checkpoints[url]
+
+ if url is None:
+ url = checkpoints["sam_vit_h_4b8939.pth"]
+
+ if output is None:
+ output = os.path.basename(url)
+
+ return download_file(url, output, overwrite=overwrite, **kwargs)
+
download_file(url=None, output=None, quiet=False, proxy=None, speed=None, use_cookies=True, verify=True, id=None, fuzzy=False, resume=False, unzip=True, overwrite=False, subfolder=False)
+
+
+¶Download a file from URL, including Google Drive shared URL.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
url |
+ str |
+ Google Drive URL is also supported. Defaults to None. |
+ None |
+
output |
+ str |
+ Output filename. Default is basename of URL. |
+ None |
+
quiet |
+ bool |
+ Suppress terminal output. Default is False. |
+ False |
+
proxy |
+ str |
+ Proxy. Defaults to None. |
+ None |
+
speed |
+ float |
+ Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None. |
+ None |
+
use_cookies |
+ bool |
+ Flag to use cookies. Defaults to True. |
+ True |
+
verify |
+ bool | str |
+ Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string, +in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True. |
+ True |
+
id |
+ str |
+ Google Drive's file ID. Defaults to None. |
+ None |
+
fuzzy |
+ bool |
+ Fuzzy extraction of Google Drive's file Id. Defaults to False. |
+ False |
+
resume |
+ bool |
+ Resume the download from existing tmp file if possible. Defaults to False. |
+ False |
+
unzip |
+ bool |
+ Unzip the file. Defaults to True. |
+ True |
+
overwrite |
+ bool |
+ Overwrite the file if it already exists. Defaults to False. |
+ False |
+
subfolder |
+ bool |
+ Create a subfolder with the same name as the file. Defaults to False. |
+ False |
+
Returns:
+Type | +Description | +
---|---|
str |
+ The output file path. |
+
samgeo/common.py
def download_file(
+ url=None,
+ output=None,
+ quiet=False,
+ proxy=None,
+ speed=None,
+ use_cookies=True,
+ verify=True,
+ id=None,
+ fuzzy=False,
+ resume=False,
+ unzip=True,
+ overwrite=False,
+ subfolder=False,
+):
+ """Download a file from URL, including Google Drive shared URL.
+
+ Args:
+ url (str, optional): Google Drive URL is also supported. Defaults to None.
+ output (str, optional): Output filename. Default is basename of URL.
+ quiet (bool, optional): Suppress terminal output. Default is False.
+ proxy (str, optional): Proxy. Defaults to None.
+ speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.
+ use_cookies (bool, optional): Flag to use cookies. Defaults to True.
+ verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string,
+ in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.
+ id (str, optional): Google Drive's file ID. Defaults to None.
+ fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False.
+ resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False.
+ unzip (bool, optional): Unzip the file. Defaults to True.
+ overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
+ subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False.
+
+ Returns:
+ str: The output file path.
+ """
+ import zipfile
+
+ try:
+ import gdown
+ except ImportError:
+ print(
+ "The gdown package is required for this function. Use `pip install gdown` to install it."
+ )
+ return
+
+ if output is None:
+ if isinstance(url, str) and url.startswith("http"):
+ output = os.path.basename(url)
+
+ out_dir = os.path.abspath(os.path.dirname(output))
+ if not os.path.exists(out_dir):
+ os.makedirs(out_dir)
+
+ if isinstance(url, str):
+ if os.path.exists(os.path.abspath(output)) and (not overwrite):
+ print(
+ f"{output} already exists. Skip downloading. Set overwrite=True to overwrite."
+ )
+ return os.path.abspath(output)
+ else:
+ url = github_raw_url(url)
+
+ if "https://drive.google.com/file/d/" in url:
+ fuzzy = True
+
+ output = gdown.download(
+ url, output, quiet, proxy, speed, use_cookies, verify, id, fuzzy, resume
+ )
+
+ if unzip and output.endswith(".zip"):
+ with zipfile.ZipFile(output, "r") as zip_ref:
+ if not quiet:
+ print("Extracting files...")
+ if subfolder:
+ basename = os.path.splitext(os.path.basename(output))[0]
+
+ output = os.path.join(out_dir, basename)
+ if not os.path.exists(output):
+ os.makedirs(output)
+ zip_ref.extractall(output)
+ else:
+ zip_ref.extractall(os.path.dirname(output))
+
+ return os.path.abspath(output)
+
geojson_to_coords(geojson, src_crs='epsg:4326', dst_crs='epsg:4326')
+
+
+¶Converts a geojson file or a dictionary of feature collection to a list of centroid coordinates.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
geojson |
+ str | dict |
+ The geojson file path or a dictionary of feature collection. |
+ required | +
src_crs |
+ str |
+ The source CRS. Defaults to "epsg:4326". |
+ 'epsg:4326' |
+
dst_crs |
+ str |
+ The destination CRS. Defaults to "epsg:4326". |
+ 'epsg:4326' |
+
Returns:
+Type | +Description | +
---|---|
list |
+ A list of centroid coordinates in the format of [[x1, y1], [x2, y2], ...] |
+
samgeo/common.py
def geojson_to_coords(
+ geojson: str, src_crs: str = "epsg:4326", dst_crs: str = "epsg:4326"
+) -> list:
+ """Converts a geojson file or a dictionary of feature collection to a list of centroid coordinates.
+
+ Args:
+ geojson (str | dict): The geojson file path or a dictionary of feature collection.
+ src_crs (str, optional): The source CRS. Defaults to "epsg:4326".
+ dst_crs (str, optional): The destination CRS. Defaults to "epsg:4326".
+
+ Returns:
+ list: A list of centroid coordinates in the format of [[x1, y1], [x2, y2], ...]
+ """
+
+ import json
+ import warnings
+
+ warnings.filterwarnings("ignore")
+
+ if isinstance(geojson, dict):
+ geojson = json.dumps(geojson)
+ gdf = gpd.read_file(geojson, driver="GeoJSON")
+ centroids = gdf.geometry.centroid
+ centroid_list = [[point.x, point.y] for point in centroids]
+ if src_crs != dst_crs:
+ centroid_list = transform_coords(
+ [x[0] for x in centroid_list],
+ [x[1] for x in centroid_list],
+ src_crs,
+ dst_crs,
+ )
+ centroid_list = [[x, y] for x, y in zip(centroid_list[0], centroid_list[1])]
+ return centroid_list
+
geojson_to_xy(src_fp, geojson, coord_crs='epsg:4326', **kwargs)
+
+
+¶Converts a geojson file or a dictionary of feature collection to a list of pixel coordinates.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
src_fp |
+ str |
+ The source raster file path. |
+ required | +
geojson |
+ str |
+ The geojson file path or a dictionary of feature collection. |
+ required | +
coord_crs |
+ str |
+ The coordinate CRS of the input coordinates. Defaults to "epsg:4326". |
+ 'epsg:4326' |
+
**kwargs |
+ + | Additional keyword arguments to pass to rasterio.transform.rowcol. |
+ {} |
+
Returns:
+Type | +Description | +
---|---|
list |
+ A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...] |
+
samgeo/common.py
def geojson_to_xy(
+ src_fp: str, geojson: str, coord_crs: str = "epsg:4326", **kwargs
+) -> list:
+ """Converts a geojson file or a dictionary of feature collection to a list of pixel coordinates.
+
+ Args:
+ src_fp: The source raster file path.
+ geojson: The geojson file path or a dictionary of feature collection.
+ coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
+ **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.
+
+ Returns:
+ A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
+ """
+ with rasterio.open(src_fp) as src:
+ src_crs = src.crs
+ coords = geojson_to_coords(geojson, coord_crs, src_crs)
+ return coords_to_xy(src_fp, coords, src_crs, **kwargs)
+
get_basemaps(free_only=True)
+
+
+¶Returns a dictionary of xyz basemaps.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
free_only |
+ bool |
+ Whether to return only free xyz tile services that do not require an access token. Defaults to True. |
+ True |
+
Returns:
+Type | +Description | +
---|---|
dict |
+ A dictionary of xyz basemaps. |
+
samgeo/common.py
def get_basemaps(free_only=True):
+ """Returns a dictionary of xyz basemaps.
+
+ Args:
+ free_only (bool, optional): Whether to return only free xyz tile services that do not require an access token. Defaults to True.
+
+ Returns:
+ dict: A dictionary of xyz basemaps.
+ """
+
+ basemaps = {}
+ xyz_dict = get_xyz_dict(free_only=free_only)
+ for item in xyz_dict:
+ name = xyz_dict[item].name
+ url = xyz_dict[item].build_url()
+ basemaps[name] = url
+
+ return basemaps
+
get_vector_crs(filename, **kwargs)
+
+
+¶Gets the CRS of a vector file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
filename |
+ str |
+ The vector file path. |
+ required | +
Returns:
+Type | +Description | +
---|---|
str |
+ The CRS of the vector file. |
+
samgeo/common.py
def get_vector_crs(filename, **kwargs):
+ """Gets the CRS of a vector file.
+
+ Args:
+ filename (str): The vector file path.
+
+ Returns:
+ str: The CRS of the vector file.
+ """
+ gdf = gpd.read_file(filename, **kwargs)
+ epsg = gdf.crs.to_epsg()
+ if epsg is None:
+ return gdf.crs
+ else:
+ return f"EPSG:{epsg}"
+
get_xyz_dict(free_only=True)
+
+
+¶Returns a dictionary of xyz services.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
free_only |
+ bool |
+ Whether to return only free xyz tile services that do not require an access token. Defaults to True. |
+ True |
+
Returns:
+Type | +Description | +
---|---|
dict |
+ A dictionary of xyz services. |
+
samgeo/common.py
def get_xyz_dict(free_only=True):
+ """Returns a dictionary of xyz services.
+
+ Args:
+ free_only (bool, optional): Whether to return only free xyz tile services that do not require an access token. Defaults to True.
+
+ Returns:
+ dict: A dictionary of xyz services.
+ """
+ import collections
+ import xyzservices.providers as xyz
+
+ def _unpack_sub_parameters(var, param):
+ temp = var
+ for sub_param in param.split("."):
+ temp = getattr(temp, sub_param)
+ return temp
+
+ xyz_dict = {}
+ for item in xyz.values():
+ try:
+ name = item["name"]
+ tile = _unpack_sub_parameters(xyz, name)
+ if _unpack_sub_parameters(xyz, name).requires_token():
+ if free_only:
+ pass
+ else:
+ xyz_dict[name] = tile
+ else:
+ xyz_dict[name] = tile
+
+ except Exception:
+ for sub_item in item:
+ name = item[sub_item]["name"]
+ tile = _unpack_sub_parameters(xyz, name)
+ if _unpack_sub_parameters(xyz, name).requires_token():
+ if free_only:
+ pass
+ else:
+ xyz_dict[name] = tile
+ else:
+ xyz_dict[name] = tile
+
+ xyz_dict = collections.OrderedDict(sorted(xyz_dict.items()))
+ return xyz_dict
+
github_raw_url(url)
+
+
+¶Get the raw URL for a GitHub file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
url |
+ str |
+ The GitHub URL. |
+ required | +
Returns:
+Type | +Description | +
---|---|
str |
+ The raw URL. |
+
samgeo/common.py
def github_raw_url(url):
+ """Get the raw URL for a GitHub file.
+
+ Args:
+ url (str): The GitHub URL.
+ Returns:
+ str: The raw URL.
+ """
+ if isinstance(url, str) and url.startswith("https://github.com/") and "blob" in url:
+ url = url.replace("github.com", "raw.githubusercontent.com").replace(
+ "blob/", ""
+ )
+ return url
+
image_to_cog(source, dst_path=None, profile='deflate', **kwargs)
+
+
+¶Converts an image to a COG file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
source |
+ str |
+ A dataset path, URL or rasterio.io.DatasetReader object. |
+ required | +
dst_path |
+ str |
+ An output dataset path or or PathLike object. Defaults to None. |
+ None |
+
profile |
+ str |
+ COG profile. More at https://cogeotiff.github.io/rio-cogeo/profile. Defaults to "deflate". |
+ 'deflate' |
+
Exceptions:
+Type | +Description | +
---|---|
ImportError |
+ If rio-cogeo is not installed. |
+
FileNotFoundError |
+ If the source file could not be found. |
+
samgeo/common.py
def image_to_cog(source, dst_path=None, profile="deflate", **kwargs):
+ """Converts an image to a COG file.
+
+ Args:
+ source (str): A dataset path, URL or rasterio.io.DatasetReader object.
+ dst_path (str, optional): An output dataset path or or PathLike object. Defaults to None.
+ profile (str, optional): COG profile. More at https://cogeotiff.github.io/rio-cogeo/profile. Defaults to "deflate".
+
+ Raises:
+ ImportError: If rio-cogeo is not installed.
+ FileNotFoundError: If the source file could not be found.
+ """
+ try:
+ from rio_cogeo.cogeo import cog_translate
+ from rio_cogeo.profiles import cog_profiles
+
+ except ImportError:
+ raise ImportError(
+ "The rio-cogeo package is not installed. Please install it with `pip install rio-cogeo` or `conda install rio-cogeo -c conda-forge`."
+ )
+
+ if not source.startswith("http"):
+ source = check_file_path(source)
+
+ if not os.path.exists(source):
+ raise FileNotFoundError("The provided input file could not be found.")
+
+ if dst_path is None:
+ if not source.startswith("http"):
+ dst_path = os.path.splitext(source)[0] + "_cog.tif"
+ else:
+ dst_path = temp_file_path(extension=".tif")
+
+ dst_path = check_file_path(dst_path)
+
+ dst_profile = cog_profiles.get(profile)
+ cog_translate(source, dst_path, dst_profile, **kwargs)
+
install_package(package)
+
+
+¶Install a Python package.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
package |
+ str | list |
+ The package name or a GitHub URL or a list of package names or GitHub URLs. |
+ required | +
samgeo/common.py
def install_package(package):
+ """Install a Python package.
+
+ Args:
+ package (str | list): The package name or a GitHub URL or a list of package names or GitHub URLs.
+ """
+ import subprocess
+
+ if isinstance(package, str):
+ packages = [package]
+
+ for package in packages:
+ if package.startswith("https://github.com"):
+ package = f"git+{package}"
+
+ # Execute pip install command and show output in real-time
+ command = f"pip install {package}"
+ process = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
+
+ # Print output in real-time
+ while True:
+ output = process.stdout.readline()
+ if output == b"" and process.poll() is not None:
+ break
+ if output:
+ print(output.decode("utf-8").strip())
+
+ # Wait for process to complete
+ process.wait()
+
is_colab()
+
+
+¶Tests if the code is being executed within Google Colab.
+ +samgeo/common.py
def is_colab():
+ """Tests if the code is being executed within Google Colab."""
+ import sys
+
+ if "google.colab" in sys.modules:
+ return True
+ else:
+ return False
+
merge_rasters(input_dir, output, input_pattern='*.tif', output_format='GTiff', output_nodata=None, output_options=['COMPRESS=DEFLATE'])
+
+
+¶Merge a directory of rasters into a single raster.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
input_dir |
+ str |
+ The path to the input directory. |
+ required | +
output |
+ str |
+ The path to the output raster. |
+ required | +
input_pattern |
+ str |
+ The pattern to match the input files. Defaults to "*.tif". |
+ '*.tif' |
+
output_format |
+ str |
+ The output format. Defaults to "GTiff". |
+ 'GTiff' |
+
output_nodata |
+ float |
+ The output nodata value. Defaults to None. |
+ None |
+
output_options |
+ list |
+ A list of output options. Defaults to ["COMPRESS=DEFLATE"]. |
+ ['COMPRESS=DEFLATE'] |
+
Exceptions:
+Type | +Description | +
---|---|
ImportError |
+ Raised if GDAL is not installed. |
+
samgeo/common.py
def merge_rasters(
+ input_dir,
+ output,
+ input_pattern="*.tif",
+ output_format="GTiff",
+ output_nodata=None,
+ output_options=["COMPRESS=DEFLATE"],
+):
+ """Merge a directory of rasters into a single raster.
+
+ Args:
+ input_dir (str): The path to the input directory.
+ output (str): The path to the output raster.
+ input_pattern (str, optional): The pattern to match the input files. Defaults to "*.tif".
+ output_format (str, optional): The output format. Defaults to "GTiff".
+ output_nodata (float, optional): The output nodata value. Defaults to None.
+ output_options (list, optional): A list of output options. Defaults to ["COMPRESS=DEFLATE"].
+
+ Raises:
+ ImportError: Raised if GDAL is not installed.
+ """
+
+ import glob
+
+ try:
+ from osgeo import gdal
+ except ImportError:
+ raise ImportError(
+ "GDAL is required to use this function. Install it with `conda install gdal -c conda-forge`"
+ )
+ # Get a list of all the input files
+ input_files = glob.glob(os.path.join(input_dir, input_pattern))
+
+ # Merge the input files into a single output file
+ gdal.Warp(
+ output,
+ input_files,
+ format=output_format,
+ dstNodata=output_nodata,
+ options=output_options,
+ )
+
overlay_images(image1, image2, alpha=0.5, backend='TkAgg', height_ratios=[10, 1], show_args1={}, show_args2={})
+
+
+¶Overlays two images using a slider to control the opacity of the top image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image1 |
+ str | np.ndarray |
+ The first input image at the bottom represented as a NumPy array or the path to the image. |
+ required | +
image2 |
+ _type_ |
+ The second input image on top represented as a NumPy array or the path to the image. |
+ required | +
alpha |
+ float |
+ The alpha value of the top image. Defaults to 0.5. |
+ 0.5 |
+
backend |
+ str |
+ The backend of the matplotlib plot. Defaults to "TkAgg". |
+ 'TkAgg' |
+
height_ratios |
+ list |
+ The height ratios of the two subplots. Defaults to [10, 1]. |
+ [10, 1] |
+
show_args1 |
+ dict |
+ The keyword arguments to pass to the imshow() function for the first image. Defaults to {}. |
+ {} |
+
show_args2 |
+ dict |
+ The keyword arguments to pass to the imshow() function for the second image. Defaults to {}. |
+ {} |
+
samgeo/common.py
def overlay_images(
+ image1,
+ image2,
+ alpha=0.5,
+ backend="TkAgg",
+ height_ratios=[10, 1],
+ show_args1={},
+ show_args2={},
+):
+ """Overlays two images using a slider to control the opacity of the top image.
+
+ Args:
+ image1 (str | np.ndarray): The first input image at the bottom represented as a NumPy array or the path to the image.
+ image2 (_type_): The second input image on top represented as a NumPy array or the path to the image.
+ alpha (float, optional): The alpha value of the top image. Defaults to 0.5.
+ backend (str, optional): The backend of the matplotlib plot. Defaults to "TkAgg".
+ height_ratios (list, optional): The height ratios of the two subplots. Defaults to [10, 1].
+ show_args1 (dict, optional): The keyword arguments to pass to the imshow() function for the first image. Defaults to {}.
+ show_args2 (dict, optional): The keyword arguments to pass to the imshow() function for the second image. Defaults to {}.
+
+ """
+ import sys
+ import matplotlib
+ import matplotlib.widgets as mpwidgets
+
+ if "google.colab" in sys.modules:
+ backend = "inline"
+ print(
+ "The TkAgg backend is not supported in Google Colab. The overlay_images function will not work on Colab."
+ )
+ return
+
+ matplotlib.use(backend)
+
+ if isinstance(image1, str):
+ if image1.startswith("http"):
+ image1 = download_file(image1)
+
+ if not os.path.exists(image1):
+ raise ValueError(f"Input path {image1} does not exist.")
+
+ if isinstance(image2, str):
+ if image2.startswith("http"):
+ image2 = download_file(image2)
+
+ if not os.path.exists(image2):
+ raise ValueError(f"Input path {image2} does not exist.")
+
+ # Load the two images
+ x = plt.imread(image1)
+ y = plt.imread(image2)
+
+ # Create the plot
+ fig, (ax0, ax1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": height_ratios})
+ img0 = ax0.imshow(x, **show_args1)
+ img1 = ax0.imshow(y, alpha=alpha, **show_args2)
+
+ # Define the update function
+ def update(value):
+ img1.set_alpha(value)
+ fig.canvas.draw_idle()
+
+ # Create the slider
+ slider0 = mpwidgets.Slider(ax=ax1, label="alpha", valmin=0, valmax=1, valinit=alpha)
+ slider0.on_changed(update)
+
+ # Display the plot
+ plt.show()
+
random_string(string_length=6)
+
+
+¶Generates a random string of fixed length.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
string_length |
+ int |
+ Fixed length. Defaults to 3. |
+ 6 |
+
Returns:
+Type | +Description | +
---|---|
str |
+ A random string |
+
samgeo/common.py
def random_string(string_length=6):
+ """Generates a random string of fixed length.
+
+ Args:
+ string_length (int, optional): Fixed length. Defaults to 3.
+
+ Returns:
+ str: A random string
+ """
+ import random
+ import string
+
+ # random.seed(1001)
+ letters = string.ascii_lowercase
+ return "".join(random.choice(letters) for i in range(string_length))
+
raster_to_geojson(tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a GeoJSON file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the GeoJSON file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/common.py
def raster_to_geojson(tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a GeoJSON file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the GeoJSON file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ if not output.endswith(".geojson"):
+ output += ".geojson"
+
+ raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
raster_to_gpkg(tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a gpkg file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the gpkg file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/common.py
def raster_to_gpkg(tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the gpkg file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ if not output.endswith(".gpkg"):
+ output += ".gpkg"
+
+ raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
raster_to_shp(tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a shapefile.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the shapefile. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/common.py
def raster_to_shp(tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a shapefile.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the shapefile.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ if not output.endswith(".shp"):
+ output += ".shp"
+
+ raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
raster_to_vector(source, output, simplify_tolerance=None, **kwargs)
+
+
+¶Vectorize a raster dataset.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
source |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the vector file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/common.py
def raster_to_vector(source, output, simplify_tolerance=None, **kwargs):
+ """Vectorize a raster dataset.
+
+ Args:
+ source (str): The path to the tiff file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+ from rasterio import features
+
+ with rasterio.open(source) as src:
+ band = src.read()
+
+ mask = band != 0
+ shapes = features.shapes(band, mask=mask, transform=src.transform)
+
+ fc = [
+ {"geometry": shapely.geometry.shape(shape), "properties": {"value": value}}
+ for shape, value in shapes
+ ]
+ if simplify_tolerance is not None:
+ for i in fc:
+ i["geometry"] = i["geometry"].simplify(tolerance=simplify_tolerance)
+
+ gdf = gpd.GeoDataFrame.from_features(fc)
+ if src.crs is not None:
+ gdf.set_crs(crs=src.crs, inplace=True)
+ gdf.to_file(output, **kwargs)
+
regularize(source, output=None, crs='EPSG:4326', **kwargs)
+
+
+¶Regularize a polygon GeoDataFrame.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
source |
+ str | gpd.GeoDataFrame |
+ The input file path or a GeoDataFrame. |
+ required | +
output |
+ str |
+ The output file path. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
gpd.GeoDataFrame |
+ The output GeoDataFrame. |
+
samgeo/common.py
def regularize(source, output=None, crs="EPSG:4326", **kwargs):
+ """Regularize a polygon GeoDataFrame.
+
+ Args:
+ source (str | gpd.GeoDataFrame): The input file path or a GeoDataFrame.
+ output (str, optional): The output file path. Defaults to None.
+
+
+ Returns:
+ gpd.GeoDataFrame: The output GeoDataFrame.
+ """
+ if isinstance(source, str):
+ gdf = gpd.read_file(source)
+ elif isinstance(source, gpd.GeoDataFrame):
+ gdf = source
+ else:
+ raise ValueError("The input source must be a GeoDataFrame or a file path.")
+
+ polygons = gdf.geometry.apply(lambda geom: geom.minimum_rotated_rectangle)
+ result = gpd.GeoDataFrame(geometry=polygons, data=gdf.drop("geometry", axis=1))
+
+ if crs is not None:
+ result.to_crs(crs, inplace=True)
+ if output is not None:
+ result.to_file(output, **kwargs)
+ else:
+ return result
+
reproject(image, output, dst_crs='EPSG:4326', resampling='nearest', to_cog=True, **kwargs)
+
+
+¶Reprojects an image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str |
+ The input image filepath. |
+ required | +
output |
+ str |
+ The output image filepath. |
+ required | +
dst_crs |
+ str |
+ The destination CRS. Defaults to "EPSG:4326". |
+ 'EPSG:4326' |
+
resampling |
+ Resampling |
+ The resampling method. Defaults to "nearest". |
+ 'nearest' |
+
to_cog |
+ bool |
+ Whether to convert the output image to a Cloud Optimized GeoTIFF. Defaults to True. |
+ True |
+
**kwargs |
+ + | Additional keyword arguments to pass to rasterio.open. |
+ {} |
+
samgeo/common.py
def reproject(
+ image, output, dst_crs="EPSG:4326", resampling="nearest", to_cog=True, **kwargs
+):
+ """Reprojects an image.
+
+ Args:
+ image (str): The input image filepath.
+ output (str): The output image filepath.
+ dst_crs (str, optional): The destination CRS. Defaults to "EPSG:4326".
+ resampling (Resampling, optional): The resampling method. Defaults to "nearest".
+ to_cog (bool, optional): Whether to convert the output image to a Cloud Optimized GeoTIFF. Defaults to True.
+ **kwargs: Additional keyword arguments to pass to rasterio.open.
+
+ """
+ import rasterio as rio
+ from rasterio.warp import calculate_default_transform, reproject, Resampling
+
+ if isinstance(resampling, str):
+ resampling = getattr(Resampling, resampling)
+
+ image = os.path.abspath(image)
+ output = os.path.abspath(output)
+
+ if not os.path.exists(os.path.dirname(output)):
+ os.makedirs(os.path.dirname(output))
+
+ with rio.open(image, **kwargs) as src:
+ transform, width, height = calculate_default_transform(
+ src.crs, dst_crs, src.width, src.height, *src.bounds
+ )
+ kwargs = src.meta.copy()
+ kwargs.update(
+ {
+ "crs": dst_crs,
+ "transform": transform,
+ "width": width,
+ "height": height,
+ }
+ )
+
+ with rio.open(output, "w", **kwargs) as dst:
+ for i in range(1, src.count + 1):
+ reproject(
+ source=rio.band(src, i),
+ destination=rio.band(dst, i),
+ src_transform=src.transform,
+ src_crs=src.crs,
+ dst_transform=transform,
+ dst_crs=dst_crs,
+ resampling=resampling,
+ **kwargs,
+ )
+
+ if to_cog:
+ image_to_cog(output, output)
+
rowcol_to_xy(src_fp, rows=None, cols=None, boxes=None, zs=None, offset='center', output=None, dst_crs='EPSG:4326', **kwargs)
+
+
+¶Converts a list of (row, col) coordinates to (x, y) coordinates.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
src_fp |
+ str |
+ The source raster file path. |
+ required | +
rows |
+ list |
+ A list of row coordinates. Defaults to None. |
+ None |
+
cols |
+ list |
+ A list of col coordinates. Defaults to None. |
+ None |
+
boxes |
+ list |
+ A list of (row, col) coordinates in the format of [[left, top, right, bottom], [left, top, right, bottom], ...] |
+ None |
+
zs |
+ + | zs (list or float, optional): Height associated with coordinates. Primarily used for RPC based coordinate transformations. |
+ None |
+
offset |
+ str |
+ Determines if the returned coordinates are for the center of the pixel or for a corner. |
+ 'center' |
+
output |
+ str |
+ The output vector file path. Defaults to None. |
+ None |
+
dst_crs |
+ str |
+ The destination CRS. Defaults to "EPSG:4326". |
+ 'EPSG:4326' |
+
**kwargs |
+ + | Additional keyword arguments to pass to rasterio.transform.xy. |
+ {} |
+
Returns:
+Type | +Description | +
---|---|
+ | A list of (x, y) coordinates. |
+
samgeo/common.py
def rowcol_to_xy(
+ src_fp,
+ rows=None,
+ cols=None,
+ boxes=None,
+ zs=None,
+ offset="center",
+ output=None,
+ dst_crs="EPSG:4326",
+ **kwargs,
+):
+ """Converts a list of (row, col) coordinates to (x, y) coordinates.
+
+ Args:
+ src_fp (str): The source raster file path.
+ rows (list, optional): A list of row coordinates. Defaults to None.
+ cols (list, optional): A list of col coordinates. Defaults to None.
+ boxes (list, optional): A list of (row, col) coordinates in the format of [[left, top, right, bottom], [left, top, right, bottom], ...]
+ zs: zs (list or float, optional): Height associated with coordinates. Primarily used for RPC based coordinate transformations.
+ offset (str, optional): Determines if the returned coordinates are for the center of the pixel or for a corner.
+ output (str, optional): The output vector file path. Defaults to None.
+ dst_crs (str, optional): The destination CRS. Defaults to "EPSG:4326".
+ **kwargs: Additional keyword arguments to pass to rasterio.transform.xy.
+
+ Returns:
+ A list of (x, y) coordinates.
+ """
+
+ if boxes is not None:
+ rows = []
+ cols = []
+
+ for box in boxes:
+ rows.append(box[1])
+ rows.append(box[3])
+ cols.append(box[0])
+ cols.append(box[2])
+
+ if rows is None or cols is None:
+ raise ValueError("rows and cols must be provided.")
+
+ with rasterio.open(src_fp) as src:
+ xs, ys = rasterio.transform.xy(src.transform, rows, cols, zs, offset, **kwargs)
+ src_crs = src.crs
+
+ if boxes is None:
+ return [[x, y] for x, y in zip(xs, ys)]
+ else:
+ result = [[xs[i], ys[i + 1], xs[i + 1], ys[i]] for i in range(0, len(xs), 2)]
+
+ if output is not None:
+ boxes_to_vector(result, src_crs, dst_crs, output)
+ else:
+ return result
+
sam_map_gui(sam, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)
+
+
+¶Display the SAM Map GUI.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
sam |
+ SamGeo |
+ + | required | +
basemap |
+ str |
+ The basemap to use. Defaults to "SATELLITE". |
+ 'SATELLITE' |
+
repeat_mode |
+ bool |
+ Whether to use the repeat mode for the draw control. Defaults to True. |
+ True |
+
out_dir |
+ str |
+ The output directory. Defaults to None. |
+ None |
+
samgeo/common.py
def sam_map_gui(sam, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
+ """Display the SAM Map GUI.
+
+ Args:
+ sam (SamGeo):
+ basemap (str, optional): The basemap to use. Defaults to "SATELLITE".
+ repeat_mode (bool, optional): Whether to use the repeat mode for the draw control. Defaults to True.
+ out_dir (str, optional): The output directory. Defaults to None.
+
+ """
+ try:
+ import shutil
+ import tempfile
+ import leafmap
+ import ipyleaflet
+ import ipyevents
+ import ipywidgets as widgets
+ from ipyfilechooser import FileChooser
+ except ImportError:
+ raise ImportError(
+ "The sam_map function requires the leafmap package. Please install it first."
+ )
+
+ if out_dir is None:
+ out_dir = tempfile.gettempdir()
+
+ m = leafmap.Map(repeat_mode=repeat_mode, **kwargs)
+ m.default_style = {"cursor": "crosshair"}
+ m.add_basemap(basemap, show=False)
+
+ # Skip the image layer if localtileserver is not available
+ try:
+ m.add_raster(sam.source, layer_name="Image")
+ except:
+ pass
+
+ m.fg_markers = []
+ m.bg_markers = []
+
+ fg_layer = ipyleaflet.LayerGroup(layers=m.fg_markers, name="Foreground")
+ bg_layer = ipyleaflet.LayerGroup(layers=m.bg_markers, name="Background")
+ m.add(fg_layer)
+ m.add(bg_layer)
+ m.fg_layer = fg_layer
+ m.bg_layer = bg_layer
+
+ widget_width = "280px"
+ button_width = "90px"
+ padding = "0px 0px 0px 4px" # upper, right, bottom, left
+ style = {"description_width": "initial"}
+
+ toolbar_button = widgets.ToggleButton(
+ value=True,
+ tooltip="Toolbar",
+ icon="gear",
+ layout=widgets.Layout(width="28px", height="28px", padding=padding),
+ )
+
+ close_button = widgets.ToggleButton(
+ value=False,
+ tooltip="Close the tool",
+ icon="times",
+ button_style="primary",
+ layout=widgets.Layout(height="28px", width="28px", padding=padding),
+ )
+
+ plus_button = widgets.ToggleButton(
+ value=False,
+ tooltip="Load foreground points",
+ icon="plus-circle",
+ button_style="primary",
+ layout=widgets.Layout(height="28px", width="28px", padding=padding),
+ )
+
+ minus_button = widgets.ToggleButton(
+ value=False,
+ tooltip="Load background points",
+ icon="minus-circle",
+ button_style="primary",
+ layout=widgets.Layout(height="28px", width="28px", padding=padding),
+ )
+
+ radio_buttons = widgets.RadioButtons(
+ options=["Foreground", "Background"],
+ description="Class Type:",
+ disabled=False,
+ style=style,
+ layout=widgets.Layout(width=widget_width, padding=padding),
+ )
+
+ fg_count = widgets.IntText(
+ value=0,
+ description="Foreground #:",
+ disabled=True,
+ style=style,
+ layout=widgets.Layout(width="135px", padding=padding),
+ )
+ bg_count = widgets.IntText(
+ value=0,
+ description="Background #:",
+ disabled=True,
+ style=style,
+ layout=widgets.Layout(width="135px", padding=padding),
+ )
+
+ segment_button = widgets.ToggleButton(
+ description="Segment",
+ value=False,
+ button_style="primary",
+ layout=widgets.Layout(padding=padding),
+ )
+
+ save_button = widgets.ToggleButton(
+ description="Save", value=False, button_style="primary"
+ )
+
+ reset_button = widgets.ToggleButton(
+ description="Reset", value=False, button_style="primary"
+ )
+ segment_button.layout.width = button_width
+ save_button.layout.width = button_width
+ reset_button.layout.width = button_width
+
+ opacity_slider = widgets.FloatSlider(
+ description="Mask opacity:",
+ min=0,
+ max=1,
+ value=0.5,
+ readout=True,
+ continuous_update=True,
+ layout=widgets.Layout(width=widget_width, padding=padding),
+ style=style,
+ )
+
+ rectangular = widgets.Checkbox(
+ value=False,
+ description="Regularize",
+ layout=widgets.Layout(width="130px", padding=padding),
+ style=style,
+ )
+
+ colorpicker = widgets.ColorPicker(
+ concise=False,
+ description="Color",
+ value="#ffff00",
+ layout=widgets.Layout(width="140px", padding=padding),
+ style=style,
+ )
+
+ buttons = widgets.VBox(
+ [
+ radio_buttons,
+ widgets.HBox([fg_count, bg_count]),
+ opacity_slider,
+ widgets.HBox([rectangular, colorpicker]),
+ widgets.HBox(
+ [segment_button, save_button, reset_button],
+ layout=widgets.Layout(padding="0px 4px 0px 4px"),
+ ),
+ ]
+ )
+
+ def opacity_changed(change):
+ if change["new"]:
+ mask_layer = m.find_layer("Masks")
+ if mask_layer is not None:
+ mask_layer.interact(opacity=opacity_slider.value)
+
+ opacity_slider.observe(opacity_changed, "value")
+
+ output = widgets.Output(
+ layout=widgets.Layout(
+ width=widget_width, padding=padding, max_width=widget_width
+ )
+ )
+
+ toolbar_header = widgets.HBox()
+ toolbar_header.children = [close_button, plus_button, minus_button, toolbar_button]
+ toolbar_footer = widgets.VBox()
+ toolbar_footer.children = [
+ buttons,
+ output,
+ ]
+ toolbar_widget = widgets.VBox()
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
+
+ toolbar_event = ipyevents.Event(
+ source=toolbar_widget, watched_events=["mouseenter", "mouseleave"]
+ )
+
+ def marker_callback(chooser):
+ with output:
+ if chooser.selected is not None:
+ try:
+ gdf = gpd.read_file(chooser.selected)
+ centroids = gdf.centroid
+ coords = [[point.x, point.y] for point in centroids]
+ for coord in coords:
+ if plus_button.value:
+ if is_colab(): # Colab does not support AwesomeIcon
+ marker = ipyleaflet.CircleMarker(
+ location=(coord[1], coord[0]),
+ radius=2,
+ color="green",
+ fill_color="green",
+ )
+ else:
+ marker = ipyleaflet.Marker(
+ location=[coord[1], coord[0]],
+ icon=ipyleaflet.AwesomeIcon(
+ name="plus-circle",
+ marker_color="green",
+ icon_color="darkred",
+ ),
+ )
+ m.fg_layer.add(marker)
+ m.fg_markers.append(marker)
+ fg_count.value = len(m.fg_markers)
+ elif minus_button.value:
+ if is_colab():
+ marker = ipyleaflet.CircleMarker(
+ location=(coord[1], coord[0]),
+ radius=2,
+ color="red",
+ fill_color="red",
+ )
+ else:
+ marker = ipyleaflet.Marker(
+ location=[coord[1], coord[0]],
+ icon=ipyleaflet.AwesomeIcon(
+ name="minus-circle",
+ marker_color="red",
+ icon_color="darkred",
+ ),
+ )
+ m.bg_layer.add(marker)
+ m.bg_markers.append(marker)
+ bg_count.value = len(m.bg_markers)
+
+ except Exception as e:
+ print(e)
+
+ if m.marker_control in m.controls:
+ m.remove_control(m.marker_control)
+ delattr(m, "marker_control")
+
+ plus_button.value = False
+ minus_button.value = False
+
+ def marker_button_click(change):
+ if change["new"]:
+ sandbox_path = os.environ.get("SANDBOX_PATH")
+ filechooser = FileChooser(
+ path=os.getcwd(),
+ sandbox_path=sandbox_path,
+ layout=widgets.Layout(width="454px"),
+ )
+ filechooser.use_dir_icons = True
+ filechooser.filter_pattern = ["*.shp", "*.geojson", "*.gpkg"]
+ filechooser.register_callback(marker_callback)
+ marker_control = ipyleaflet.WidgetControl(
+ widget=filechooser, position="topright"
+ )
+ m.add_control(marker_control)
+ m.marker_control = marker_control
+ else:
+ if hasattr(m, "marker_control") and m.marker_control in m.controls:
+ m.remove_control(m.marker_control)
+ m.marker_control.close()
+
+ plus_button.observe(marker_button_click, "value")
+ minus_button.observe(marker_button_click, "value")
+
+ def handle_toolbar_event(event):
+ if event["type"] == "mouseenter":
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
+ elif event["type"] == "mouseleave":
+ if not toolbar_button.value:
+ toolbar_widget.children = [toolbar_button]
+ toolbar_button.value = False
+ close_button.value = False
+
+ toolbar_event.on_dom_event(handle_toolbar_event)
+
+ def toolbar_btn_click(change):
+ if change["new"]:
+ close_button.value = False
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
+ else:
+ if not close_button.value:
+ toolbar_widget.children = [toolbar_button]
+
+ toolbar_button.observe(toolbar_btn_click, "value")
+
+ def close_btn_click(change):
+ if change["new"]:
+ toolbar_button.value = False
+ if m.toolbar_control in m.controls:
+ m.remove_control(m.toolbar_control)
+ toolbar_widget.close()
+
+ close_button.observe(close_btn_click, "value")
+
+ def handle_map_interaction(**kwargs):
+ try:
+ if kwargs.get("type") == "click":
+ latlon = kwargs.get("coordinates")
+ if radio_buttons.value == "Foreground":
+ if is_colab():
+ marker = ipyleaflet.CircleMarker(
+ location=tuple(latlon),
+ radius=2,
+ color="green",
+ fill_color="green",
+ )
+ else:
+ marker = ipyleaflet.Marker(
+ location=latlon,
+ icon=ipyleaflet.AwesomeIcon(
+ name="plus-circle",
+ marker_color="green",
+ icon_color="darkred",
+ ),
+ )
+ fg_layer.add(marker)
+ m.fg_markers.append(marker)
+ fg_count.value = len(m.fg_markers)
+ elif radio_buttons.value == "Background":
+ if is_colab():
+ marker = ipyleaflet.CircleMarker(
+ location=tuple(latlon),
+ radius=2,
+ color="red",
+ fill_color="red",
+ )
+ else:
+ marker = ipyleaflet.Marker(
+ location=latlon,
+ icon=ipyleaflet.AwesomeIcon(
+ name="minus-circle",
+ marker_color="red",
+ icon_color="darkred",
+ ),
+ )
+ bg_layer.add(marker)
+ m.bg_markers.append(marker)
+ bg_count.value = len(m.bg_markers)
+
+ except (TypeError, KeyError) as e:
+ print(f"Error handling map interaction: {e}")
+
+ m.on_interaction(handle_map_interaction)
+
+ def segment_button_click(change):
+ if change["new"]:
+ segment_button.value = False
+ with output:
+ output.clear_output()
+ if len(m.fg_markers) == 0:
+ print("Please add some foreground markers.")
+ segment_button.value = False
+ return
+
+ else:
+ try:
+ fg_points = [
+ [marker.location[1], marker.location[0]]
+ for marker in m.fg_markers
+ ]
+ bg_points = [
+ [marker.location[1], marker.location[0]]
+ for marker in m.bg_markers
+ ]
+ point_coords = fg_points + bg_points
+ point_labels = [1] * len(fg_points) + [0] * len(bg_points)
+
+ filename = f"masks_{random_string()}.tif"
+ filename = os.path.join(out_dir, filename)
+ sam.predict(
+ point_coords=point_coords,
+ point_labels=point_labels,
+ point_crs="EPSG:4326",
+ output=filename,
+ )
+ if m.find_layer("Masks") is not None:
+ m.remove_layer(m.find_layer("Masks"))
+ if m.find_layer("Regularized") is not None:
+ m.remove_layer(m.find_layer("Regularized"))
+
+ if hasattr(sam, "prediction_fp") and os.path.exists(
+ sam.prediction_fp
+ ):
+ try:
+ os.remove(sam.prediction_fp)
+ except:
+ pass
+
+ # Skip the image layer if localtileserver is not available
+ try:
+ m.add_raster(
+ filename,
+ nodata=0,
+ cmap="Blues",
+ opacity=opacity_slider.value,
+ layer_name="Masks",
+ zoom_to_layer=False,
+ )
+
+ if rectangular.value:
+ vector = filename.replace(".tif", ".gpkg")
+ vector_rec = filename.replace(".tif", "_rect.gpkg")
+ raster_to_vector(filename, vector)
+ regularize(vector, vector_rec)
+ vector_style = {"color": colorpicker.value}
+ m.add_vector(
+ vector_rec,
+ layer_name="Regularized",
+ style=vector_style,
+ info_mode=None,
+ zoom_to_layer=False,
+ )
+
+ except:
+ pass
+ output.clear_output()
+ segment_button.value = False
+ sam.prediction_fp = filename
+ except Exception as e:
+ segment_button.value = False
+ print(e)
+
+ segment_button.observe(segment_button_click, "value")
+
+ def filechooser_callback(chooser):
+ with output:
+ if chooser.selected is not None:
+ try:
+ filename = chooser.selected
+ shutil.copy(sam.prediction_fp, filename)
+ vector = filename.replace(".tif", ".gpkg")
+ raster_to_vector(filename, vector)
+ if rectangular.value:
+ vector_rec = filename.replace(".tif", "_rect.gpkg")
+ regularize(vector, vector_rec)
+
+ fg_points = [
+ [marker.location[1], marker.location[0]]
+ for marker in m.fg_markers
+ ]
+ bg_points = [
+ [marker.location[1], marker.location[0]]
+ for marker in m.bg_markers
+ ]
+
+ coords_to_geojson(
+ fg_points, filename.replace(".tif", "_fg_markers.geojson")
+ )
+ coords_to_geojson(
+ bg_points, filename.replace(".tif", "_bg_markers.geojson")
+ )
+
+ except Exception as e:
+ print(e)
+
+ if hasattr(m, "save_control") and m.save_control in m.controls:
+ m.remove_control(m.save_control)
+ delattr(m, "save_control")
+ save_button.value = False
+
+ def save_button_click(change):
+ if change["new"]:
+ with output:
+ sandbox_path = os.environ.get("SANDBOX_PATH")
+ filechooser = FileChooser(
+ path=os.getcwd(),
+ filename="masks.tif",
+ sandbox_path=sandbox_path,
+ layout=widgets.Layout(width="454px"),
+ )
+ filechooser.use_dir_icons = True
+ filechooser.filter_pattern = ["*.tif"]
+ filechooser.register_callback(filechooser_callback)
+ save_control = ipyleaflet.WidgetControl(
+ widget=filechooser, position="topright"
+ )
+ m.add_control(save_control)
+ m.save_control = save_control
+ else:
+ if hasattr(m, "save_control") and m.save_control in m.controls:
+ m.remove_control(m.save_control)
+ delattr(m, "save_control")
+
+ save_button.observe(save_button_click, "value")
+
+ def reset_button_click(change):
+ if change["new"]:
+ segment_button.value = False
+ reset_button.value = False
+ opacity_slider.value = 0.5
+ rectangular.value = False
+ colorpicker.value = "#ffff00"
+ output.clear_output()
+ try:
+ m.remove_layer(m.find_layer("Masks"))
+ if m.find_layer("Regularized") is not None:
+ m.remove_layer(m.find_layer("Regularized"))
+ m.clear_drawings()
+ if hasattr(m, "fg_markers"):
+ m.user_rois = None
+ m.fg_markers = []
+ m.bg_markers = []
+ m.fg_layer.clear_layers()
+ m.bg_layer.clear_layers()
+ fg_count.value = 0
+ bg_count.value = 0
+ try:
+ os.remove(sam.prediction_fp)
+ except:
+ pass
+ except:
+ pass
+
+ reset_button.observe(reset_button_click, "value")
+
+ toolbar_control = ipyleaflet.WidgetControl(
+ widget=toolbar_widget, position="topright"
+ )
+ m.add_control(toolbar_control)
+ m.toolbar_control = toolbar_control
+
+ return m
+
show_canvas(image, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)
+
+
+¶Show a canvas to collect foreground and background points.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str | np.ndarray |
+ The input image. |
+ required | +
fg_color |
+ tuple |
+ The color for the foreground points. Defaults to (0, 255, 0). |
+ (0, 255, 0) |
+
bg_color |
+ tuple |
+ The color for the background points. Defaults to (0, 0, 255). |
+ (0, 0, 255) |
+
radius |
+ int |
+ The radius of the points. Defaults to 5. |
+ 5 |
+
Returns:
+Type | +Description | +
---|---|
tuple |
+ A tuple of two lists of foreground and background points. |
+
samgeo/common.py
def show_canvas(image, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
+ """Show a canvas to collect foreground and background points.
+
+ Args:
+ image (str | np.ndarray): The input image.
+ fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
+ bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
+ radius (int, optional): The radius of the points. Defaults to 5.
+
+ Returns:
+ tuple: A tuple of two lists of foreground and background points.
+ """
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ image = cv2.imread(image)
+ elif isinstance(image, np.ndarray):
+ pass
+ else:
+ raise ValueError("Input image must be a URL or a NumPy array.")
+
+ # Create an empty list to store the mouse click coordinates
+ left_clicks = []
+ right_clicks = []
+
+ # Create a mouse callback function
+ def get_mouse_coordinates(event, x, y):
+ if event == cv2.EVENT_LBUTTONDOWN:
+ # Append the coordinates to the mouse_clicks list
+ left_clicks.append((x, y))
+
+ # Draw a green circle at the mouse click coordinates
+ cv2.circle(image, (x, y), radius, fg_color, -1)
+
+ # Show the updated image with the circle
+ cv2.imshow("Image", image)
+
+ elif event == cv2.EVENT_RBUTTONDOWN:
+ # Append the coordinates to the mouse_clicks list
+ right_clicks.append((x, y))
+
+ # Draw a red circle at the mouse click coordinates
+ cv2.circle(image, (x, y), radius, bg_color, -1)
+
+ # Show the updated image with the circle
+ cv2.imshow("Image", image)
+
+ # Create a window to display the image
+ cv2.namedWindow("Image")
+
+ # Set the mouse callback function for the window
+ cv2.setMouseCallback("Image", get_mouse_coordinates)
+
+ # Display the image in the window
+ cv2.imshow("Image", image)
+
+ # Wait for a key press to exit
+ cv2.waitKey(0)
+
+ # Destroy the window
+ cv2.destroyAllWindows()
+
+ return left_clicks, right_clicks
+
split_raster(filename, out_dir, tile_size=256, overlap=0)
+
+
+¶Split a raster into tiles.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
filename |
+ str |
+ The path or http URL to the raster file. |
+ required | +
out_dir |
+ str |
+ The path to the output directory. |
+ required | +
tile_size |
+ int | tuple |
+ The size of the tiles. Can be an integer or a tuple of (width, height). Defaults to 256. |
+ 256 |
+
overlap |
+ int |
+ The number of pixels to overlap between tiles. Defaults to 0. |
+ 0 |
+
Exceptions:
+Type | +Description | +
---|---|
ImportError |
+ Raised if GDAL is not installed. |
+
samgeo/common.py
def split_raster(filename, out_dir, tile_size=256, overlap=0):
+ """Split a raster into tiles.
+
+ Args:
+ filename (str): The path or http URL to the raster file.
+ out_dir (str): The path to the output directory.
+ tile_size (int | tuple, optional): The size of the tiles. Can be an integer or a tuple of (width, height). Defaults to 256.
+ overlap (int, optional): The number of pixels to overlap between tiles. Defaults to 0.
+
+ Raises:
+ ImportError: Raised if GDAL is not installed.
+ """
+
+ try:
+ from osgeo import gdal
+ except ImportError:
+ raise ImportError(
+ "GDAL is required to use this function. Install it with `conda install gdal -c conda-forge`"
+ )
+
+ if isinstance(filename, str):
+ if filename.startswith("http"):
+ output = filename.split("/")[-1]
+ download_file(filename, output)
+ filename = output
+
+ # Open the input GeoTIFF file
+ ds = gdal.Open(filename)
+
+ if not os.path.exists(out_dir):
+ os.makedirs(out_dir)
+
+ if isinstance(tile_size, int):
+ tile_width = tile_size
+ tile_height = tile_size
+ elif isinstance(tile_size, tuple):
+ tile_width = tile_size[0]
+ tile_height = tile_size[1]
+
+ # Get the size of the input raster
+ width = ds.RasterXSize
+ height = ds.RasterYSize
+
+ # Calculate the number of tiles needed in both directions, taking into account the overlap
+ num_tiles_x = (width - overlap) // (tile_width - overlap) + int(
+ (width - overlap) % (tile_width - overlap) > 0
+ )
+ num_tiles_y = (height - overlap) // (tile_height - overlap) + int(
+ (height - overlap) % (tile_height - overlap) > 0
+ )
+
+ # Get the georeferencing information of the input raster
+ geotransform = ds.GetGeoTransform()
+
+ # Loop over all the tiles
+ for i in range(num_tiles_x):
+ for j in range(num_tiles_y):
+ # Calculate the pixel coordinates of the tile, taking into account the overlap and clamping to the edge of the raster
+ x_min = i * (tile_width - overlap)
+ y_min = j * (tile_height - overlap)
+ x_max = min(x_min + tile_width, width)
+ y_max = min(y_min + tile_height, height)
+
+ # Adjust the size of the last tile in each row and column to include any remaining pixels
+ if i == num_tiles_x - 1:
+ x_min = max(x_max - tile_width, 0)
+ if j == num_tiles_y - 1:
+ y_min = max(y_max - tile_height, 0)
+
+ # Calculate the size of the tile, taking into account the overlap
+ tile_width = x_max - x_min
+ tile_height = y_max - y_min
+
+ # Set the output file name
+ output_file = f"{out_dir}/tile_{i}_{j}.tif"
+
+ # Create a new dataset for the tile
+ driver = gdal.GetDriverByName("GTiff")
+ tile_ds = driver.Create(
+ output_file,
+ tile_width,
+ tile_height,
+ ds.RasterCount,
+ ds.GetRasterBand(1).DataType,
+ )
+
+ # Calculate the georeferencing information for the output tile
+ tile_geotransform = (
+ geotransform[0] + x_min * geotransform[1],
+ geotransform[1],
+ 0,
+ geotransform[3] + y_min * geotransform[5],
+ 0,
+ geotransform[5],
+ )
+
+ # Set the geotransform and projection of the tile
+ tile_ds.SetGeoTransform(tile_geotransform)
+ tile_ds.SetProjection(ds.GetProjection())
+
+ # Read the data from the input raster band(s) and write it to the tile band(s)
+ for k in range(ds.RasterCount):
+ band = ds.GetRasterBand(k + 1)
+ tile_band = tile_ds.GetRasterBand(k + 1)
+ tile_data = band.ReadAsArray(x_min, y_min, tile_width, tile_height)
+ tile_band.WriteArray(tile_data)
+
+ # Close the tile dataset
+ tile_ds = None
+
+ # Close the input dataset
+ ds = None
+
temp_file_path(extension)
+
+
+¶Returns a temporary file path.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
extension |
+ str |
+ The file extension. |
+ required | +
Returns:
+Type | +Description | +
---|---|
str |
+ The temporary file path. |
+
samgeo/common.py
def temp_file_path(extension):
+ """Returns a temporary file path.
+
+ Args:
+ extension (str): The file extension.
+
+ Returns:
+ str: The temporary file path.
+ """
+
+ import tempfile
+ import uuid
+
+ if not extension.startswith("."):
+ extension = "." + extension
+ file_id = str(uuid.uuid4())
+ file_path = os.path.join(tempfile.gettempdir(), f"{file_id}{extension}")
+
+ return file_path
+
text_sam_gui(sam, basemap='SATELLITE', out_dir=None, box_threshold=0.25, text_threshold=0.25, cmap='viridis', opacity=0.5, **kwargs)
+
+
+¶Display the SAM Map GUI.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
sam |
+ SamGeo |
+ + | required | +
basemap |
+ str |
+ The basemap to use. Defaults to "SATELLITE". |
+ 'SATELLITE' |
+
out_dir |
+ str |
+ The output directory. Defaults to None. |
+ None |
+
samgeo/common.py
def text_sam_gui(
+ sam,
+ basemap="SATELLITE",
+ out_dir=None,
+ box_threshold=0.25,
+ text_threshold=0.25,
+ cmap="viridis",
+ opacity=0.5,
+ **kwargs,
+):
+ """Display the SAM Map GUI.
+
+ Args:
+ sam (SamGeo):
+ basemap (str, optional): The basemap to use. Defaults to "SATELLITE".
+ out_dir (str, optional): The output directory. Defaults to None.
+
+ """
+ try:
+ import shutil
+ import tempfile
+ import leafmap
+ import ipyleaflet
+ import ipyevents
+ import ipywidgets as widgets
+ import leafmap.colormaps as cm
+ from ipyfilechooser import FileChooser
+ except ImportError:
+ raise ImportError(
+ "The sam_map function requires the leafmap package. Please install it first."
+ )
+
+ if out_dir is None:
+ out_dir = tempfile.gettempdir()
+
+ m = leafmap.Map(**kwargs)
+ m.default_style = {"cursor": "crosshair"}
+ m.add_basemap(basemap, show=False)
+
+ # Skip the image layer if localtileserver is not available
+ try:
+ m.add_raster(sam.source, layer_name="Image")
+ except:
+ pass
+
+ widget_width = "280px"
+ button_width = "90px"
+ padding = "0px 4px 0px 4px" # upper, right, bottom, left
+ style = {"description_width": "initial"}
+
+ toolbar_button = widgets.ToggleButton(
+ value=True,
+ tooltip="Toolbar",
+ icon="gear",
+ layout=widgets.Layout(width="28px", height="28px", padding="0px 0px 0px 4px"),
+ )
+
+ close_button = widgets.ToggleButton(
+ value=False,
+ tooltip="Close the tool",
+ icon="times",
+ button_style="primary",
+ layout=widgets.Layout(height="28px", width="28px", padding="0px 0px 0px 4px"),
+ )
+
+ text_prompt = widgets.Text(
+ description="Text prompt:",
+ style=style,
+ layout=widgets.Layout(width=widget_width, padding=padding),
+ )
+
+ box_slider = widgets.FloatSlider(
+ description="Box threshold:",
+ min=0,
+ max=1,
+ value=box_threshold,
+ step=0.01,
+ readout=True,
+ continuous_update=True,
+ layout=widgets.Layout(width=widget_width, padding=padding),
+ style=style,
+ )
+
+ text_slider = widgets.FloatSlider(
+ description="Text threshold:",
+ min=0,
+ max=1,
+ step=0.01,
+ value=text_threshold,
+ readout=True,
+ continuous_update=True,
+ layout=widgets.Layout(width=widget_width, padding=padding),
+ style=style,
+ )
+
+ cmap_dropdown = widgets.Dropdown(
+ description="Palette:",
+ options=cm.list_colormaps(),
+ value=cmap,
+ style=style,
+ layout=widgets.Layout(width=widget_width, padding=padding),
+ )
+
+ opacity_slider = widgets.FloatSlider(
+ description="Opacity:",
+ min=0,
+ max=1,
+ value=opacity,
+ readout=True,
+ continuous_update=True,
+ layout=widgets.Layout(width=widget_width, padding=padding),
+ style=style,
+ )
+
+ def opacity_changed(change):
+ if change["new"]:
+ if hasattr(m, "layer_name"):
+ mask_layer = m.find_layer(m.layer_name)
+ if mask_layer is not None:
+ mask_layer.interact(opacity=opacity_slider.value)
+
+ opacity_slider.observe(opacity_changed, "value")
+
+ rectangular = widgets.Checkbox(
+ value=False,
+ description="Regularize",
+ layout=widgets.Layout(width="130px", padding=padding),
+ style=style,
+ )
+
+ colorpicker = widgets.ColorPicker(
+ concise=False,
+ description="Color",
+ value="#ffff00",
+ layout=widgets.Layout(width="140px", padding=padding),
+ style=style,
+ )
+
+ segment_button = widgets.ToggleButton(
+ description="Segment",
+ value=False,
+ button_style="primary",
+ layout=widgets.Layout(padding=padding),
+ )
+
+ save_button = widgets.ToggleButton(
+ description="Save", value=False, button_style="primary"
+ )
+
+ reset_button = widgets.ToggleButton(
+ description="Reset", value=False, button_style="primary"
+ )
+ segment_button.layout.width = button_width
+ save_button.layout.width = button_width
+ reset_button.layout.width = button_width
+
+ output = widgets.Output(
+ layout=widgets.Layout(
+ width=widget_width, padding=padding, max_width=widget_width
+ )
+ )
+
+ toolbar_header = widgets.HBox()
+ toolbar_header.children = [close_button, toolbar_button]
+ toolbar_footer = widgets.VBox()
+ toolbar_footer.children = [
+ text_prompt,
+ box_slider,
+ text_slider,
+ cmap_dropdown,
+ opacity_slider,
+ widgets.HBox([rectangular, colorpicker]),
+ widgets.HBox(
+ [segment_button, save_button, reset_button],
+ layout=widgets.Layout(padding="0px 4px 0px 4px"),
+ ),
+ output,
+ ]
+ toolbar_widget = widgets.VBox()
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
+
+ toolbar_event = ipyevents.Event(
+ source=toolbar_widget, watched_events=["mouseenter", "mouseleave"]
+ )
+
+ def handle_toolbar_event(event):
+ if event["type"] == "mouseenter":
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
+ elif event["type"] == "mouseleave":
+ if not toolbar_button.value:
+ toolbar_widget.children = [toolbar_button]
+ toolbar_button.value = False
+ close_button.value = False
+
+ toolbar_event.on_dom_event(handle_toolbar_event)
+
+ def toolbar_btn_click(change):
+ if change["new"]:
+ close_button.value = False
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
+ else:
+ if not close_button.value:
+ toolbar_widget.children = [toolbar_button]
+
+ toolbar_button.observe(toolbar_btn_click, "value")
+
+ def close_btn_click(change):
+ if change["new"]:
+ toolbar_button.value = False
+ if m.toolbar_control in m.controls:
+ m.remove_control(m.toolbar_control)
+ toolbar_widget.close()
+
+ close_button.observe(close_btn_click, "value")
+
+ def segment_button_click(change):
+ if change["new"]:
+ segment_button.value = False
+ with output:
+ output.clear_output()
+ if len(text_prompt.value) == 0:
+ print("Please enter a text prompt first.")
+ elif sam.source is None:
+ print("Please run sam.set_image() first.")
+ else:
+ print("Segmenting...")
+ layer_name = text_prompt.value.replace(" ", "_")
+ filename = os.path.join(
+ out_dir, f"{layer_name}_{random_string()}.tif"
+ )
+ try:
+ sam.predict(
+ sam.source,
+ text_prompt.value,
+ box_slider.value,
+ text_slider.value,
+ output=filename,
+ )
+ sam.output = filename
+ if m.find_layer(layer_name) is not None:
+ m.remove_layer(m.find_layer(layer_name))
+ if m.find_layer(f"{layer_name}_rect") is not None:
+ m.remove_layer(m.find_layer(f"{layer_name} Regularized"))
+ except Exception as e:
+ output.clear_output()
+ print(e)
+ if os.path.exists(filename):
+ try:
+ m.add_raster(
+ filename,
+ layer_name=layer_name,
+ palette=cmap_dropdown.value,
+ opacity=opacity_slider.value,
+ nodata=0,
+ zoom_to_layer=False,
+ )
+ m.layer_name = layer_name
+
+ if rectangular.value:
+ vector = filename.replace(".tif", ".gpkg")
+ vector_rec = filename.replace(".tif", "_rect.gpkg")
+ raster_to_vector(filename, vector)
+ regularize(vector, vector_rec)
+ vector_style = {"color": colorpicker.value}
+ m.add_vector(
+ vector_rec,
+ layer_name=f"{layer_name} Regularized",
+ style=vector_style,
+ info_mode=None,
+ zoom_to_layer=False,
+ )
+
+ output.clear_output()
+ except Exception as e:
+ print(e)
+
+ segment_button.observe(segment_button_click, "value")
+
+ def filechooser_callback(chooser):
+ with output:
+ if chooser.selected is not None:
+ try:
+ filename = chooser.selected
+ shutil.copy(sam.output, filename)
+ vector = filename.replace(".tif", ".gpkg")
+ raster_to_vector(filename, vector)
+ if rectangular.value:
+ vector_rec = filename.replace(".tif", "_rect.gpkg")
+ regularize(vector, vector_rec)
+ except Exception as e:
+ print(e)
+
+ if hasattr(m, "save_control") and m.save_control in m.controls:
+ m.remove_control(m.save_control)
+ delattr(m, "save_control")
+ save_button.value = False
+
+ def save_button_click(change):
+ if change["new"]:
+ with output:
+ output.clear_output()
+ if not hasattr(m, "layer_name"):
+ print("Please click the Segment button first.")
+ else:
+ sandbox_path = os.environ.get("SANDBOX_PATH")
+ filechooser = FileChooser(
+ path=os.getcwd(),
+ filename=f"{m.layer_name}.tif",
+ sandbox_path=sandbox_path,
+ layout=widgets.Layout(width="454px"),
+ )
+ filechooser.use_dir_icons = True
+ filechooser.filter_pattern = ["*.tif"]
+ filechooser.register_callback(filechooser_callback)
+ save_control = ipyleaflet.WidgetControl(
+ widget=filechooser, position="topright"
+ )
+ m.add_control(save_control)
+ m.save_control = save_control
+
+ else:
+ if hasattr(m, "save_control") and m.save_control in m.controls:
+ m.remove_control(m.save_control)
+ delattr(m, "save_control")
+
+ save_button.observe(save_button_click, "value")
+
+ def reset_button_click(change):
+ if change["new"]:
+ segment_button.value = False
+ save_button.value = False
+ reset_button.value = False
+ opacity_slider.value = 0.5
+ box_slider.value = 0.25
+ text_slider.value = 0.25
+ cmap_dropdown.value = "viridis"
+ text_prompt.value = ""
+ output.clear_output()
+ try:
+ if hasattr(m, "layer_name") and m.find_layer(m.layer_name) is not None:
+ m.remove_layer(m.find_layer(m.layer_name))
+ m.clear_drawings()
+ except:
+ pass
+
+ reset_button.observe(reset_button_click, "value")
+
+ toolbar_control = ipyleaflet.WidgetControl(
+ widget=toolbar_widget, position="topright"
+ )
+ m.add_control(toolbar_control)
+ m.toolbar_control = toolbar_control
+
+ return m
+
tms_to_geotiff(output, bbox, zoom=None, resolution=None, source='OpenStreetMap', crs='EPSG:3857', to_cog=False, return_image=False, overwrite=False, quiet=False, **kwargs)
+
+
+¶Download TMS tiles and convert them to a GeoTIFF. The source is adapted from https://github.com/gumblex/tms2geotiff. + Credits to the GitHub user @gumblex.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
output |
+ str |
+ The output GeoTIFF file. |
+ required | +
bbox |
+ list |
+ The bounding box [minx, miny, maxx, maxy], e.g., [-122.5216, 37.733, -122.3661, 37.8095] |
+ required | +
zoom |
+ int |
+ The map zoom level. Defaults to None. |
+ None |
+
resolution |
+ float |
+ The resolution in meters. Defaults to None. |
+ None |
+
source |
+ str |
+ The tile source. It can be one of the following: "OPENSTREETMAP", "ROADMAP", +"SATELLITE", "TERRAIN", "HYBRID", or an HTTP URL. Defaults to "OpenStreetMap". |
+ 'OpenStreetMap' |
+
crs |
+ str |
+ The output CRS. Defaults to "EPSG:3857". |
+ 'EPSG:3857' |
+
to_cog |
+ bool |
+ Convert to Cloud Optimized GeoTIFF. Defaults to False. |
+ False |
+
return_image |
+ bool |
+ Return the image as PIL.Image. Defaults to False. |
+ False |
+
overwrite |
+ bool |
+ Overwrite the output file if it already exists. Defaults to False. |
+ False |
+
quiet |
+ bool |
+ Suppress output. Defaults to False. |
+ False |
+
**kwargs |
+ + | Additional arguments to pass to gdal.GetDriverByName("GTiff").Create(). |
+ {} |
+
samgeo/common.py
def tms_to_geotiff(
+ output,
+ bbox,
+ zoom=None,
+ resolution=None,
+ source="OpenStreetMap",
+ crs="EPSG:3857",
+ to_cog=False,
+ return_image=False,
+ overwrite=False,
+ quiet=False,
+ **kwargs,
+):
+ """Download TMS tiles and convert them to a GeoTIFF. The source is adapted from https://github.com/gumblex/tms2geotiff.
+ Credits to the GitHub user @gumblex.
+
+ Args:
+ output (str): The output GeoTIFF file.
+ bbox (list): The bounding box [minx, miny, maxx, maxy], e.g., [-122.5216, 37.733, -122.3661, 37.8095]
+ zoom (int, optional): The map zoom level. Defaults to None.
+ resolution (float, optional): The resolution in meters. Defaults to None.
+ source (str, optional): The tile source. It can be one of the following: "OPENSTREETMAP", "ROADMAP",
+ "SATELLITE", "TERRAIN", "HYBRID", or an HTTP URL. Defaults to "OpenStreetMap".
+ crs (str, optional): The output CRS. Defaults to "EPSG:3857".
+ to_cog (bool, optional): Convert to Cloud Optimized GeoTIFF. Defaults to False.
+ return_image (bool, optional): Return the image as PIL.Image. Defaults to False.
+ overwrite (bool, optional): Overwrite the output file if it already exists. Defaults to False.
+ quiet (bool, optional): Suppress output. Defaults to False.
+ **kwargs: Additional arguments to pass to gdal.GetDriverByName("GTiff").Create().
+
+ """
+
+ import os
+ import io
+ import math
+ import itertools
+ import concurrent.futures
+
+ import numpy
+ from PIL import Image
+
+ try:
+ from osgeo import gdal, osr
+ except ImportError:
+ raise ImportError("GDAL is not installed. Install it with pip install GDAL")
+
+ try:
+ import httpx
+
+ SESSION = httpx.Client()
+ except ImportError:
+ import requests
+
+ SESSION = requests.Session()
+
+ if not overwrite and os.path.exists(output):
+ print(
+ f"The output file {output} already exists. Use `overwrite=True` to overwrite it."
+ )
+ return
+
+ xyz_tiles = {
+ "OPENSTREETMAP": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
+ "ROADMAP": "https://mt1.google.com/vt/lyrs=m&x={x}&y={y}&z={z}",
+ "SATELLITE": "https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}",
+ "TERRAIN": "https://mt1.google.com/vt/lyrs=p&x={x}&y={y}&z={z}",
+ "HYBRID": "https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}",
+ }
+
+ basemaps = get_basemaps()
+
+ if isinstance(source, str):
+ if source.upper() in xyz_tiles:
+ source = xyz_tiles[source.upper()]
+ elif source in basemaps:
+ source = basemaps[source]
+ elif source.startswith("http"):
+ pass
+ else:
+ raise ValueError(
+ 'source must be one of "OpenStreetMap", "ROADMAP", "SATELLITE", "TERRAIN", "HYBRID", or a URL'
+ )
+
+ def resolution_to_zoom_level(resolution):
+ """
+ Convert map resolution in meters to zoom level for Web Mercator (EPSG:3857) tiles.
+ """
+ # Web Mercator tile size in meters at zoom level 0
+ initial_resolution = 156543.03392804097
+
+ # Calculate the zoom level
+ zoom_level = math.log2(initial_resolution / resolution)
+
+ return int(zoom_level)
+
+ if isinstance(bbox, list) and len(bbox) == 4:
+ west, south, east, north = bbox
+ else:
+ raise ValueError(
+ "bbox must be a list of 4 coordinates in the format of [xmin, ymin, xmax, ymax]"
+ )
+
+ if zoom is None and resolution is None:
+ raise ValueError("Either zoom or resolution must be provided")
+ elif zoom is not None and resolution is not None:
+ raise ValueError("Only one of zoom or resolution can be provided")
+
+ if resolution is not None:
+ zoom = resolution_to_zoom_level(resolution)
+
+ EARTH_EQUATORIAL_RADIUS = 6378137.0
+
+ Image.MAX_IMAGE_PIXELS = None
+
+ gdal.UseExceptions()
+ web_mercator = osr.SpatialReference()
+ web_mercator.ImportFromEPSG(3857)
+
+ WKT_3857 = web_mercator.ExportToWkt()
+
+ def from4326_to3857(lat, lon):
+ xtile = math.radians(lon) * EARTH_EQUATORIAL_RADIUS
+ ytile = (
+ math.log(math.tan(math.radians(45 + lat / 2.0))) * EARTH_EQUATORIAL_RADIUS
+ )
+ return (xtile, ytile)
+
+ def deg2num(lat, lon, zoom):
+ lat_r = math.radians(lat)
+ n = 2**zoom
+ xtile = (lon + 180) / 360 * n
+ ytile = (1 - math.log(math.tan(lat_r) + 1 / math.cos(lat_r)) / math.pi) / 2 * n
+ return (xtile, ytile)
+
+ def is_empty(im):
+ extrema = im.getextrema()
+ if len(extrema) >= 3:
+ if len(extrema) > 3 and extrema[-1] == (0, 0):
+ return True
+ for ext in extrema[:3]:
+ if ext != (0, 0):
+ return False
+ return True
+ else:
+ return extrema[0] == (0, 0)
+
+ def paste_tile(bigim, base_size, tile, corner_xy, bbox):
+ if tile is None:
+ return bigim
+ im = Image.open(io.BytesIO(tile))
+ mode = "RGB" if im.mode == "RGB" else "RGBA"
+ size = im.size
+ if bigim is None:
+ base_size[0] = size[0]
+ base_size[1] = size[1]
+ newim = Image.new(
+ mode, (size[0] * (bbox[2] - bbox[0]), size[1] * (bbox[3] - bbox[1]))
+ )
+ else:
+ newim = bigim
+
+ dx = abs(corner_xy[0] - bbox[0])
+ dy = abs(corner_xy[1] - bbox[1])
+ xy0 = (size[0] * dx, size[1] * dy)
+ if mode == "RGB":
+ newim.paste(im, xy0)
+ else:
+ if im.mode != mode:
+ im = im.convert(mode)
+ if not is_empty(im):
+ newim.paste(im, xy0)
+ im.close()
+ return newim
+
+ def finish_picture(bigim, base_size, bbox, x0, y0, x1, y1):
+ xfrac = x0 - bbox[0]
+ yfrac = y0 - bbox[1]
+ x2 = round(base_size[0] * xfrac)
+ y2 = round(base_size[1] * yfrac)
+ imgw = round(base_size[0] * (x1 - x0))
+ imgh = round(base_size[1] * (y1 - y0))
+ retim = bigim.crop((x2, y2, x2 + imgw, y2 + imgh))
+ if retim.mode == "RGBA" and retim.getextrema()[3] == (255, 255):
+ retim = retim.convert("RGB")
+ bigim.close()
+ return retim
+
+ def get_tile(url):
+ retry = 3
+ while 1:
+ try:
+ r = SESSION.get(url, timeout=60)
+ break
+ except Exception:
+ retry -= 1
+ if not retry:
+ raise
+ if r.status_code == 404:
+ return None
+ elif not r.content:
+ return None
+ r.raise_for_status()
+ return r.content
+
+ def draw_tile(
+ source, lat0, lon0, lat1, lon1, zoom, filename, quiet=False, **kwargs
+ ):
+ x0, y0 = deg2num(lat0, lon0, zoom)
+ x1, y1 = deg2num(lat1, lon1, zoom)
+ x0, x1 = sorted([x0, x1])
+ y0, y1 = sorted([y0, y1])
+ corners = tuple(
+ itertools.product(
+ range(math.floor(x0), math.ceil(x1)),
+ range(math.floor(y0), math.ceil(y1)),
+ )
+ )
+ totalnum = len(corners)
+ futures = []
+ with concurrent.futures.ThreadPoolExecutor(5) as executor:
+ for x, y in corners:
+ futures.append(
+ executor.submit(get_tile, source.format(z=zoom, x=x, y=y))
+ )
+ bbox = (math.floor(x0), math.floor(y0), math.ceil(x1), math.ceil(y1))
+ bigim = None
+ base_size = [256, 256]
+ for k, (fut, corner_xy) in enumerate(zip(futures, corners), 1):
+ bigim = paste_tile(bigim, base_size, fut.result(), corner_xy, bbox)
+ if not quiet:
+ print(
+ f"Downloaded image {str(k).zfill(len(str(totalnum)))}/{totalnum}"
+ )
+
+ if not quiet:
+ print("Saving GeoTIFF. Please wait...")
+ img = finish_picture(bigim, base_size, bbox, x0, y0, x1, y1)
+ imgbands = len(img.getbands())
+ driver = gdal.GetDriverByName("GTiff")
+
+ if "options" not in kwargs:
+ kwargs["options"] = [
+ "COMPRESS=DEFLATE",
+ "PREDICTOR=2",
+ "ZLEVEL=9",
+ "TILED=YES",
+ ]
+
+ gtiff = driver.Create(
+ filename,
+ img.size[0],
+ img.size[1],
+ imgbands,
+ gdal.GDT_Byte,
+ **kwargs,
+ )
+ xp0, yp0 = from4326_to3857(lat0, lon0)
+ xp1, yp1 = from4326_to3857(lat1, lon1)
+ pwidth = abs(xp1 - xp0) / img.size[0]
+ pheight = abs(yp1 - yp0) / img.size[1]
+ gtiff.SetGeoTransform((min(xp0, xp1), pwidth, 0, max(yp0, yp1), 0, -pheight))
+ gtiff.SetProjection(WKT_3857)
+ for band in range(imgbands):
+ array = numpy.array(img.getdata(band), dtype="u8")
+ array = array.reshape((img.size[1], img.size[0]))
+ band = gtiff.GetRasterBand(band + 1)
+ band.WriteArray(array)
+ gtiff.FlushCache()
+
+ if not quiet:
+ print(f"Image saved to {filename}")
+ return img
+
+ try:
+ image = draw_tile(
+ source, south, west, north, east, zoom, output, quiet, **kwargs
+ )
+ if return_image:
+ return image
+ if crs.upper() != "EPSG:3857":
+ reproject(output, output, crs, to_cog=to_cog)
+ elif to_cog:
+ image_to_cog(output, output)
+ except Exception as e:
+ raise Exception(e)
+
transform_coords(x, y, src_crs, dst_crs, **kwargs)
+
+
+¶Transform coordinates from one CRS to another.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ float |
+ The x coordinate. |
+ required | +
y |
+ float |
+ The y coordinate. |
+ required | +
src_crs |
+ str |
+ The source CRS, e.g., "EPSG:4326". |
+ required | +
dst_crs |
+ str |
+ The destination CRS, e.g., "EPSG:3857". |
+ required | +
Returns:
+Type | +Description | +
---|---|
dict |
+ The transformed coordinates in the format of (x, y) |
+
samgeo/common.py
def transform_coords(x, y, src_crs, dst_crs, **kwargs):
+ """Transform coordinates from one CRS to another.
+
+ Args:
+ x (float): The x coordinate.
+ y (float): The y coordinate.
+ src_crs (str): The source CRS, e.g., "EPSG:4326".
+ dst_crs (str): The destination CRS, e.g., "EPSG:3857".
+
+ Returns:
+ dict: The transformed coordinates in the format of (x, y)
+ """
+ transformer = pyproj.Transformer.from_crs(
+ src_crs, dst_crs, always_xy=True, **kwargs
+ )
+ return transformer.transform(x, y)
+
update_package(out_dir=None, keep=False, **kwargs)
+
+
+¶Updates the package from the GitHub repository without the need to use pip or conda.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
out_dir |
+ str |
+ The output directory. Defaults to None. |
+ None |
+
keep |
+ bool |
+ Whether to keep the downloaded package. Defaults to False. |
+ False |
+
**kwargs |
+ + | Additional keyword arguments to pass to the download_file() function. |
+ {} |
+
samgeo/common.py
def update_package(out_dir=None, keep=False, **kwargs):
+ """Updates the package from the GitHub repository without the need to use pip or conda.
+
+ Args:
+ out_dir (str, optional): The output directory. Defaults to None.
+ keep (bool, optional): Whether to keep the downloaded package. Defaults to False.
+ **kwargs: Additional keyword arguments to pass to the download_file() function.
+ """
+
+ import shutil
+
+ try:
+ if out_dir is None:
+ out_dir = os.getcwd()
+ url = (
+ "https://github.com/opengeos/segment-geospatial/archive/refs/heads/main.zip"
+ )
+ filename = "segment-geospatial-main.zip"
+ download_file(url, filename, **kwargs)
+
+ pkg_dir = os.path.join(out_dir, "segment-geospatial-main")
+ work_dir = os.getcwd()
+ os.chdir(pkg_dir)
+
+ if shutil.which("pip") is None:
+ cmd = "pip3 install ."
+ else:
+ cmd = "pip install ."
+
+ os.system(cmd)
+ os.chdir(work_dir)
+
+ if not keep:
+ shutil.rmtree(pkg_dir)
+ try:
+ os.remove(filename)
+ except:
+ pass
+
+ print("Package updated successfully.")
+
+ except Exception as e:
+ raise Exception(e)
+
vector_to_geojson(filename, output=None, **kwargs)
+
+
+¶Converts a vector file to a geojson file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
filename |
+ str |
+ The vector file path. |
+ required | +
output |
+ str |
+ The output geojson file path. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
dict |
+ The geojson dictionary. |
+
samgeo/common.py
def vector_to_geojson(filename, output=None, **kwargs):
+ """Converts a vector file to a geojson file.
+
+ Args:
+ filename (str): The vector file path.
+ output (str, optional): The output geojson file path. Defaults to None.
+
+ Returns:
+ dict: The geojson dictionary.
+ """
+
+ if not filename.startswith("http"):
+ filename = download_file(filename)
+
+ gdf = gpd.read_file(filename, **kwargs)
+ if output is None:
+ return gdf.__geo_interface__
+ else:
+ gdf.to_file(output, driver="GeoJSON")
+
Contributions are welcome, and they are greatly appreciated! Every +little bit helps, and credit will always be given.
+You can contribute in many ways:
+Report bugs at https://github.com/giswqs/segment-geospatial/issues.
+If you are reporting a bug, please include:
+Look through the GitHub issues for bugs. Anything tagged with bug
and
+help wanted
is open to whoever wants to implement it.
Look through the GitHub issues for features. Anything tagged with
+enhancement
and help wanted
is open to whoever wants to implement it.
segment-geospatial could always use more documentation, +whether as part of the official segment-geospatial docs, +in docstrings, or even on the web in blog posts, articles, and such.
+The best way to send feedback is to file an issue at +https://github.com/giswqs/segment-geospatial/issues.
+If you are proposing a feature:
+Ready to contribute? Here's how to set up segment-geospatial for local development.
+Fork the segment-geospatial repo on GitHub.
+Clone your fork locally:
+1 |
|
Install your local copy into a virtualenv. Assuming you have + virtualenvwrapper installed, this is how you set up your fork for + local development:
+1 +2 +3 |
|
Create a branch for local development:
+1 |
|
Now you can make your changes locally.
+When you're done making changes, check that your changes pass flake8 + and the tests, including testing other Python versions with tox:
+1 +2 +3 |
|
To get flake8 and tox, just pip install them into your virtualenv.
+Commit your changes and push your branch to GitHub:
+1 +2 +3 |
|
Submit a pull request through the GitHub website.
+Before you submit a pull request, check that it meets these guidelines:
+Unit tests are in the tests
folder. If you add new functionality to the package, please add a unit test for it. You can either add the test to an existing test file or create a new one. For example, if you add a new function to samgeo/samgeo.py
, you can add the unit test to tests/test_samgeo.py
. If you add a new module to samgeo/<MODULE-NAME>
, you can create a new test file in tests/test_<MODULE-NAME>
. Please refer to tests/test_samgeo.py
for examples. For more information about unit testing, please refer to this tutorial - Getting Started With Testing in Python.
To run the unit tests, navigate to the root directory of the package and run the following command:
+1 |
|
If you PR involves adding new dependencies, please make sure that the new dependencies are available on both PyPI and conda-forge. Search here to see if the package is available on conda-forge. If the package is not available on conda-forge, it can't be added as a required dependency in requirements.txt
. Instead, it should be added as an optional dependency in requirements_dev.txt
.
If the package is available on PyPI and conda-forge, but if it is challenging to install the package on some operating systems, we would recommend adding the package as an optional dependency in requirements_dev.txt
rather than a required dependency in requirements.txt
.
The dependencies required for building the documentation should be added to requirements_docs.txt
. In most cases, contributors do not need to add new dependencies to requirements_docs.txt
unless the documentation fails to build due to missing dependencies.
The notebook shows step-by-step instructions for using the Segment Anything Model (SAM) with ArcGIS Pro. Check out the YouTube tutorial here and the Resources for Unlocking the Power of Deep Learning Applications Using ArcGIS. Credit goes to Esri.
+ +Open Windows Registry Editor (regedit.exe
) and navigate to Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem
. Change the value of LongPathsEnabled
to 1
. See this screenshot. This is a known issue with the deep learning libraries for ArcGIS Pro 3.1. A future release might fix this issue.
Navigate to the Start Menu -> All apps -> ArcGIS folder, then open the Python Command Prompt.
+Create a new conda environment and install mamba and Python 3.9.x from the Esri Anaconda channel. Mamba is a drop-in replacement for conda that is mach faster for installing Python packages and their dependencies.
+conda create conda-forge::mamba esri::python --name samgeo
Activate the new conda environment.
+conda activate samgeo
Install arcpy, deep-learning-essentials, segment-geospatial, and other dependencies (~4GB download).
+mamba install arcpy deep-learning-essentials leafmap localtileserver segment-geospatial -c esri -c conda-forge
Activate the new environment in ArcGIS Pro.
+proswap samgeo
Close the Python Command Prompt and open ArcGIS Pro.
+Download this notebook and run it in ArcGIS Pro.
+import os
+import leafmap
+from samgeo import SamGeo
+
+%matplotlib inline
+
In this example, we will use the high-resolution aerial imagery from the USDA National Agricultural Imagery Program (NAIP). You can download NAIP imagery using the USDA Data Gateway or the USDA NCRS Box Drive. I have downloaded some NAIP imagery and clipped them to a smaller area, which are available here.
+workspace = os.path.dirname(arcpy.env.workspace)
+os.chdir(workspace)
+arcpy.env.overwriteOutput = True
+
leafmap.download_file(
+ url="https://github.com/opengeos/data/blob/main/naip/buildings.tif",
+ quiet=True,
+ overwrite=True,
+)
+
leafmap.download_file(
+ url="https://github.com/opengeos/data/blob/main/naip/agriculture.tif",
+ quiet=True,
+ overwrite=True,
+)
+
leafmap.download_file(
+ url="https://github.com/opengeos/data/blob/main/naip/water.tif",
+ quiet=True,
+ overwrite=True,
+)
+
Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
+sam = SamGeo(
+ model_type="vit_h",
+ sam_kwargs=None,
+)
+
Specify the file path to the image we downloaded earlier.
+image = "agriculture.tif"
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Segment the image and save the results to a GeoTIFF file. Set unique=True
to assign a unique ID to each object.
sam.generate(image, output="ag_masks.tif", foreground=True, unique=True)
+
If you run into GPU memory errors, uncomment the following code block and run it to empty cuda cache then rerun the code block above.
+# sam.clear_cuda_cache()
+
Show the segmentation result as a grayscale image.
+sam.show_masks(cmap="binary_r")
+
Show the object annotations (objects with random color) on the map.
+sam.show_anns(axis="off", alpha=1, output="ag_annotations.tif")
+
Add layers to ArcGIS Pro.
+m = leafmap.arc_active_map()
+
m.addDataFromPath(os.path.join(workspace, "agriculture.tif"))
+
m.addDataFromPath(os.path.join(workspace, "ag_annotations.tif"))
+
Convert the object annotations to vector format, such as GeoPackage, Shapefile, or GeoJSON.
+in_raster = os.path.join(workspace, "ag_masks.tif")
+out_shp = os.path.join(workspace, "ag_masks.shp")
+
arcpy.conversion.RasterToPolygon(in_raster, out_shp)
+
image = "water.tif"
+
sam.generate(image, output="water_masks.tif", foreground=True, unique=True)
+
# sam.clear_cuda_cache()
+
sam.show_masks(cmap="binary_r")
+
sam.show_anns(axis="off", alpha=1, output="water_annotations.tif")
+
m.addDataFromPath(os.path.join(workspace, "water.tif"))
+
m.addDataFromPath(os.path.join(workspace, "water_annotations.tif"))
+
in_raster = os.path.join(workspace, "water_masks.tif")
+out_shp = os.path.join(workspace, "water_masks.shp")
+
arcpy.conversion.RasterToPolygon(in_raster, out_shp)
+
There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
+sam_kwargs = {
+ "points_per_side": 32,
+ "pred_iou_thresh": 0.86,
+ "stability_score_thresh": 0.92,
+ "crop_n_layers": 1,
+ "crop_n_points_downscale_factor": 2,
+ "min_mask_region_area": 100,
+}
+
sam = SamGeo(
+ model_type="vit_h",
+ sam_kwargs=sam_kwargs,
+)
+
sam.generate('agriculture.tif', output="ag_masks2.tif", foreground=True)
+
sam.show_masks(cmap="binary_r")
+
sam.show_anns(axis="off", alpha=0.5, output="ag_annotations2.tif")
+
This notebook shows how to segment objects from an image using the Segment Anything Model (SAM) with a few lines of code.
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
The notebook is adapted from segment-anything/notebooks/automatic_mask_generator_example.ipynb, but I have made it much easier to save the segmentation results and visualize them.
+Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial
+
import os
+import leafmap
+from samgeo import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff
+
m = leafmap.Map(center=[37.8713, -122.2580], zoom=17, height="800px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+if m.user_roi_bounds() is not None:
+ bbox = m.user_roi_bounds()
+else:
+ bbox = [-122.2659, 37.8682, -122.2521, 37.8741]
+
image = "satellite.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=17, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
+sam = SamGeo(
+ model_type="vit_h",
+ sam_kwargs=None,
+)
+
Segment the image and save the results to a GeoTIFF file. Set unique=True
to assign a unique ID to each object.
sam.generate(image, output="masks.tif", foreground=True, unique=True)
+
sam.show_masks(cmap="binary_r")
+
Show the object annotations (objects with random color) on the map.
+sam.show_anns(axis="off", alpha=1, output="annotations.tif")
+
Compare images with a slider.
+leafmap.image_comparison(
+ "satellite.tif",
+ "annotations.tif",
+ label1="Satellite Image",
+ label2="Image Segmentation",
+)
+
Add image to the map.
+m.add_raster("annotations.tif", alpha=0.5, layer_name="Masks")
+m
+
Convert the object annotations to vector format, such as GeoPackage, Shapefile, or GeoJSON.
+sam.tiff_to_vector("masks.tif", "masks.gpkg")
+
There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
+sam_kwargs = {
+ "points_per_side": 32,
+ "pred_iou_thresh": 0.86,
+ "stability_score_thresh": 0.92,
+ "crop_n_layers": 1,
+ "crop_n_points_downscale_factor": 2,
+ "min_mask_region_area": 100,
+}
+
sam = SamGeo(
+ model_type="vit_h",
+ sam_kwargs=sam_kwargs,
+)
+
sam.generate(image, output="masks2.tif", foreground=True)
+
sam.show_masks(cmap="binary_r")
+
sam.show_anns(axis="off", opacity=1, output="annotations2.tif")
+
Compare images with a slider.
+leafmap.image_comparison(
+ image,
+ "annotations.tif",
+ label1="Image",
+ label2="Image Segmentation",
+)
+
Overlay the annotations on the image and use the slider to change the opacity interactively.
+overlay_images(image, "annotations2.tif", backend="TkAgg")
+
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial
+
import os
+import leafmap
+from samgeo.hq_sam import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff
+
m = leafmap.Map(center=[37.8713, -122.2580], zoom=17, height="800px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+if m.user_roi_bounds() is not None:
+ bbox = m.user_roi_bounds()
+else:
+ bbox = [-122.2659, 37.8682, -122.2521, 37.8741]
+
image = "satellite.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=17, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
+sam = SamGeo(
+ model_type="vit_h", # can be vit_h, vit_b, vit_l, vit_tiny
+ sam_kwargs=None,
+)
+
Segment the image and save the results to a GeoTIFF file. Set unique=True
to assign a unique ID to each object.
sam.generate(image, output="masks.tif", foreground=True, unique=True)
+
sam.show_masks(cmap="binary_r")
+
Show the object annotations (objects with random color) on the map.
+sam.show_anns(axis="off", alpha=1, output="annotations.tif")
+
Compare images with a slider.
+leafmap.image_comparison(
+ "satellite.tif",
+ "annotations.tif",
+ label1="Satellite Image",
+ label2="Image Segmentation",
+)
+
Add image to the map.
+m.add_raster("annotations.tif", alpha=0.5, layer_name="Masks")
+m
+
Convert the object annotations to vector format, such as GeoPackage, Shapefile, or GeoJSON.
+sam.tiff_to_vector("masks.tif", "masks.gpkg")
+
There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
+sam_kwargs = {
+ "points_per_side": 32,
+ "pred_iou_thresh": 0.86,
+ "stability_score_thresh": 0.92,
+ "crop_n_layers": 1,
+ "crop_n_points_downscale_factor": 2,
+ "min_mask_region_area": 100,
+}
+
sam = SamGeo(
+ model_type="vit_h",
+ sam_kwargs=sam_kwargs,
+)
+
sam.generate(image, output="masks2.tif", foreground=True)
+
sam.show_masks(cmap="binary_r")
+
sam.show_anns(axis="off", opacity=1, output="annotations2.tif")
+
Compare images with a slider.
+leafmap.image_comparison(
+ image,
+ "annotations.tif",
+ label1="Image",
+ label2="Image Segmentation",
+)
+
Overlay the annotations on the image and use the slider to change the opacity interactively.
+overlay_images(image, "annotations2.tif", backend="TkAgg")
+
This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM).
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial
+
import leafmap
+from samgeo import tms_to_geotiff
+from samgeo import SamGeo
+
m = leafmap.Map(center=[-22.17615, -51.253043], zoom=18, height="800px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+bbox = m.user_roi_bounds()
+if bbox is None:
+ bbox = [-51.2565, -22.1777, -51.2512, -22.175]
+
image = "Image.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=19, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
+Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
+Set automatic=False
to disable the SamAutomaticMaskGenerator
and enable the SamPredictor
.
sam = SamGeo(
+ model_type="vit_h",
+ automatic=False,
+ sam_kwargs=None,
+)
+
Specify the image to segment.
+sam.set_image(image)
+
Display the map. Use the drawing tools to draw some rectangles around the features you want to extract, such as trees, buildings.
+m
+
If no rectangles are drawn, the default bounding boxes will be used as follows:
+if m.user_rois is not None:
+ boxes = m.user_rois
+else:
+ boxes = [
+ [-51.2546, -22.1771, -51.2541, -22.1767],
+ [-51.2538, -22.1764, -51.2535, -22.1761],
+ ]
+
Use the predict()
method to segment the image with specified bounding boxes. The boxes
parameter accepts a list of bounding box coordinates in the format of [[left, bottom, right, top], [left, bottom, right, top], ...], a GeoJSON dictionary, or a file path to a GeoJSON file.
sam.predict(boxes=boxes, point_crs="EPSG:4326", output="mask.tif", dtype="uint8")
+
Add the segmented image to the map.
+m.add_raster('mask.tif', cmap='viridis', nodata=0, layer_name='Mask')
+m
+
Alternatively, you can specify a file path to a vector file. Let's download a sample vector file from GitHub.
+url = 'https://opengeos.github.io/data/sam/tree_boxes.geojson'
+geojson = "tree_boxes.geojson"
+leafmap.download_file(url, geojson)
+
Display the vector data on the map.
+m = leafmap.Map()
+m.add_raster("Image.tif", layer_name="image")
+style = {
+ "color": "#ffff00",
+ "weight": 2,
+ "fillColor": "#7c4185",
+ "fillOpacity": 0,
+}
+m.add_vector(geojson, style=style, zoom_to_layer=True, layer_name="Bounding boxes")
+m
+
Segment the image using the specified file path to the vector mask.
+sam.predict(boxes=geojson, point_crs="EPSG:4326", output="mask2.tif", dtype="uint8")
+
Display the segmented masks on the map.
+m.add_raster("mask2.tif", cmap="Greens", nodata=0, opacity=0.5, layer_name="Tree masks")
+m
+
This notebook shows how to generate object masks from input prompts with the Segment Anything Model (SAM).
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
The notebook is adapted from segment-anything/notebooks/predictor_example.ipynb, but I have made it much easier to save the segmentation results and visualize them.
+Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial
+
import os
+import leafmap
+from samgeo import SamGeo, tms_to_geotiff
+
m = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height="800px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+if m.user_roi is not None:
+ bbox = m.user_roi_bounds()
+else:
+ bbox = [-122.1497, 37.6311, -122.1203, 37.6458]
+
image = "satellite.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=16, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
+Set automatic=False
to disable the SamAutomaticMaskGenerator
and enable the SamPredictor
.
sam = SamGeo(
+ model_type="vit_h",
+ automatic=False,
+ sam_kwargs=None,
+)
+
Specify the image to segment.
+sam.set_image(image)
+
A single point can be used to segment an object. The point can be specified as a tuple of (x, y), such as (col, row) or (lon, lat). The points can also be specified as a file path to a vector dataset. For non (col, row) input points, specify the point_crs
parameter, which will automatically transform the points to the image column and row coordinates.
Try a single point input:
+point_coords = [[-122.1419, 37.6383]]
+sam.predict(point_coords, point_labels=1, point_crs="EPSG:4326", output="mask1.tif")
+m.add_raster("mask1.tif", layer_name="Mask1", nodata=0, cmap="Blues", opacity=1)
+m
+
Try multiple points input:
+point_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]]
+sam.predict(point_coords, point_labels=1, point_crs="EPSG:4326", output="mask2.tif")
+m.add_raster("mask2.tif", layer_name="Mask2", nodata=0, cmap="Greens", opacity=1)
+m
+
Display the interactive map and use the marker tool to draw points on the map. Then click on the Segment
button to segment the objects. The results will be added to the map automatically. Click on the Reset
button to clear the points and the results.
m = sam.show_map()
+m
+
This notebook shows how to generate object masks from input prompts with the High-Quality Segment Anything Model (HQ-SAM).
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial
+
import os
+import leafmap
+from samgeo.hq_sam import SamGeo, tms_to_geotiff
+
m = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height="800px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+if m.user_roi is not None:
+ bbox = m.user_roi_bounds()
+else:
+ bbox = [-122.1497, 37.6311, -122.1203, 37.6458]
+
image = "satellite.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=16, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
+Set automatic=False
to disable the SamAutomaticMaskGenerator
and enable the SamPredictor
.
sam = SamGeo(
+ model_type="vit_h", # can be vit_h, vit_b, vit_l, vit_tiny
+ automatic=False,
+ sam_kwargs=None,
+)
+
Specify the image to segment.
+sam.set_image(image)
+
A single point can be used to segment an object. The point can be specified as a tuple of (x, y), such as (col, row) or (lon, lat). The points can also be specified as a file path to a vector dataset. For non (col, row) input points, specify the point_crs
parameter, which will automatically transform the points to the image column and row coordinates.
Try a single point input:
+point_coords = [[-122.1419, 37.6383]]
+sam.predict(point_coords, point_labels=1, point_crs="EPSG:4326", output="mask1.tif")
+m.add_raster("mask1.tif", layer_name="Mask1", nodata=0, cmap="Blues", opacity=1)
+m
+
Try multiple points input:
+point_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]]
+sam.predict(point_coords, point_labels=1, point_crs="EPSG:4326", output="mask2.tif")
+m.add_raster("mask2.tif", layer_name="Mask2", nodata=0, cmap="Greens", opacity=1)
+m
+
Display the interactive map and use the marker tool to draw points on the map. Then click on the Segment
button to segment the objects. The results will be added to the map automatically. Click on the Reset
button to clear the points and the results.
m = sam.show_map()
+m
+
This notebook shows how to use segment satellite imagery using the Segment Anything Model (SAM) with a few lines of code.
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial
+
import os
+import leafmap
+from samgeo import SamGeoPredictor, tms_to_geotiff, get_basemaps
+from segment_anything import sam_model_registry
+
zoom = 16
+m = leafmap.Map(center=[45, -123], zoom=zoom)
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+if m.user_roi_bounds() is not None:
+ bbox = m.user_roi_bounds()
+else:
+ bbox = [-123.0127, 44.9957, -122.9874, 45.0045]
+
Download maps tiles and mosaic them into a single GeoTIFF file
+image = "satellite.tif"
+# image = '/path/to/your/own/image.tif'
+
Besides the satellite
basemap, you can use any of the following basemaps returned by the get_basemaps()
function:
# get_basemaps().keys()
+
Specify the basemap as the source.
+tms_to_geotiff(
+ output=image, bbox=bbox, zoom=zoom + 1, source="Satellite", overwrite=True
+)
+
m.add_raster(image, layer_name="Image")
+m
+
Use the draw tools to draw a rectangle from which to subset segmentations on the map
+if m.user_roi_bounds() is not None:
+ clip_box = m.user_roi_bounds()
+else:
+ clip_box = [-123.0064, 44.9988, -123.0005, 45.0025]
+
clip_box
+
out_dir = os.path.join(os.path.expanduser("~"), "Downloads")
+checkpoint = os.path.join(out_dir, "sam_vit_h_4b8939.pth")
+
import cv2
+
+img_arr = cv2.imread(image)
+
+model_type = "vit_h"
+
+sam = sam_model_registry[model_type](checkpoint=checkpoint)
+
+predictor = SamGeoPredictor(sam)
+
+predictor.set_image(img_arr)
+
+masks, _, _ = predictor.predict(src_fp=image, geo_box=clip_box)
+
masks_img = "preds.tif"
+predictor.masks_to_geotiff(image, masks_img, masks.astype("uint8"))
+
vector = "feats.geojson"
+gdf = predictor.geotiff_to_geojson(masks_img, vector, bidx=1)
+gdf.plot()
+
style = {
+ "color": "#3388ff",
+ "weight": 2,
+ "fillColor": "#7c4185",
+ "fillOpacity": 0.5,
+}
+m.add_vector(vector, layer_name="Vector", style=style)
+m
+
This notebook shows how to use segment satellite imagery using the Segment Anything Model (SAM) with a few lines of code.
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial
+
import os
+import leafmap
+from samgeo import SamGeo, tms_to_geotiff, get_basemaps
+
m = leafmap.Map(center=[29.676840, -95.369222], zoom=19)
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+if m.user_roi_bounds() is not None:
+ bbox = m.user_roi_bounds()
+else:
+ bbox = [-95.3704, 29.6762, -95.368, 29.6775]
+
Download maps tiles and mosaic them into a single GeoTIFF file
+image = "satellite.tif"
+
Besides the satellite
basemap, you can use any of the following basemaps returned by the get_basemaps()
function:
# get_basemaps().keys()
+
Specify the basemap as the source.
+tms_to_geotiff(output=image, bbox=bbox, zoom=20, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False # turn off the basemap
+m.add_raster(image, layer_name="Image")
+m
+
sam = SamGeo(
+ model_type="vit_h",
+ checkpoint="sam_vit_h_4b8939.pth",
+ sam_kwargs=None,
+)
+
Set batch=True
to segment the image in batches. This is useful for large images that cannot fit in memory.
mask = "segment.tif"
+sam.generate(
+ image, mask, batch=True, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255
+)
+
Save the segmentation results as a GeoPackage file.
+vector = "segment.gpkg"
+sam.tiff_to_gpkg(mask, vector, simplify_tolerance=None)
+
You can also save the segmentation results as any vector data format supported by GeoPandas.
+shapefile = "segment.shp"
+sam.tiff_to_vector(mask, shapefile)
+
style = {
+ "color": "#3388ff",
+ "weight": 2,
+ "fillColor": "#7c4185",
+ "fillOpacity": 0.5,
+}
+m.add_vector(vector, layer_name="Vector", style=style)
+m
+
This notebook shows how to map swimming pools with text prompts and the Segment Anything Model (SAM).
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial groundingdino-py leafmap localtileserver
+
import leafmap
+from samgeo import tms_to_geotiff
+from samgeo.text_sam import LangSAM
+
m = leafmap.Map(center=[34.040984, -118.491668], zoom=19, height="600px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+bbox = m.user_roi_bounds()
+if bbox is None:
+ bbox = [-118.4932, 34.0404, -118.4903, 34.0417]
+
image = "Image.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=19, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
+sam = LangSAM()
+
text_prompt = "swimming pool"
+
Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.
+box_threshold
: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.
text_threshold
: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.
Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.
+sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24)
+
Show the result with bounding boxes on the map.
+sam.show_anns(
+ cmap='Blues',
+ box_color='red',
+ title='Automatic Segmentation of Swimming Pools',
+ blend=True,
+)
+
Show the result without bounding boxes on the map.
+sam.show_anns(
+ cmap='Blues',
+ add_boxes=False,
+ alpha=0.5,
+ title='Automatic Segmentation of Swimming Pools',
+)
+
Show the result as a grayscale image.
+sam.show_anns(
+ cmap='Greys_r',
+ add_boxes=False,
+ alpha=1,
+ title='Automatic Segmentation of Swimming Pools',
+ blend=False,
+ output='pools.tif',
+)
+
Convert the result to a vector format.
+sam.raster_to_vector("pools.tif", "pools.shp")
+
Show the results on the interactive map.
+m.add_raster("pools.tif", layer_name="Pools", palette="Blues", opacity=0.5, nodata=0)
+style = {
+ "color": "#3388ff",
+ "weight": 2,
+ "fillColor": "#7c4185",
+ "fillOpacity": 0.5,
+}
+m.add_vector("pools.shp", layer_name="Vector", style=style)
+m
+
sam.show_map()
+
This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM).
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial groundingdino-py leafmap localtileserver
+
import leafmap
+from samgeo import tms_to_geotiff
+from samgeo.text_sam import LangSAM
+
m = leafmap.Map(center=[-22.17615, -51.253043], zoom=18, height="800px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+bbox = m.user_roi_bounds()
+if bbox is None:
+ bbox = [-51.2565, -22.1777, -51.2512, -22.175]
+
image = "Image.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=19, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
+sam = LangSAM()
+
text_prompt = "tree"
+
Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.
+box_threshold
: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.
text_threshold
: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.
Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.
+sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24)
+
Show the result with bounding boxes on the map.
+sam.show_anns(
+ cmap='Greens',
+ box_color='red',
+ title='Automatic Segmentation of Trees',
+ blend=True,
+)
+
Show the result without bounding boxes on the map.
+sam.show_anns(
+ cmap='Greens',
+ add_boxes=False,
+ alpha=0.5,
+ title='Automatic Segmentation of Trees',
+)
+
Show the result as a grayscale image.
+sam.show_anns(
+ cmap='Greys_r',
+ add_boxes=False,
+ alpha=1,
+ title='Automatic Segmentation of Trees',
+ blend=False,
+ output='trees.tif',
+)
+
Convert the result to a vector format.
+sam.raster_to_vector("trees.tif", "trees.shp")
+
Show the results on the interactive map.
+m.add_raster("trees.tif", layer_name="Trees", palette="Greens", opacity=0.5, nodata=0)
+style = {
+ "color": "#3388ff",
+ "weight": 2,
+ "fillColor": "#7c4185",
+ "fillOpacity": 0.5,
+}
+m.add_vector("trees.shp", layer_name="Vector", style=style)
+m
+
sam.show_map()
+
This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM).
+Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
+# %pip install segment-geospatial groundingdino-py leafmap localtileserver
+
import leafmap
+from samgeo import tms_to_geotiff, split_raster
+from samgeo.text_sam import LangSAM
+
m = leafmap.Map(center=[-22.1278, -51.4430], zoom=17, height="800px")
+m.add_basemap("SATELLITE")
+m
+
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
+bbox = m.user_roi_bounds()
+if bbox is None:
+ bbox = [-51.4494, -22.1307, -51.4371, -22.1244]
+
image = "Image.tif"
+tms_to_geotiff(output=image, bbox=bbox, zoom=19, source="Satellite", overwrite=True)
+
You can also use your own image. Uncomment and run the following cell to use your own image.
+# image = '/path/to/your/own/image.tif'
+
Display the downloaded image on the map.
+m.layers[-1].visible = False
+m.add_raster(image, layer_name="Image")
+m
+
split_raster(image, out_dir="tiles", tile_size=(1000, 1000), overlap=0)
+
The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
+sam = LangSAM()
+
text_prompt = "tree"
+
Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.
+box_threshold
: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.
text_threshold
: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.
Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.
+sam.predict_batch(
+ images='tiles',
+ out_dir='masks',
+ text_prompt=text_prompt,
+ box_threshold=0.24,
+ text_threshold=0.24,
+ mask_multiplier=255,
+ dtype='uint8',
+ merge=True,
+ verbose=True,
+)
+
m.add_raster('masks/merged.tif', cmap='viridis', nodata=0, layer_name='Mask')
+m.add_layer_manager()
+m
+
Segment remote sensing imagery with HQ-SAM (High Quality Segment Anything Model). +See https://github.com/SysCV/sam-hq
+ + + +
+SamGeo
+
+
+
+¶The main class for segmenting geospatial data with the Segment Anything Model (SAM). See +https://github.com/facebookresearch/segment-anything for details.
+ +samgeo/hq_sam.py
class SamGeo:
+ """The main class for segmenting geospatial data with the Segment Anything Model (SAM). See
+ https://github.com/facebookresearch/segment-anything for details.
+ """
+
+ def __init__(
+ self,
+ model_type="vit_h",
+ automatic=True,
+ device=None,
+ checkpoint_dir=None,
+ hq=False,
+ sam_kwargs=None,
+ **kwargs,
+ ):
+ """Initialize the class.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
+ The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
+ device (str, optional): The device to use. It can be one of the following: cpu, cuda.
+ Defaults to None, which will use cuda if available.
+ hq (bool, optional): Whether to use the HQ-SAM model. Defaults to False.
+ checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:
+ sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
+ Defaults to None. See https://bit.ly/3VrpxUh for more details.
+ sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
+ The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
+
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+
+ """
+
+ hq = True # Using HQ-SAM
+ if "checkpoint" in kwargs:
+ checkpoint = kwargs["checkpoint"]
+ if not os.path.exists(checkpoint):
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+ kwargs.pop("checkpoint")
+ else:
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+
+ # Use cuda if available
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if device == "cuda":
+ torch.cuda.empty_cache()
+
+ self.checkpoint = checkpoint
+ self.model_type = model_type
+ self.device = device
+ self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
+ self.source = None # Store the input image path
+ self.image = None # Store the input image as a numpy array
+ # Store the masks as a list of dictionaries. Each mask is a dictionary
+ # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
+ self.masks = None
+ self.objects = None # Store the mask objects as a numpy array
+ # Store the annotations (objects with random color) as a numpy array.
+ self.annotations = None
+
+ # Store the predicted masks, iou_predictions, and low_res_masks
+ self.prediction = None
+ self.scores = None
+ self.logits = None
+
+ # Build the SAM model
+ self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
+ self.sam.to(device=self.device)
+ # Use optional arguments for fine-tuning the SAM model
+ sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}
+
+ if automatic:
+ # Segment the entire image using the automatic mask generator
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)
+ else:
+ # Segment selected objects using input prompts
+ self.predictor = SamPredictor(self.sam, **sam_kwargs)
+
+ def __call__(
+ self,
+ image,
+ foreground=True,
+ erosion_kernel=(3, 3),
+ mask_multiplier=255,
+ **kwargs,
+ ):
+ """Generate masks for the input tile. This function originates from the segment-anything-eo repository.
+ See https://bit.ly/41pwiHw
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ """
+ h, w, _ = image.shape
+
+ masks = self.mask_generator.generate(image)
+
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=np.uint8)
+ else:
+ resulting_mask = np.ones((h, w), dtype=np.uint8)
+ resulting_borders = np.zeros((h, w), dtype=np.uint8)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(np.uint8)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(np.uint8)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(np.uint8)
+ resulting_borders = (resulting_borders > 0).astype(np.uint8)
+ resulting_mask_with_borders = resulting_mask - resulting_borders
+ return resulting_mask_with_borders * mask_multiplier
+
+ def generate(
+ self,
+ source,
+ output=None,
+ foreground=True,
+ batch=False,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ unique=True,
+ **kwargs,
+ ):
+ """Generate masks for the input image.
+
+ Args:
+ source (str | np.ndarray): The path to the input image or the input image as a numpy array.
+ output (str, optional): The path to the output image. Defaults to None.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ The parameter is ignored if unique is True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
+
+ """
+
+ if isinstance(source, str):
+ if source.startswith("http"):
+ source = download_file(source)
+
+ if not os.path.exists(source):
+ raise ValueError(f"Input path {source} does not exist.")
+
+ if batch: # Subdivide the image into tiles and segment each tile
+ self.batch = True
+ self.source = source
+ self.masks = output
+ return tiff_to_tiff(
+ source,
+ output,
+ self,
+ foreground=foreground,
+ erosion_kernel=erosion_kernel,
+ mask_multiplier=mask_multiplier,
+ **kwargs,
+ )
+
+ image = cv2.imread(source)
+ elif isinstance(source, np.ndarray):
+ image = source
+ source = None
+ else:
+ raise ValueError("Input source must be either a path or a numpy array.")
+
+ self.source = source # Store the input image path
+ self.image = image # Store the input image as a numpy array
+ mask_generator = self.mask_generator # The automatic mask generator
+ masks = mask_generator.generate(image) # Segment the input image
+ self.masks = masks # Store the masks as a list of dictionaries
+ self.batch = False
+
+ if output is not None:
+ # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+ self.save_masks(
+ output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
+ )
+
+ def save_masks(
+ self,
+ output=None,
+ foreground=True,
+ unique=True,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ **kwargs,
+ ):
+ """Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+
+ Args:
+ output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+
+ """
+
+ if self.masks is None:
+ raise ValueError("No masks found. Please run generate() first.")
+
+ h, w, _ = self.image.shape
+ masks = self.masks
+
+ # Set output image data type based on the number of objects
+ if len(masks) < 255:
+ dtype = np.uint8
+ elif len(masks) < 65535:
+ dtype = np.uint16
+ else:
+ dtype = np.uint32
+
+ # Generate a mask of objects with unique values
+ if unique:
+ # Sort the masks by area in ascending order
+ sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)
+
+ # Create an output image with the same size as the input image
+ objects = np.zeros(
+ (
+ sorted_masks[0]["segmentation"].shape[0],
+ sorted_masks[0]["segmentation"].shape[1],
+ )
+ )
+ # Assign a unique value to each object
+ for index, ann in enumerate(sorted_masks):
+ m = ann["segmentation"]
+ objects[m] = index + 1
+
+ # Generate a binary mask
+ else:
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=dtype)
+ else:
+ resulting_mask = np.ones((h, w), dtype=dtype)
+ resulting_borders = np.zeros((h, w), dtype=dtype)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(dtype)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(dtype)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(dtype)
+ resulting_borders = (resulting_borders > 0).astype(dtype)
+ objects = resulting_mask - resulting_borders
+ objects = objects * mask_multiplier
+
+ objects = objects.astype(dtype)
+ self.objects = objects
+
+ if output is not None: # Save the output image
+ array_to_image(self.objects, output, self.source, **kwargs)
+
+ def show_masks(
+ self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs
+ ):
+ """Show the binary mask or the mask of objects with unique values.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ cmap (str, optional): The colormap. Defaults to "binary_r".
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
+ **kwargs: Other arguments for save_masks().
+ """
+
+ import matplotlib.pyplot as plt
+
+ if self.batch:
+ self.objects = cv2.imread(self.masks)
+ else:
+ if self.objects is None:
+ self.save_masks(foreground=foreground, **kwargs)
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.objects, cmap=cmap)
+ plt.axis(axis)
+ plt.show()
+
+ def show_anns(
+ self,
+ figsize=(12, 10),
+ axis="off",
+ alpha=0.35,
+ output=None,
+ blend=True,
+ **kwargs,
+ ):
+ """Show the annotations (objects with random color) on the input image.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
+ output (str, optional): The path to the output image. Defaults to None.
+ blend (bool, optional): Whether to show the input image. Defaults to True.
+ """
+
+ import matplotlib.pyplot as plt
+
+ anns = self.masks
+
+ if self.image is None:
+ print("Please run generate() first.")
+ return
+
+ if anns is None or len(anns) == 0:
+ return
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.image)
+
+ sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
+
+ ax = plt.gca()
+ ax.set_autoscale_on(False)
+
+ img = np.ones(
+ (
+ sorted_anns[0]["segmentation"].shape[0],
+ sorted_anns[0]["segmentation"].shape[1],
+ 4,
+ )
+ )
+ img[:, :, 3] = 0
+ for ann in sorted_anns:
+ m = ann["segmentation"]
+ color_mask = np.concatenate([np.random.random(3), [alpha]])
+ img[m] = color_mask
+ ax.imshow(img)
+
+ if "dpi" not in kwargs:
+ kwargs["dpi"] = 100
+
+ if "bbox_inches" not in kwargs:
+ kwargs["bbox_inches"] = "tight"
+
+ plt.axis(axis)
+
+ self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
+
+ if output is not None:
+ if blend:
+ array = blend_images(
+ self.annotations, self.image, alpha=alpha, show=False
+ )
+ else:
+ array = self.annotations
+ array_to_image(array, output, self.source)
+
+ def set_image(self, image, image_format="RGB"):
+ """Set the input image as a numpy array.
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ image_format (str, optional): The image format, can be RGB or BGR. Defaults to "RGB".
+ """
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+
+ image = cv2.imread(image)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ self.image = image
+ elif isinstance(image, np.ndarray):
+ pass
+ else:
+ raise ValueError("Input image must be either a path or a numpy array.")
+
+ self.predictor.set_image(image, image_format=image_format)
+
+ def save_prediction(
+ self,
+ output,
+ index=None,
+ mask_multiplier=255,
+ dtype=np.float32,
+ vector=None,
+ simplify_tolerance=None,
+ **kwargs,
+ ):
+ """Save the predicted mask to the output path.
+
+ Args:
+ output (str): The path to the output image.
+ index (int, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ vector (str, optional): The path to the output vector file. Defaults to None.
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+
+ """
+ if self.scores is None:
+ raise ValueError("No predictions found. Please run predict() first.")
+
+ if index is None:
+ index = self.scores.argmax(axis=0)
+
+ array = self.masks[index] * mask_multiplier
+ self.prediction = array
+ array_to_image(array, output, self.source, dtype=dtype, **kwargs)
+
+ if vector is not None:
+ raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)
+
+ def predict(
+ self,
+ point_coords=None,
+ point_labels=None,
+ boxes=None,
+ point_crs=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+ output=None,
+ index=None,
+ mask_multiplier=255,
+ dtype="float32",
+ return_results=False,
+ **kwargs,
+ ):
+ """Predict masks for the given input prompts, using the currently set image.
+
+ Args:
+ point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
+ dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
+ point_labels (list | int | np.ndarray, optional): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a background point.
+ point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
+ boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
+ multimask_output (bool, optional): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool, optional): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+ output (str, optional): The path to the output image. Defaults to None.
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.
+
+ """
+
+ if isinstance(boxes, str):
+ gdf = gpd.read_file(boxes)
+ if gdf.crs is not None:
+ gdf = gdf.to_crs("epsg:4326")
+ boxes = gdf.geometry.bounds.values.tolist()
+ elif isinstance(boxes, dict):
+ import json
+
+ geojson = json.dumps(boxes)
+ gdf = gpd.read_file(geojson, driver="GeoJSON")
+ boxes = gdf.geometry.bounds.values.tolist()
+
+ if isinstance(point_coords, str):
+ point_coords = vector_to_geojson(point_coords)
+
+ if isinstance(point_coords, dict):
+ point_coords = geojson_to_coords(point_coords)
+
+ if hasattr(self, "point_coords"):
+ point_coords = self.point_coords
+
+ if hasattr(self, "point_labels"):
+ point_labels = self.point_labels
+
+ if (point_crs is not None) and (point_coords is not None):
+ point_coords = coords_to_xy(self.source, point_coords, point_crs)
+
+ if isinstance(point_coords, list):
+ point_coords = np.array(point_coords)
+
+ if point_coords is not None:
+ if point_labels is None:
+ point_labels = [1] * len(point_coords)
+ elif isinstance(point_labels, int):
+ point_labels = [point_labels] * len(point_coords)
+
+ if isinstance(point_labels, list):
+ if len(point_labels) != len(point_coords):
+ if len(point_labels) == 1:
+ point_labels = point_labels * len(point_coords)
+ else:
+ raise ValueError(
+ "The length of point_labels must be equal to the length of point_coords."
+ )
+ point_labels = np.array(point_labels)
+
+ predictor = self.predictor
+
+ input_boxes = None
+ if isinstance(boxes, list) and (point_crs is not None):
+ coords = bbox_to_xy(self.source, boxes, point_crs)
+ input_boxes = np.array(coords)
+ if isinstance(coords[0], int):
+ input_boxes = input_boxes[None, :]
+ else:
+ input_boxes = torch.tensor(input_boxes, device=self.device)
+ input_boxes = predictor.transform.apply_boxes_torch(
+ input_boxes, self.image.shape[:2]
+ )
+ elif isinstance(boxes, list) and (point_crs is None):
+ input_boxes = np.array(boxes)
+ if isinstance(boxes[0], int):
+ input_boxes = input_boxes[None, :]
+
+ self.boxes = input_boxes
+
+ if boxes is None or (not isinstance(boxes[0], list)):
+ masks, scores, logits = predictor.predict(
+ point_coords,
+ point_labels,
+ input_boxes,
+ mask_input,
+ multimask_output,
+ return_logits,
+ )
+ else:
+ masks, scores, logits = predictor.predict_torch(
+ point_coords=point_coords,
+ point_labels=point_coords,
+ boxes=input_boxes,
+ multimask_output=True,
+ )
+
+ self.masks = masks
+ self.scores = scores
+ self.logits = logits
+
+ if output is not None:
+ if boxes is None or (not isinstance(boxes[0], list)):
+ self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
+ else:
+ self.tensor_to_numpy(
+ index, output, mask_multiplier, dtype, save_args=kwargs
+ )
+
+ if return_results:
+ return masks, scores, logits
+
+ def tensor_to_numpy(
+ self, index=None, output=None, mask_multiplier=255, dtype="uint8", save_args={}
+ ):
+ """Convert the predicted masks from tensors to numpy arrays.
+
+ Args:
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ output (str, optional): The path to the output image. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.
+ save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.
+
+ Returns:
+ np.ndarray: The predicted mask as a numpy array.
+ """
+
+ boxes = self.boxes
+ masks = self.masks
+
+ image_pil = self.image
+ image_np = np.array(image_pil)
+
+ if index is None:
+ index = 1
+
+ masks = masks[:, index, :, :]
+ masks = masks.squeeze(1)
+
+ if boxes is None or (len(boxes) == 0): # No "object" instances found
+ print("No objects found in the image.")
+ return
+ else:
+ # Create an empty image to store the mask overlays
+ mask_overlay = np.zeros_like(
+ image_np[..., 0], dtype=dtype
+ ) # Adjusted for single channel
+
+ for i, (box, mask) in enumerate(zip(boxes, masks)):
+ # Convert tensor to numpy array if necessary and ensure it contains integers
+ if isinstance(mask, torch.Tensor):
+ mask = (
+ mask.cpu().numpy().astype(dtype)
+ ) # If mask is on GPU, use .cpu() before .numpy()
+ mask_overlay += ((mask > 0) * (i + 1)).astype(
+ dtype
+ ) # Assign a unique value for each mask
+
+ # Normalize mask_overlay to be in [0, 255]
+ mask_overlay = (
+ mask_overlay > 0
+ ) * mask_multiplier # Binary mask in [0, 255]
+
+ if output is not None:
+ array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
+ else:
+ return mask_overlay
+
+ def show_map(self, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
+ """Show the interactive map.
+
+ Args:
+ basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
+ repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.
+ out_dir (str, optional): The path to the output directory. Defaults to None.
+
+ Returns:
+ leafmap.Map: The map object.
+ """
+ return sam_map_gui(
+ self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
+ )
+
+ def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
+ """Show a canvas to collect foreground and background points.
+
+ Args:
+ image (str | np.ndarray): The input image.
+ fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
+ bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
+ radius (int, optional): The radius of the points. Defaults to 5.
+
+ Returns:
+ tuple: A tuple of two lists of foreground and background points.
+ """
+
+ if self.image is None:
+ raise ValueError("Please run set_image() first.")
+
+ image = self.image
+ fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)
+ self.fg_points = fg_points
+ self.bg_points = bg_points
+ point_coords = fg_points + bg_points
+ point_labels = [1] * len(fg_points) + [0] * len(bg_points)
+ self.point_coords = point_coords
+ self.point_labels = point_labels
+
+ def clear_cuda_cache(self):
+ """Clear the CUDA cache."""
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def image_to_image(self, image, **kwargs):
+ return image_to_image(image, self, **kwargs)
+
+ def download_tms_as_tiff(self, source, pt1, pt2, zoom, dist):
+ image = draw_tile(source, pt1[0], pt1[1], pt2[0], pt2[1], zoom, dist)
+ return image
+
+ def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
+ """Save the result to a vector file.
+
+ Args:
+ image (str): The path to the image file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
+ def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+ def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the gpkg file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_gpkg(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+ def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a shapefile.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the shapefile.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_shp(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+ def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a GeoJSON file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the GeoJSON file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_geojson(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
__call__(self, image, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255, **kwargs)
+
+
+ special
+
+
+¶Generate masks for the input tile. This function originates from the segment-anything-eo repository. + See https://bit.ly/41pwiHw
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ np.ndarray |
+ The input image as a numpy array. |
+ required | +
foreground |
+ bool |
+ Whether to generate the foreground mask. Defaults to True. |
+ True |
+
erosion_kernel |
+ tuple |
+ The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3). |
+ (3, 3) |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. +You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. |
+ 255 |
+
samgeo/hq_sam.py
def __call__(
+ self,
+ image,
+ foreground=True,
+ erosion_kernel=(3, 3),
+ mask_multiplier=255,
+ **kwargs,
+):
+ """Generate masks for the input tile. This function originates from the segment-anything-eo repository.
+ See https://bit.ly/41pwiHw
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ """
+ h, w, _ = image.shape
+
+ masks = self.mask_generator.generate(image)
+
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=np.uint8)
+ else:
+ resulting_mask = np.ones((h, w), dtype=np.uint8)
+ resulting_borders = np.zeros((h, w), dtype=np.uint8)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(np.uint8)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(np.uint8)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(np.uint8)
+ resulting_borders = (resulting_borders > 0).astype(np.uint8)
+ resulting_mask_with_borders = resulting_mask - resulting_borders
+ return resulting_mask_with_borders * mask_multiplier
+
__init__(self, model_type='vit_h', automatic=True, device=None, checkpoint_dir=None, hq=False, sam_kwargs=None, **kwargs)
+
+
+ special
+
+
+¶Initialize the class.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
model_type |
+ str |
+ The model type. It can be one of the following: vit_h, vit_l, vit_b. +Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details. |
+ 'vit_h' |
+
automatic |
+ bool |
+ Whether to use the automatic mask generator or input prompts. Defaults to True. +The automatic mask generator will segment the entire image, while the input prompts will segment selected objects. |
+ True |
+
device |
+ str |
+ The device to use. It can be one of the following: cpu, cuda. +Defaults to None, which will use cuda if available. |
+ None |
+
hq |
+ bool |
+ Whether to use the HQ-SAM model. Defaults to False. |
+ False |
+
checkpoint_dir |
+ str |
+ The path to the model checkpoint. It can be one of the following: +sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth. +Defaults to None. See https://bit.ly/3VrpxUh for more details. |
+ None |
+
sam_kwargs |
+ dict |
+ Optional arguments for fine-tuning the SAM model. Defaults to None. +The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details. +points_per_side: Optional[int] = 32, +points_per_batch: int = 64, +pred_iou_thresh: float = 0.88, +stability_score_thresh: float = 0.95, +stability_score_offset: float = 1.0, +box_nms_thresh: float = 0.7, +crop_n_layers: int = 0, +crop_nms_thresh: float = 0.7, +crop_overlap_ratio: float = 512 / 1500, +crop_n_points_downscale_factor: int = 1, +point_grids: Optional[List[np.ndarray]] = None, +min_mask_region_area: int = 0, +output_mode: str = "binary_mask", |
+ None |
+
samgeo/hq_sam.py
def __init__(
+ self,
+ model_type="vit_h",
+ automatic=True,
+ device=None,
+ checkpoint_dir=None,
+ hq=False,
+ sam_kwargs=None,
+ **kwargs,
+):
+ """Initialize the class.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
+ The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
+ device (str, optional): The device to use. It can be one of the following: cpu, cuda.
+ Defaults to None, which will use cuda if available.
+ hq (bool, optional): Whether to use the HQ-SAM model. Defaults to False.
+ checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:
+ sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
+ Defaults to None. See https://bit.ly/3VrpxUh for more details.
+ sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
+ The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
+
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+
+ """
+
+ hq = True # Using HQ-SAM
+ if "checkpoint" in kwargs:
+ checkpoint = kwargs["checkpoint"]
+ if not os.path.exists(checkpoint):
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+ kwargs.pop("checkpoint")
+ else:
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+
+ # Use cuda if available
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if device == "cuda":
+ torch.cuda.empty_cache()
+
+ self.checkpoint = checkpoint
+ self.model_type = model_type
+ self.device = device
+ self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
+ self.source = None # Store the input image path
+ self.image = None # Store the input image as a numpy array
+ # Store the masks as a list of dictionaries. Each mask is a dictionary
+ # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
+ self.masks = None
+ self.objects = None # Store the mask objects as a numpy array
+ # Store the annotations (objects with random color) as a numpy array.
+ self.annotations = None
+
+ # Store the predicted masks, iou_predictions, and low_res_masks
+ self.prediction = None
+ self.scores = None
+ self.logits = None
+
+ # Build the SAM model
+ self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
+ self.sam.to(device=self.device)
+ # Use optional arguments for fine-tuning the SAM model
+ sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}
+
+ if automatic:
+ # Segment the entire image using the automatic mask generator
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)
+ else:
+ # Segment selected objects using input prompts
+ self.predictor = SamPredictor(self.sam, **sam_kwargs)
+
clear_cuda_cache(self)
+
+
+¶Clear the CUDA cache.
+ +samgeo/hq_sam.py
def clear_cuda_cache(self):
+ """Clear the CUDA cache."""
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
generate(self, source, output=None, foreground=True, batch=False, erosion_kernel=None, mask_multiplier=255, unique=True, **kwargs)
+
+
+¶Generate masks for the input image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
source |
+ str | np.ndarray |
+ The path to the input image or the input image as a numpy array. |
+ required | +
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
foreground |
+ bool |
+ Whether to generate the foreground mask. Defaults to True. |
+ True |
+
batch |
+ bool |
+ Whether to generate masks for a batch of image tiles. Defaults to False. |
+ False |
+
erosion_kernel |
+ tuple |
+ The erosion kernel for filtering object masks and extract borders. +Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. +You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. +The parameter is ignored if unique is True. |
+ 255 |
+
unique |
+ bool |
+ Whether to assign a unique value to each object. Defaults to True. +The unique value increases from 1 to the number of objects. The larger the number, the larger the object area. |
+ True |
+
samgeo/hq_sam.py
def generate(
+ self,
+ source,
+ output=None,
+ foreground=True,
+ batch=False,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ unique=True,
+ **kwargs,
+):
+ """Generate masks for the input image.
+
+ Args:
+ source (str | np.ndarray): The path to the input image or the input image as a numpy array.
+ output (str, optional): The path to the output image. Defaults to None.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ The parameter is ignored if unique is True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
+
+ """
+
+ if isinstance(source, str):
+ if source.startswith("http"):
+ source = download_file(source)
+
+ if not os.path.exists(source):
+ raise ValueError(f"Input path {source} does not exist.")
+
+ if batch: # Subdivide the image into tiles and segment each tile
+ self.batch = True
+ self.source = source
+ self.masks = output
+ return tiff_to_tiff(
+ source,
+ output,
+ self,
+ foreground=foreground,
+ erosion_kernel=erosion_kernel,
+ mask_multiplier=mask_multiplier,
+ **kwargs,
+ )
+
+ image = cv2.imread(source)
+ elif isinstance(source, np.ndarray):
+ image = source
+ source = None
+ else:
+ raise ValueError("Input source must be either a path or a numpy array.")
+
+ self.source = source # Store the input image path
+ self.image = image # Store the input image as a numpy array
+ mask_generator = self.mask_generator # The automatic mask generator
+ masks = mask_generator.generate(image) # Segment the input image
+ self.masks = masks # Store the masks as a list of dictionaries
+ self.batch = False
+
+ if output is not None:
+ # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+ self.save_masks(
+ output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
+ )
+
predict(self, point_coords=None, point_labels=None, boxes=None, point_crs=None, mask_input=None, multimask_output=True, return_logits=False, output=None, index=None, mask_multiplier=255, dtype='float32', return_results=False, **kwargs)
+
+
+¶Predict masks for the given input prompts, using the currently set image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
point_coords |
+ str | dict | list | np.ndarray |
+ A Nx2 array of point prompts to the +model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON +dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None. |
+ None |
+
point_labels |
+ list | int | np.ndarray |
+ A length N array of labels for the +point prompts. 1 indicates a foreground point and 0 indicates a background point. |
+ None |
+
point_crs |
+ str |
+ The coordinate reference system (CRS) of the point prompts. |
+ None |
+
boxes |
+ list | np.ndarray |
+ A length 4 array given a box prompt to the +model, in XYXY format. |
+ None |
+
mask_input |
+ np.ndarray |
+ A low resolution mask input to the model, typically +coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. +multimask_output (bool, optional): If true, the model will return three masks. +For ambiguous input prompts (such as a single click), this will often +produce better masks than a single prediction. If only a single +mask is needed, the model's predicted quality score can be used +to select the best mask. For non-ambiguous prompts, such as multiple +input prompts, multimask_output=False can give better results. |
+ None |
+
return_logits |
+ bool |
+ If true, returns un-thresholded masks logits +instead of a binary mask. |
+ False |
+
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
index |
+ index |
+ The index of the mask to save. Defaults to None, +which will save the mask with the highest score. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
+ 255 |
+
dtype |
+ np.dtype |
+ The data type of the output image. Defaults to np.float32. |
+ 'float32' |
+
return_results |
+ bool |
+ Whether to return the predicted masks, scores, and logits. Defaults to False. |
+ False |
+
samgeo/hq_sam.py
def predict(
+ self,
+ point_coords=None,
+ point_labels=None,
+ boxes=None,
+ point_crs=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+ output=None,
+ index=None,
+ mask_multiplier=255,
+ dtype="float32",
+ return_results=False,
+ **kwargs,
+):
+ """Predict masks for the given input prompts, using the currently set image.
+
+ Args:
+ point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
+ dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
+ point_labels (list | int | np.ndarray, optional): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a background point.
+ point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
+ boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
+ multimask_output (bool, optional): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool, optional): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+ output (str, optional): The path to the output image. Defaults to None.
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.
+
+ """
+
+ if isinstance(boxes, str):
+ gdf = gpd.read_file(boxes)
+ if gdf.crs is not None:
+ gdf = gdf.to_crs("epsg:4326")
+ boxes = gdf.geometry.bounds.values.tolist()
+ elif isinstance(boxes, dict):
+ import json
+
+ geojson = json.dumps(boxes)
+ gdf = gpd.read_file(geojson, driver="GeoJSON")
+ boxes = gdf.geometry.bounds.values.tolist()
+
+ if isinstance(point_coords, str):
+ point_coords = vector_to_geojson(point_coords)
+
+ if isinstance(point_coords, dict):
+ point_coords = geojson_to_coords(point_coords)
+
+ if hasattr(self, "point_coords"):
+ point_coords = self.point_coords
+
+ if hasattr(self, "point_labels"):
+ point_labels = self.point_labels
+
+ if (point_crs is not None) and (point_coords is not None):
+ point_coords = coords_to_xy(self.source, point_coords, point_crs)
+
+ if isinstance(point_coords, list):
+ point_coords = np.array(point_coords)
+
+ if point_coords is not None:
+ if point_labels is None:
+ point_labels = [1] * len(point_coords)
+ elif isinstance(point_labels, int):
+ point_labels = [point_labels] * len(point_coords)
+
+ if isinstance(point_labels, list):
+ if len(point_labels) != len(point_coords):
+ if len(point_labels) == 1:
+ point_labels = point_labels * len(point_coords)
+ else:
+ raise ValueError(
+ "The length of point_labels must be equal to the length of point_coords."
+ )
+ point_labels = np.array(point_labels)
+
+ predictor = self.predictor
+
+ input_boxes = None
+ if isinstance(boxes, list) and (point_crs is not None):
+ coords = bbox_to_xy(self.source, boxes, point_crs)
+ input_boxes = np.array(coords)
+ if isinstance(coords[0], int):
+ input_boxes = input_boxes[None, :]
+ else:
+ input_boxes = torch.tensor(input_boxes, device=self.device)
+ input_boxes = predictor.transform.apply_boxes_torch(
+ input_boxes, self.image.shape[:2]
+ )
+ elif isinstance(boxes, list) and (point_crs is None):
+ input_boxes = np.array(boxes)
+ if isinstance(boxes[0], int):
+ input_boxes = input_boxes[None, :]
+
+ self.boxes = input_boxes
+
+ if boxes is None or (not isinstance(boxes[0], list)):
+ masks, scores, logits = predictor.predict(
+ point_coords,
+ point_labels,
+ input_boxes,
+ mask_input,
+ multimask_output,
+ return_logits,
+ )
+ else:
+ masks, scores, logits = predictor.predict_torch(
+ point_coords=point_coords,
+ point_labels=point_coords,
+ boxes=input_boxes,
+ multimask_output=True,
+ )
+
+ self.masks = masks
+ self.scores = scores
+ self.logits = logits
+
+ if output is not None:
+ if boxes is None or (not isinstance(boxes[0], list)):
+ self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
+ else:
+ self.tensor_to_numpy(
+ index, output, mask_multiplier, dtype, save_args=kwargs
+ )
+
+ if return_results:
+ return masks, scores, logits
+
raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs)
+
+
+¶Save the result to a vector file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str |
+ The path to the image file. |
+ required | +
output |
+ str |
+ The path to the vector file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/hq_sam.py
def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
+ """Save the result to a vector file.
+
+ Args:
+ image (str): The path to the image file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
save_masks(self, output=None, foreground=True, unique=True, erosion_kernel=None, mask_multiplier=255, **kwargs)
+
+
+¶Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
output |
+ str |
+ The path to the output image. Defaults to None, saving the masks to SamGeo.objects. |
+ None |
+
foreground |
+ bool |
+ Whether to generate the foreground mask. Defaults to True. |
+ True |
+
unique |
+ bool |
+ Whether to assign a unique value to each object. Defaults to True. |
+ True |
+
erosion_kernel |
+ tuple |
+ The erosion kernel for filtering object masks and extract borders. +Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. +You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. |
+ 255 |
+
samgeo/hq_sam.py
def save_masks(
+ self,
+ output=None,
+ foreground=True,
+ unique=True,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ **kwargs,
+):
+ """Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+
+ Args:
+ output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+
+ """
+
+ if self.masks is None:
+ raise ValueError("No masks found. Please run generate() first.")
+
+ h, w, _ = self.image.shape
+ masks = self.masks
+
+ # Set output image data type based on the number of objects
+ if len(masks) < 255:
+ dtype = np.uint8
+ elif len(masks) < 65535:
+ dtype = np.uint16
+ else:
+ dtype = np.uint32
+
+ # Generate a mask of objects with unique values
+ if unique:
+ # Sort the masks by area in ascending order
+ sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)
+
+ # Create an output image with the same size as the input image
+ objects = np.zeros(
+ (
+ sorted_masks[0]["segmentation"].shape[0],
+ sorted_masks[0]["segmentation"].shape[1],
+ )
+ )
+ # Assign a unique value to each object
+ for index, ann in enumerate(sorted_masks):
+ m = ann["segmentation"]
+ objects[m] = index + 1
+
+ # Generate a binary mask
+ else:
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=dtype)
+ else:
+ resulting_mask = np.ones((h, w), dtype=dtype)
+ resulting_borders = np.zeros((h, w), dtype=dtype)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(dtype)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(dtype)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(dtype)
+ resulting_borders = (resulting_borders > 0).astype(dtype)
+ objects = resulting_mask - resulting_borders
+ objects = objects * mask_multiplier
+
+ objects = objects.astype(dtype)
+ self.objects = objects
+
+ if output is not None: # Save the output image
+ array_to_image(self.objects, output, self.source, **kwargs)
+
save_prediction(self, output, index=None, mask_multiplier=255, dtype=<class 'numpy.float32'>, vector=None, simplify_tolerance=None, **kwargs)
+
+
+¶Save the predicted mask to the output path.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
output |
+ str |
+ The path to the output image. |
+ required | +
index |
+ int |
+ The index of the mask to save. Defaults to None, +which will save the mask with the highest score. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
+ 255 |
+
vector |
+ str |
+ The path to the output vector file. Defaults to None. |
+ None |
+
dtype |
+ np.dtype |
+ The data type of the output image. Defaults to np.float32. |
+ <class 'numpy.float32'> |
+
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/hq_sam.py
def save_prediction(
+ self,
+ output,
+ index=None,
+ mask_multiplier=255,
+ dtype=np.float32,
+ vector=None,
+ simplify_tolerance=None,
+ **kwargs,
+):
+ """Save the predicted mask to the output path.
+
+ Args:
+ output (str): The path to the output image.
+ index (int, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ vector (str, optional): The path to the output vector file. Defaults to None.
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+
+ """
+ if self.scores is None:
+ raise ValueError("No predictions found. Please run predict() first.")
+
+ if index is None:
+ index = self.scores.argmax(axis=0)
+
+ array = self.masks[index] * mask_multiplier
+ self.prediction = array
+ array_to_image(array, output, self.source, dtype=dtype, **kwargs)
+
+ if vector is not None:
+ raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)
+
set_image(self, image, image_format='RGB')
+
+
+¶Set the input image as a numpy array.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ np.ndarray |
+ The input image as a numpy array. |
+ required | +
image_format |
+ str |
+ The image format, can be RGB or BGR. Defaults to "RGB". |
+ 'RGB' |
+
samgeo/hq_sam.py
def set_image(self, image, image_format="RGB"):
+ """Set the input image as a numpy array.
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ image_format (str, optional): The image format, can be RGB or BGR. Defaults to "RGB".
+ """
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+
+ image = cv2.imread(image)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ self.image = image
+ elif isinstance(image, np.ndarray):
+ pass
+ else:
+ raise ValueError("Input image must be either a path or a numpy array.")
+
+ self.predictor.set_image(image, image_format=image_format)
+
show_anns(self, figsize=(12, 10), axis='off', alpha=0.35, output=None, blend=True, **kwargs)
+
+
+¶Show the annotations (objects with random color) on the input image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
figsize |
+ tuple |
+ The figure size. Defaults to (12, 10). |
+ (12, 10) |
+
axis |
+ str |
+ Whether to show the axis. Defaults to "off". |
+ 'off' |
+
alpha |
+ float |
+ The alpha value for the annotations. Defaults to 0.35. |
+ 0.35 |
+
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
blend |
+ bool |
+ Whether to show the input image. Defaults to True. |
+ True |
+
samgeo/hq_sam.py
def show_anns(
+ self,
+ figsize=(12, 10),
+ axis="off",
+ alpha=0.35,
+ output=None,
+ blend=True,
+ **kwargs,
+):
+ """Show the annotations (objects with random color) on the input image.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
+ output (str, optional): The path to the output image. Defaults to None.
+ blend (bool, optional): Whether to show the input image. Defaults to True.
+ """
+
+ import matplotlib.pyplot as plt
+
+ anns = self.masks
+
+ if self.image is None:
+ print("Please run generate() first.")
+ return
+
+ if anns is None or len(anns) == 0:
+ return
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.image)
+
+ sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
+
+ ax = plt.gca()
+ ax.set_autoscale_on(False)
+
+ img = np.ones(
+ (
+ sorted_anns[0]["segmentation"].shape[0],
+ sorted_anns[0]["segmentation"].shape[1],
+ 4,
+ )
+ )
+ img[:, :, 3] = 0
+ for ann in sorted_anns:
+ m = ann["segmentation"]
+ color_mask = np.concatenate([np.random.random(3), [alpha]])
+ img[m] = color_mask
+ ax.imshow(img)
+
+ if "dpi" not in kwargs:
+ kwargs["dpi"] = 100
+
+ if "bbox_inches" not in kwargs:
+ kwargs["bbox_inches"] = "tight"
+
+ plt.axis(axis)
+
+ self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
+
+ if output is not None:
+ if blend:
+ array = blend_images(
+ self.annotations, self.image, alpha=alpha, show=False
+ )
+ else:
+ array = self.annotations
+ array_to_image(array, output, self.source)
+
show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)
+
+
+¶Show a canvas to collect foreground and background points.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str | np.ndarray |
+ The input image. |
+ required | +
fg_color |
+ tuple |
+ The color for the foreground points. Defaults to (0, 255, 0). |
+ (0, 255, 0) |
+
bg_color |
+ tuple |
+ The color for the background points. Defaults to (0, 0, 255). |
+ (0, 0, 255) |
+
radius |
+ int |
+ The radius of the points. Defaults to 5. |
+ 5 |
+
Returns:
+Type | +Description | +
---|---|
tuple |
+ A tuple of two lists of foreground and background points. |
+
samgeo/hq_sam.py
def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
+ """Show a canvas to collect foreground and background points.
+
+ Args:
+ image (str | np.ndarray): The input image.
+ fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
+ bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
+ radius (int, optional): The radius of the points. Defaults to 5.
+
+ Returns:
+ tuple: A tuple of two lists of foreground and background points.
+ """
+
+ if self.image is None:
+ raise ValueError("Please run set_image() first.")
+
+ image = self.image
+ fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)
+ self.fg_points = fg_points
+ self.bg_points = bg_points
+ point_coords = fg_points + bg_points
+ point_labels = [1] * len(fg_points) + [0] * len(bg_points)
+ self.point_coords = point_coords
+ self.point_labels = point_labels
+
show_map(self, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)
+
+
+¶Show the interactive map.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
basemap |
+ str |
+ The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID. |
+ 'SATELLITE' |
+
repeat_mode |
+ bool |
+ Whether to use the repeat mode for draw control. Defaults to True. |
+ True |
+
out_dir |
+ str |
+ The path to the output directory. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
leafmap.Map |
+ The map object. |
+
samgeo/hq_sam.py
def show_map(self, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
+ """Show the interactive map.
+
+ Args:
+ basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
+ repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.
+ out_dir (str, optional): The path to the output directory. Defaults to None.
+
+ Returns:
+ leafmap.Map: The map object.
+ """
+ return sam_map_gui(
+ self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
+ )
+
show_masks(self, figsize=(12, 10), cmap='binary_r', axis='off', foreground=True, **kwargs)
+
+
+¶Show the binary mask or the mask of objects with unique values.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
figsize |
+ tuple |
+ The figure size. Defaults to (12, 10). |
+ (12, 10) |
+
cmap |
+ str |
+ The colormap. Defaults to "binary_r". |
+ 'binary_r' |
+
axis |
+ str |
+ Whether to show the axis. Defaults to "off". |
+ 'off' |
+
foreground |
+ bool |
+ Whether to show the foreground mask only. Defaults to True. |
+ True |
+
**kwargs |
+ + | Other arguments for save_masks(). |
+ {} |
+
samgeo/hq_sam.py
def show_masks(
+ self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs
+):
+ """Show the binary mask or the mask of objects with unique values.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ cmap (str, optional): The colormap. Defaults to "binary_r".
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
+ **kwargs: Other arguments for save_masks().
+ """
+
+ import matplotlib.pyplot as plt
+
+ if self.batch:
+ self.objects = cv2.imread(self.masks)
+ else:
+ if self.objects is None:
+ self.save_masks(foreground=foreground, **kwargs)
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.objects, cmap=cmap)
+ plt.axis(axis)
+ plt.show()
+
tensor_to_numpy(self, index=None, output=None, mask_multiplier=255, dtype='uint8', save_args={})
+
+
+¶Convert the predicted masks from tensors to numpy arrays.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
index |
+ index |
+ The index of the mask to save. Defaults to None, +which will save the mask with the highest score. |
+ None |
+
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
+ 255 |
+
dtype |
+ np.dtype |
+ The data type of the output image. Defaults to np.uint8. |
+ 'uint8' |
+
save_args |
+ dict |
+ Optional arguments for saving the output image. Defaults to {}. |
+ {} |
+
Returns:
+Type | +Description | +
---|---|
np.ndarray |
+ The predicted mask as a numpy array. |
+
samgeo/hq_sam.py
def tensor_to_numpy(
+ self, index=None, output=None, mask_multiplier=255, dtype="uint8", save_args={}
+):
+ """Convert the predicted masks from tensors to numpy arrays.
+
+ Args:
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ output (str, optional): The path to the output image. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.
+ save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.
+
+ Returns:
+ np.ndarray: The predicted mask as a numpy array.
+ """
+
+ boxes = self.boxes
+ masks = self.masks
+
+ image_pil = self.image
+ image_np = np.array(image_pil)
+
+ if index is None:
+ index = 1
+
+ masks = masks[:, index, :, :]
+ masks = masks.squeeze(1)
+
+ if boxes is None or (len(boxes) == 0): # No "object" instances found
+ print("No objects found in the image.")
+ return
+ else:
+ # Create an empty image to store the mask overlays
+ mask_overlay = np.zeros_like(
+ image_np[..., 0], dtype=dtype
+ ) # Adjusted for single channel
+
+ for i, (box, mask) in enumerate(zip(boxes, masks)):
+ # Convert tensor to numpy array if necessary and ensure it contains integers
+ if isinstance(mask, torch.Tensor):
+ mask = (
+ mask.cpu().numpy().astype(dtype)
+ ) # If mask is on GPU, use .cpu() before .numpy()
+ mask_overlay += ((mask > 0) * (i + 1)).astype(
+ dtype
+ ) # Assign a unique value for each mask
+
+ # Normalize mask_overlay to be in [0, 255]
+ mask_overlay = (
+ mask_overlay > 0
+ ) * mask_multiplier # Binary mask in [0, 255]
+
+ if output is not None:
+ array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
+ else:
+ return mask_overlay
+
tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a GeoJSON file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the GeoJSON file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/hq_sam.py
def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a GeoJSON file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the GeoJSON file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_geojson(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a gpkg file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the gpkg file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/hq_sam.py
def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the gpkg file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_gpkg(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a shapefile.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the shapefile. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/hq_sam.py
def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a shapefile.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the shapefile.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_shp(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a gpkg file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the vector file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/hq_sam.py
def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+SamGeoPredictor (SamPredictor)
+
+
+
+
+¶samgeo/hq_sam.py
class SamGeoPredictor(SamPredictor):
+ def __init__(
+ self,
+ sam_model,
+ ):
+ from segment_anything.utils.transforms import ResizeLongestSide
+
+ self.model = sam_model
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
+
+ def set_image(self, image):
+ super(SamGeoPredictor, self).set_image(image)
+
+ def predict(
+ self,
+ src_fp=None,
+ geo_box=None,
+ point_coords=None,
+ point_labels=None,
+ box=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+ ):
+ if geo_box and src_fp:
+ self.crs = "EPSG:4326"
+ dst_crs = get_crs(src_fp)
+ sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)
+ ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)
+ xs = np.array([sw[0], ne[0]])
+ ys = np.array([sw[1], ne[1]])
+ box = get_pixel_coords(src_fp, xs, ys)
+ self.geo_box = geo_box
+ self.width = box[2] - box[0]
+ self.height = box[3] - box[1]
+ self.geo_transform = set_transform(geo_box, self.width, self.height)
+
+ masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(
+ point_coords, point_labels, box, mask_input, multimask_output, return_logits
+ )
+
+ return masks, iou_predictions, low_res_masks
+
+ def masks_to_geotiff(self, src_fp, dst_fp, masks):
+ profile = get_profile(src_fp)
+ write_raster(
+ dst_fp,
+ masks,
+ profile,
+ self.width,
+ self.height,
+ self.geo_transform,
+ self.crs,
+ )
+
+ def geotiff_to_geojson(self, src_fp, dst_fp, bidx=1):
+ gdf = get_features(src_fp, bidx)
+ write_features(gdf, dst_fp)
+ return gdf
+
predict(self, src_fp=None, geo_box=None, point_coords=None, point_labels=None, box=None, mask_input=None, multimask_output=True, return_logits=False)
+
+
+¶Predict masks for the given input prompts, using the currently set image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
point_coords |
+ np.ndarray or None |
+ A Nx2 array of point prompts to the +model. Each point is in (X,Y) in pixels. |
+ None |
+
point_labels |
+ np.ndarray or None |
+ A length N array of labels for the +point prompts. 1 indicates a foreground point and 0 indicates a +background point. |
+ None |
+
box |
+ np.ndarray or None |
+ A length 4 array given a box prompt to the +model, in XYXY format. |
+ None |
+
mask_input |
+ np.ndarray |
+ A low resolution mask input to the model, typically +coming from a previous prediction iteration. Has form 1xHxW, where +for SAM, H=W=256. |
+ None |
+
multimask_output |
+ bool |
+ If true, the model will return three masks. +For ambiguous input prompts (such as a single click), this will often +produce better masks than a single prediction. If only a single +mask is needed, the model's predicted quality score can be used +to select the best mask. For non-ambiguous prompts, such as multiple +input prompts, multimask_output=False can give better results. |
+ True |
+
return_logits |
+ bool |
+ If true, returns un-thresholded masks logits +instead of a binary mask. |
+ False |
+
Returns:
+Type | +Description | +
---|---|
(np.ndarray) |
+ The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. +(np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. +(np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. |
+
samgeo/hq_sam.py
def predict(
+ self,
+ src_fp=None,
+ geo_box=None,
+ point_coords=None,
+ point_labels=None,
+ box=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+):
+ if geo_box and src_fp:
+ self.crs = "EPSG:4326"
+ dst_crs = get_crs(src_fp)
+ sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)
+ ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)
+ xs = np.array([sw[0], ne[0]])
+ ys = np.array([sw[1], ne[1]])
+ box = get_pixel_coords(src_fp, xs, ys)
+ self.geo_box = geo_box
+ self.width = box[2] - box[0]
+ self.height = box[3] - box[1]
+ self.geo_transform = set_transform(geo_box, self.width, self.height)
+
+ masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(
+ point_coords, point_labels, box, mask_input, multimask_output, return_logits
+ )
+
+ return masks, iou_predictions, low_res_masks
+
set_image(self, image)
+
+
+¶Calculates the image embeddings for the provided image, allowing +masks to be predicted with the 'predict' method.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ np.ndarray |
+ The image for calculating masks. Expects an +image in HWC uint8 format, with pixel values in [0, 255]. |
+ required | +
image_format |
+ str |
+ The color format of the image, in ['RGB', 'BGR']. |
+ required | +
samgeo/hq_sam.py
def set_image(self, image):
+ super(SamGeoPredictor, self).set_image(image)
+
A Python package for segmenting geospatial data with the Segment Anything Model (SAM) 🗺️
+The segment-geospatial package draws its inspiration from segment-anything-eo repository authored by Aliaksandr Hancharenka. To facilitate the use of the Segment Anything Model (SAM) for geospatial data, I have developed the segment-anything-py and segment-geospatial Python packages, which are now available on PyPI and conda-forge. My primary objective is to simplify the process of leveraging SAM for geospatial data analysis by enabling users to achieve this with minimal coding effort. I have adapted the source code of segment-geospatial from the segment-anything-eo repository, and credit for its original version goes to Aliaksandr Hancharenka.
+Video tutorials are available on my YouTube Channel.
+The Segment Anything Model is computationally intensive, and a powerful GPU is recommended to process large datasets. It is recommended to have a GPU with at least 8 GB of GPU memory. You can utilize the free GPU resources provided by Google Colab. Alternatively, you can apply for AWS Cloud Credit for Research, which offers cloud credits to support academic research. If you are in the Greater China region, apply for the AWS Cloud Credit here.
+This repository and its content are provided for educational purposes only. By using the information and code provided, users acknowledge that they are using the APIs and models at their own risk and agree to comply with any applicable laws and regulations. Users who intend to download a large number of image tiles from any basemap are advised to contact the basemap provider to obtain permission before doing so. Unauthorized use of the basemap or any of its components may be a violation of copyright laws or other applicable laws and regulations.
+This project is based upon work partially supported by the National Aeronautics and Space Administration (NASA) under Grant No. 80NSSC22K1742 issued through the Open Source Tools, Frameworks, and Libraries 2020 Program.
+This project is also supported by Amazon Web Services (AWS). In addition, this package was made possible by the following open source projects. Credit goes to the developers of these projects.
+segment-geospatial is available on PyPI. To install segment-geospatial, run this command in your terminal:
+1 |
|
segment-geospatial is also available on conda-forge. If you have
+Anaconda or Miniconda installed on your computer, you can install segment-geospatial using the following commands. It is recommended to create a fresh conda environment for segment-geospatial. The following commands will create a new conda environment named geo
and install segment-geospatial and its dependencies:
1 +2 +3 +4 |
|
Samgeo-geospatial has some optional dependencies that are not included in the default conda environment. To install these dependencies, run the following command:
+1 |
|
As of July 9th, 2023 Linux systems have also required that libgl1
be installed for segment-geospatial to work. The following command will install that dependency
1 |
|
To install the development version from GitHub using Git, run the following command in your terminal:
+1 |
|
You can also use docker to run segment-geospatial:
+1 |
|
To enable GPU for segment-geospatial, run the following command to run a short benchmark on your GPU:
+1 |
|
The output should be similar to the following:
+1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 |
|
If you encounter the following error:
+1 |
|
Try adding sudo
to the command:
1 |
|
Once everything is working, you can run the following command to start a Jupyter Notebook server:
+1 |
|
The source code is adapted from https://github.com/aliaksandr960/segment-anything-eo. Credit to the author Aliaksandr Hancharenka.
+ + + +
+SamGeo
+
+
+
+¶The main class for segmenting geospatial data with the Segment Anything Model (SAM). See +https://github.com/facebookresearch/segment-anything for details.
+ +samgeo/samgeo.py
class SamGeo:
+ """The main class for segmenting geospatial data with the Segment Anything Model (SAM). See
+ https://github.com/facebookresearch/segment-anything for details.
+ """
+
+ def __init__(
+ self,
+ model_type="vit_h",
+ automatic=True,
+ device=None,
+ checkpoint_dir=None,
+ sam_kwargs=None,
+ **kwargs,
+ ):
+ """Initialize the class.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
+ The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
+ device (str, optional): The device to use. It can be one of the following: cpu, cuda.
+ Defaults to None, which will use cuda if available.
+ checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:
+ sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
+ Defaults to None. See https://bit.ly/3VrpxUh for more details.
+ sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
+ The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
+
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+
+ """
+ hq = False # Not using HQ-SAM
+
+ if "checkpoint" in kwargs:
+ checkpoint = kwargs["checkpoint"]
+ if not os.path.exists(checkpoint):
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+ kwargs.pop("checkpoint")
+ else:
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+
+ # Use cuda if available
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if device == "cuda":
+ torch.cuda.empty_cache()
+
+ self.checkpoint = checkpoint
+ self.model_type = model_type
+ self.device = device
+ self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
+ self.source = None # Store the input image path
+ self.image = None # Store the input image as a numpy array
+ # Store the masks as a list of dictionaries. Each mask is a dictionary
+ # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
+ self.masks = None
+ self.objects = None # Store the mask objects as a numpy array
+ # Store the annotations (objects with random color) as a numpy array.
+ self.annotations = None
+
+ # Store the predicted masks, iou_predictions, and low_res_masks
+ self.prediction = None
+ self.scores = None
+ self.logits = None
+
+ # Build the SAM model
+ self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
+ self.sam.to(device=self.device)
+ # Use optional arguments for fine-tuning the SAM model
+ sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}
+
+ if automatic:
+ # Segment the entire image using the automatic mask generator
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)
+ else:
+ # Segment selected objects using input prompts
+ self.predictor = SamPredictor(self.sam, **sam_kwargs)
+
+ def __call__(
+ self,
+ image,
+ foreground=True,
+ erosion_kernel=(3, 3),
+ mask_multiplier=255,
+ **kwargs,
+ ):
+ """Generate masks for the input tile. This function originates from the segment-anything-eo repository.
+ See https://bit.ly/41pwiHw
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ """
+ h, w, _ = image.shape
+
+ masks = self.mask_generator.generate(image)
+
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=np.uint8)
+ else:
+ resulting_mask = np.ones((h, w), dtype=np.uint8)
+ resulting_borders = np.zeros((h, w), dtype=np.uint8)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(np.uint8)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(np.uint8)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(np.uint8)
+ resulting_borders = (resulting_borders > 0).astype(np.uint8)
+ resulting_mask_with_borders = resulting_mask - resulting_borders
+ return resulting_mask_with_borders * mask_multiplier
+
+ def generate(
+ self,
+ source,
+ output=None,
+ foreground=True,
+ batch=False,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ unique=True,
+ **kwargs,
+ ):
+ """Generate masks for the input image.
+
+ Args:
+ source (str | np.ndarray): The path to the input image or the input image as a numpy array.
+ output (str, optional): The path to the output image. Defaults to None.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ The parameter is ignored if unique is True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
+
+ """
+
+ if isinstance(source, str):
+ if source.startswith("http"):
+ source = download_file(source)
+
+ if not os.path.exists(source):
+ raise ValueError(f"Input path {source} does not exist.")
+
+ if batch: # Subdivide the image into tiles and segment each tile
+ self.batch = True
+ self.source = source
+ self.masks = output
+ return tiff_to_tiff(
+ source,
+ output,
+ self,
+ foreground=foreground,
+ erosion_kernel=erosion_kernel,
+ mask_multiplier=mask_multiplier,
+ **kwargs,
+ )
+
+ image = cv2.imread(source)
+ elif isinstance(source, np.ndarray):
+ image = source
+ source = None
+ else:
+ raise ValueError("Input source must be either a path or a numpy array.")
+
+ self.source = source # Store the input image path
+ self.image = image # Store the input image as a numpy array
+ mask_generator = self.mask_generator # The automatic mask generator
+ masks = mask_generator.generate(image) # Segment the input image
+ self.masks = masks # Store the masks as a list of dictionaries
+ self.batch = False
+
+ if output is not None:
+ # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+ self.save_masks(
+ output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
+ )
+
+ def save_masks(
+ self,
+ output=None,
+ foreground=True,
+ unique=True,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ **kwargs,
+ ):
+ """Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+
+ Args:
+ output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+
+ """
+
+ if self.masks is None:
+ raise ValueError("No masks found. Please run generate() first.")
+
+ h, w, _ = self.image.shape
+ masks = self.masks
+
+ # Set output image data type based on the number of objects
+ if len(masks) < 255:
+ dtype = np.uint8
+ elif len(masks) < 65535:
+ dtype = np.uint16
+ else:
+ dtype = np.uint32
+
+ # Generate a mask of objects with unique values
+ if unique:
+ # Sort the masks by area in ascending order
+ sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)
+
+ # Create an output image with the same size as the input image
+ objects = np.zeros(
+ (
+ sorted_masks[0]["segmentation"].shape[0],
+ sorted_masks[0]["segmentation"].shape[1],
+ )
+ )
+ # Assign a unique value to each object
+ for index, ann in enumerate(sorted_masks):
+ m = ann["segmentation"]
+ objects[m] = index + 1
+
+ # Generate a binary mask
+ else:
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=dtype)
+ else:
+ resulting_mask = np.ones((h, w), dtype=dtype)
+ resulting_borders = np.zeros((h, w), dtype=dtype)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(dtype)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(dtype)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(dtype)
+ resulting_borders = (resulting_borders > 0).astype(dtype)
+ objects = resulting_mask - resulting_borders
+ objects = objects * mask_multiplier
+
+ objects = objects.astype(dtype)
+ self.objects = objects
+
+ if output is not None: # Save the output image
+ array_to_image(self.objects, output, self.source, **kwargs)
+
+ def show_masks(
+ self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs
+ ):
+ """Show the binary mask or the mask of objects with unique values.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ cmap (str, optional): The colormap. Defaults to "binary_r".
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
+ **kwargs: Other arguments for save_masks().
+ """
+
+ import matplotlib.pyplot as plt
+
+ if self.batch:
+ self.objects = cv2.imread(self.masks)
+ else:
+ if self.objects is None:
+ self.save_masks(foreground=foreground, **kwargs)
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.objects, cmap=cmap)
+ plt.axis(axis)
+ plt.show()
+
+ def show_anns(
+ self,
+ figsize=(12, 10),
+ axis="off",
+ alpha=0.35,
+ output=None,
+ blend=True,
+ **kwargs,
+ ):
+ """Show the annotations (objects with random color) on the input image.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
+ output (str, optional): The path to the output image. Defaults to None.
+ blend (bool, optional): Whether to show the input image. Defaults to True.
+ """
+
+ import matplotlib.pyplot as plt
+
+ anns = self.masks
+
+ if self.image is None:
+ print("Please run generate() first.")
+ return
+
+ if anns is None or len(anns) == 0:
+ return
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.image)
+
+ sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
+
+ ax = plt.gca()
+ ax.set_autoscale_on(False)
+
+ img = np.ones(
+ (
+ sorted_anns[0]["segmentation"].shape[0],
+ sorted_anns[0]["segmentation"].shape[1],
+ 4,
+ )
+ )
+ img[:, :, 3] = 0
+ for ann in sorted_anns:
+ m = ann["segmentation"]
+ color_mask = np.concatenate([np.random.random(3), [alpha]])
+ img[m] = color_mask
+ ax.imshow(img)
+
+ if "dpi" not in kwargs:
+ kwargs["dpi"] = 100
+
+ if "bbox_inches" not in kwargs:
+ kwargs["bbox_inches"] = "tight"
+
+ plt.axis(axis)
+
+ self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
+
+ if output is not None:
+ if blend:
+ array = blend_images(
+ self.annotations, self.image, alpha=alpha, show=False
+ )
+ else:
+ array = self.annotations
+ array_to_image(array, output, self.source)
+
+ def set_image(self, image, image_format="RGB"):
+ """Set the input image as a numpy array.
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ image_format (str, optional): The image format, can be RGB or BGR. Defaults to "RGB".
+ """
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+
+ image = cv2.imread(image)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ self.image = image
+ elif isinstance(image, np.ndarray):
+ pass
+ else:
+ raise ValueError("Input image must be either a path or a numpy array.")
+
+ self.predictor.set_image(image, image_format=image_format)
+
+ def save_prediction(
+ self,
+ output,
+ index=None,
+ mask_multiplier=255,
+ dtype=np.float32,
+ vector=None,
+ simplify_tolerance=None,
+ **kwargs,
+ ):
+ """Save the predicted mask to the output path.
+
+ Args:
+ output (str): The path to the output image.
+ index (int, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ vector (str, optional): The path to the output vector file. Defaults to None.
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+
+ """
+ if self.scores is None:
+ raise ValueError("No predictions found. Please run predict() first.")
+
+ if index is None:
+ index = self.scores.argmax(axis=0)
+
+ array = self.masks[index] * mask_multiplier
+ self.prediction = array
+ array_to_image(array, output, self.source, dtype=dtype, **kwargs)
+
+ if vector is not None:
+ raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)
+
+ def predict(
+ self,
+ point_coords=None,
+ point_labels=None,
+ boxes=None,
+ point_crs=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+ output=None,
+ index=None,
+ mask_multiplier=255,
+ dtype="float32",
+ return_results=False,
+ **kwargs,
+ ):
+ """Predict masks for the given input prompts, using the currently set image.
+
+ Args:
+ point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
+ dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
+ point_labels (list | int | np.ndarray, optional): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a background point.
+ point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
+ boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
+ multimask_output (bool, optional): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool, optional): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+ output (str, optional): The path to the output image. Defaults to None.
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.
+
+ """
+
+ if isinstance(boxes, str):
+ gdf = gpd.read_file(boxes)
+ if gdf.crs is not None:
+ gdf = gdf.to_crs("epsg:4326")
+ boxes = gdf.geometry.bounds.values.tolist()
+ elif isinstance(boxes, dict):
+ import json
+
+ geojson = json.dumps(boxes)
+ gdf = gpd.read_file(geojson, driver="GeoJSON")
+ boxes = gdf.geometry.bounds.values.tolist()
+
+ if isinstance(point_coords, str):
+ point_coords = vector_to_geojson(point_coords)
+
+ if isinstance(point_coords, dict):
+ point_coords = geojson_to_coords(point_coords)
+
+ if hasattr(self, "point_coords"):
+ point_coords = self.point_coords
+
+ if hasattr(self, "point_labels"):
+ point_labels = self.point_labels
+
+ if (point_crs is not None) and (point_coords is not None):
+ point_coords = coords_to_xy(self.source, point_coords, point_crs)
+
+ if isinstance(point_coords, list):
+ point_coords = np.array(point_coords)
+
+ if point_coords is not None:
+ if point_labels is None:
+ point_labels = [1] * len(point_coords)
+ elif isinstance(point_labels, int):
+ point_labels = [point_labels] * len(point_coords)
+
+ if isinstance(point_labels, list):
+ if len(point_labels) != len(point_coords):
+ if len(point_labels) == 1:
+ point_labels = point_labels * len(point_coords)
+ else:
+ raise ValueError(
+ "The length of point_labels must be equal to the length of point_coords."
+ )
+ point_labels = np.array(point_labels)
+
+ predictor = self.predictor
+
+ input_boxes = None
+ if isinstance(boxes, list) and (point_crs is not None):
+ coords = bbox_to_xy(self.source, boxes, point_crs)
+ input_boxes = np.array(coords)
+ if isinstance(coords[0], int):
+ input_boxes = input_boxes[None, :]
+ else:
+ input_boxes = torch.tensor(input_boxes, device=self.device)
+ input_boxes = predictor.transform.apply_boxes_torch(
+ input_boxes, self.image.shape[:2]
+ )
+ elif isinstance(boxes, list) and (point_crs is None):
+ input_boxes = np.array(boxes)
+ if isinstance(boxes[0], int):
+ input_boxes = input_boxes[None, :]
+
+ self.boxes = input_boxes
+
+ if boxes is None or (not isinstance(boxes[0], list)):
+ masks, scores, logits = predictor.predict(
+ point_coords,
+ point_labels,
+ input_boxes,
+ mask_input,
+ multimask_output,
+ return_logits,
+ )
+ else:
+ masks, scores, logits = predictor.predict_torch(
+ point_coords=point_coords,
+ point_labels=point_coords,
+ boxes=input_boxes,
+ multimask_output=True,
+ )
+
+ self.masks = masks
+ self.scores = scores
+ self.logits = logits
+
+ if output is not None:
+ if boxes is None or (not isinstance(boxes[0], list)):
+ self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
+ else:
+ self.tensor_to_numpy(
+ index, output, mask_multiplier, dtype, save_args=kwargs
+ )
+
+ if return_results:
+ return masks, scores, logits
+
+ def tensor_to_numpy(
+ self, index=None, output=None, mask_multiplier=255, dtype="uint8", save_args={}
+ ):
+ """Convert the predicted masks from tensors to numpy arrays.
+
+ Args:
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ output (str, optional): The path to the output image. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.
+ save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.
+
+ Returns:
+ np.ndarray: The predicted mask as a numpy array.
+ """
+
+ boxes = self.boxes
+ masks = self.masks
+
+ image_pil = self.image
+ image_np = np.array(image_pil)
+
+ if index is None:
+ index = 1
+
+ masks = masks[:, index, :, :]
+ masks = masks.squeeze(1)
+
+ if boxes is None or (len(boxes) == 0): # No "object" instances found
+ print("No objects found in the image.")
+ return
+ else:
+ # Create an empty image to store the mask overlays
+ mask_overlay = np.zeros_like(
+ image_np[..., 0], dtype=dtype
+ ) # Adjusted for single channel
+
+ for i, (box, mask) in enumerate(zip(boxes, masks)):
+ # Convert tensor to numpy array if necessary and ensure it contains integers
+ if isinstance(mask, torch.Tensor):
+ mask = (
+ mask.cpu().numpy().astype(dtype)
+ ) # If mask is on GPU, use .cpu() before .numpy()
+ mask_overlay += ((mask > 0) * (i + 1)).astype(
+ dtype
+ ) # Assign a unique value for each mask
+
+ # Normalize mask_overlay to be in [0, 255]
+ mask_overlay = (
+ mask_overlay > 0
+ ) * mask_multiplier # Binary mask in [0, 255]
+
+ if output is not None:
+ array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
+ else:
+ return mask_overlay
+
+ def show_map(self, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
+ """Show the interactive map.
+
+ Args:
+ basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
+ repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.
+ out_dir (str, optional): The path to the output directory. Defaults to None.
+
+ Returns:
+ leafmap.Map: The map object.
+ """
+ return sam_map_gui(
+ self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
+ )
+
+ def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
+ """Show a canvas to collect foreground and background points.
+
+ Args:
+ image (str | np.ndarray): The input image.
+ fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
+ bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
+ radius (int, optional): The radius of the points. Defaults to 5.
+
+ Returns:
+ tuple: A tuple of two lists of foreground and background points.
+ """
+
+ if self.image is None:
+ raise ValueError("Please run set_image() first.")
+
+ image = self.image
+ fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)
+ self.fg_points = fg_points
+ self.bg_points = bg_points
+ point_coords = fg_points + bg_points
+ point_labels = [1] * len(fg_points) + [0] * len(bg_points)
+ self.point_coords = point_coords
+ self.point_labels = point_labels
+
+ def clear_cuda_cache(self):
+ """Clear the CUDA cache."""
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def image_to_image(self, image, **kwargs):
+ return image_to_image(image, self, **kwargs)
+
+ def download_tms_as_tiff(self, source, pt1, pt2, zoom, dist):
+ image = draw_tile(source, pt1[0], pt1[1], pt2[0], pt2[1], zoom, dist)
+ return image
+
+ def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
+ """Save the result to a vector file.
+
+ Args:
+ image (str): The path to the image file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
+ def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+ def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the gpkg file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_gpkg(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+ def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a shapefile.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the shapefile.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_shp(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+ def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a GeoJSON file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the GeoJSON file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_geojson(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
__call__(self, image, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255, **kwargs)
+
+
+ special
+
+
+¶Generate masks for the input tile. This function originates from the segment-anything-eo repository. + See https://bit.ly/41pwiHw
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ np.ndarray |
+ The input image as a numpy array. |
+ required | +
foreground |
+ bool |
+ Whether to generate the foreground mask. Defaults to True. |
+ True |
+
erosion_kernel |
+ tuple |
+ The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3). |
+ (3, 3) |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. +You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. |
+ 255 |
+
samgeo/samgeo.py
def __call__(
+ self,
+ image,
+ foreground=True,
+ erosion_kernel=(3, 3),
+ mask_multiplier=255,
+ **kwargs,
+):
+ """Generate masks for the input tile. This function originates from the segment-anything-eo repository.
+ See https://bit.ly/41pwiHw
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ """
+ h, w, _ = image.shape
+
+ masks = self.mask_generator.generate(image)
+
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=np.uint8)
+ else:
+ resulting_mask = np.ones((h, w), dtype=np.uint8)
+ resulting_borders = np.zeros((h, w), dtype=np.uint8)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(np.uint8)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(np.uint8)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(np.uint8)
+ resulting_borders = (resulting_borders > 0).astype(np.uint8)
+ resulting_mask_with_borders = resulting_mask - resulting_borders
+ return resulting_mask_with_borders * mask_multiplier
+
__init__(self, model_type='vit_h', automatic=True, device=None, checkpoint_dir=None, sam_kwargs=None, **kwargs)
+
+
+ special
+
+
+¶Initialize the class.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
model_type |
+ str |
+ The model type. It can be one of the following: vit_h, vit_l, vit_b. +Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details. |
+ 'vit_h' |
+
automatic |
+ bool |
+ Whether to use the automatic mask generator or input prompts. Defaults to True. +The automatic mask generator will segment the entire image, while the input prompts will segment selected objects. |
+ True |
+
device |
+ str |
+ The device to use. It can be one of the following: cpu, cuda. +Defaults to None, which will use cuda if available. |
+ None |
+
checkpoint_dir |
+ str |
+ The path to the model checkpoint. It can be one of the following: +sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth. +Defaults to None. See https://bit.ly/3VrpxUh for more details. |
+ None |
+
sam_kwargs |
+ dict |
+ Optional arguments for fine-tuning the SAM model. Defaults to None. +The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details. +points_per_side: Optional[int] = 32, +points_per_batch: int = 64, +pred_iou_thresh: float = 0.88, +stability_score_thresh: float = 0.95, +stability_score_offset: float = 1.0, +box_nms_thresh: float = 0.7, +crop_n_layers: int = 0, +crop_nms_thresh: float = 0.7, +crop_overlap_ratio: float = 512 / 1500, +crop_n_points_downscale_factor: int = 1, +point_grids: Optional[List[np.ndarray]] = None, +min_mask_region_area: int = 0, +output_mode: str = "binary_mask", |
+ None |
+
samgeo/samgeo.py
def __init__(
+ self,
+ model_type="vit_h",
+ automatic=True,
+ device=None,
+ checkpoint_dir=None,
+ sam_kwargs=None,
+ **kwargs,
+):
+ """Initialize the class.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
+ The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
+ device (str, optional): The device to use. It can be one of the following: cpu, cuda.
+ Defaults to None, which will use cuda if available.
+ checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:
+ sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
+ Defaults to None. See https://bit.ly/3VrpxUh for more details.
+ sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
+ The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
+
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+
+ """
+ hq = False # Not using HQ-SAM
+
+ if "checkpoint" in kwargs:
+ checkpoint = kwargs["checkpoint"]
+ if not os.path.exists(checkpoint):
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+ kwargs.pop("checkpoint")
+ else:
+ checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
+
+ # Use cuda if available
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if device == "cuda":
+ torch.cuda.empty_cache()
+
+ self.checkpoint = checkpoint
+ self.model_type = model_type
+ self.device = device
+ self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
+ self.source = None # Store the input image path
+ self.image = None # Store the input image as a numpy array
+ # Store the masks as a list of dictionaries. Each mask is a dictionary
+ # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
+ self.masks = None
+ self.objects = None # Store the mask objects as a numpy array
+ # Store the annotations (objects with random color) as a numpy array.
+ self.annotations = None
+
+ # Store the predicted masks, iou_predictions, and low_res_masks
+ self.prediction = None
+ self.scores = None
+ self.logits = None
+
+ # Build the SAM model
+ self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
+ self.sam.to(device=self.device)
+ # Use optional arguments for fine-tuning the SAM model
+ sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}
+
+ if automatic:
+ # Segment the entire image using the automatic mask generator
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)
+ else:
+ # Segment selected objects using input prompts
+ self.predictor = SamPredictor(self.sam, **sam_kwargs)
+
clear_cuda_cache(self)
+
+
+¶Clear the CUDA cache.
+ +samgeo/samgeo.py
def clear_cuda_cache(self):
+ """Clear the CUDA cache."""
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
generate(self, source, output=None, foreground=True, batch=False, erosion_kernel=None, mask_multiplier=255, unique=True, **kwargs)
+
+
+¶Generate masks for the input image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
source |
+ str | np.ndarray |
+ The path to the input image or the input image as a numpy array. |
+ required | +
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
foreground |
+ bool |
+ Whether to generate the foreground mask. Defaults to True. |
+ True |
+
batch |
+ bool |
+ Whether to generate masks for a batch of image tiles. Defaults to False. |
+ False |
+
erosion_kernel |
+ tuple |
+ The erosion kernel for filtering object masks and extract borders. +Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. +You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. +The parameter is ignored if unique is True. |
+ 255 |
+
unique |
+ bool |
+ Whether to assign a unique value to each object. Defaults to True. +The unique value increases from 1 to the number of objects. The larger the number, the larger the object area. |
+ True |
+
samgeo/samgeo.py
def generate(
+ self,
+ source,
+ output=None,
+ foreground=True,
+ batch=False,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ unique=True,
+ **kwargs,
+):
+ """Generate masks for the input image.
+
+ Args:
+ source (str | np.ndarray): The path to the input image or the input image as a numpy array.
+ output (str, optional): The path to the output image. Defaults to None.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+ The parameter is ignored if unique is True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
+
+ """
+
+ if isinstance(source, str):
+ if source.startswith("http"):
+ source = download_file(source)
+
+ if not os.path.exists(source):
+ raise ValueError(f"Input path {source} does not exist.")
+
+ if batch: # Subdivide the image into tiles and segment each tile
+ self.batch = True
+ self.source = source
+ self.masks = output
+ return tiff_to_tiff(
+ source,
+ output,
+ self,
+ foreground=foreground,
+ erosion_kernel=erosion_kernel,
+ mask_multiplier=mask_multiplier,
+ **kwargs,
+ )
+
+ image = cv2.imread(source)
+ elif isinstance(source, np.ndarray):
+ image = source
+ source = None
+ else:
+ raise ValueError("Input source must be either a path or a numpy array.")
+
+ self.source = source # Store the input image path
+ self.image = image # Store the input image as a numpy array
+ mask_generator = self.mask_generator # The automatic mask generator
+ masks = mask_generator.generate(image) # Segment the input image
+ self.masks = masks # Store the masks as a list of dictionaries
+ self.batch = False
+
+ if output is not None:
+ # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+ self.save_masks(
+ output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
+ )
+
predict(self, point_coords=None, point_labels=None, boxes=None, point_crs=None, mask_input=None, multimask_output=True, return_logits=False, output=None, index=None, mask_multiplier=255, dtype='float32', return_results=False, **kwargs)
+
+
+¶Predict masks for the given input prompts, using the currently set image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
point_coords |
+ str | dict | list | np.ndarray |
+ A Nx2 array of point prompts to the +model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON +dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None. |
+ None |
+
point_labels |
+ list | int | np.ndarray |
+ A length N array of labels for the +point prompts. 1 indicates a foreground point and 0 indicates a background point. |
+ None |
+
point_crs |
+ str |
+ The coordinate reference system (CRS) of the point prompts. |
+ None |
+
boxes |
+ list | np.ndarray |
+ A length 4 array given a box prompt to the +model, in XYXY format. |
+ None |
+
mask_input |
+ np.ndarray |
+ A low resolution mask input to the model, typically +coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. +multimask_output (bool, optional): If true, the model will return three masks. +For ambiguous input prompts (such as a single click), this will often +produce better masks than a single prediction. If only a single +mask is needed, the model's predicted quality score can be used +to select the best mask. For non-ambiguous prompts, such as multiple +input prompts, multimask_output=False can give better results. |
+ None |
+
return_logits |
+ bool |
+ If true, returns un-thresholded masks logits +instead of a binary mask. |
+ False |
+
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
index |
+ index |
+ The index of the mask to save. Defaults to None, +which will save the mask with the highest score. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
+ 255 |
+
dtype |
+ np.dtype |
+ The data type of the output image. Defaults to np.float32. |
+ 'float32' |
+
return_results |
+ bool |
+ Whether to return the predicted masks, scores, and logits. Defaults to False. |
+ False |
+
samgeo/samgeo.py
def predict(
+ self,
+ point_coords=None,
+ point_labels=None,
+ boxes=None,
+ point_crs=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+ output=None,
+ index=None,
+ mask_multiplier=255,
+ dtype="float32",
+ return_results=False,
+ **kwargs,
+):
+ """Predict masks for the given input prompts, using the currently set image.
+
+ Args:
+ point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
+ dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
+ point_labels (list | int | np.ndarray, optional): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a background point.
+ point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
+ boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
+ multimask_output (bool, optional): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool, optional): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+ output (str, optional): The path to the output image. Defaults to None.
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.
+
+ """
+
+ if isinstance(boxes, str):
+ gdf = gpd.read_file(boxes)
+ if gdf.crs is not None:
+ gdf = gdf.to_crs("epsg:4326")
+ boxes = gdf.geometry.bounds.values.tolist()
+ elif isinstance(boxes, dict):
+ import json
+
+ geojson = json.dumps(boxes)
+ gdf = gpd.read_file(geojson, driver="GeoJSON")
+ boxes = gdf.geometry.bounds.values.tolist()
+
+ if isinstance(point_coords, str):
+ point_coords = vector_to_geojson(point_coords)
+
+ if isinstance(point_coords, dict):
+ point_coords = geojson_to_coords(point_coords)
+
+ if hasattr(self, "point_coords"):
+ point_coords = self.point_coords
+
+ if hasattr(self, "point_labels"):
+ point_labels = self.point_labels
+
+ if (point_crs is not None) and (point_coords is not None):
+ point_coords = coords_to_xy(self.source, point_coords, point_crs)
+
+ if isinstance(point_coords, list):
+ point_coords = np.array(point_coords)
+
+ if point_coords is not None:
+ if point_labels is None:
+ point_labels = [1] * len(point_coords)
+ elif isinstance(point_labels, int):
+ point_labels = [point_labels] * len(point_coords)
+
+ if isinstance(point_labels, list):
+ if len(point_labels) != len(point_coords):
+ if len(point_labels) == 1:
+ point_labels = point_labels * len(point_coords)
+ else:
+ raise ValueError(
+ "The length of point_labels must be equal to the length of point_coords."
+ )
+ point_labels = np.array(point_labels)
+
+ predictor = self.predictor
+
+ input_boxes = None
+ if isinstance(boxes, list) and (point_crs is not None):
+ coords = bbox_to_xy(self.source, boxes, point_crs)
+ input_boxes = np.array(coords)
+ if isinstance(coords[0], int):
+ input_boxes = input_boxes[None, :]
+ else:
+ input_boxes = torch.tensor(input_boxes, device=self.device)
+ input_boxes = predictor.transform.apply_boxes_torch(
+ input_boxes, self.image.shape[:2]
+ )
+ elif isinstance(boxes, list) and (point_crs is None):
+ input_boxes = np.array(boxes)
+ if isinstance(boxes[0], int):
+ input_boxes = input_boxes[None, :]
+
+ self.boxes = input_boxes
+
+ if boxes is None or (not isinstance(boxes[0], list)):
+ masks, scores, logits = predictor.predict(
+ point_coords,
+ point_labels,
+ input_boxes,
+ mask_input,
+ multimask_output,
+ return_logits,
+ )
+ else:
+ masks, scores, logits = predictor.predict_torch(
+ point_coords=point_coords,
+ point_labels=point_coords,
+ boxes=input_boxes,
+ multimask_output=True,
+ )
+
+ self.masks = masks
+ self.scores = scores
+ self.logits = logits
+
+ if output is not None:
+ if boxes is None or (not isinstance(boxes[0], list)):
+ self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
+ else:
+ self.tensor_to_numpy(
+ index, output, mask_multiplier, dtype, save_args=kwargs
+ )
+
+ if return_results:
+ return masks, scores, logits
+
raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs)
+
+
+¶Save the result to a vector file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str |
+ The path to the image file. |
+ required | +
output |
+ str |
+ The path to the vector file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/samgeo.py
def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
+ """Save the result to a vector file.
+
+ Args:
+ image (str): The path to the image file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
save_masks(self, output=None, foreground=True, unique=True, erosion_kernel=None, mask_multiplier=255, **kwargs)
+
+
+¶Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
output |
+ str |
+ The path to the output image. Defaults to None, saving the masks to SamGeo.objects. |
+ None |
+
foreground |
+ bool |
+ Whether to generate the foreground mask. Defaults to True. |
+ True |
+
unique |
+ bool |
+ Whether to assign a unique value to each object. Defaults to True. |
+ True |
+
erosion_kernel |
+ tuple |
+ The erosion kernel for filtering object masks and extract borders. +Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. +You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. |
+ 255 |
+
samgeo/samgeo.py
def save_masks(
+ self,
+ output=None,
+ foreground=True,
+ unique=True,
+ erosion_kernel=None,
+ mask_multiplier=255,
+ **kwargs,
+):
+ """Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
+
+ Args:
+ output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
+ erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
+ Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
+
+ """
+
+ if self.masks is None:
+ raise ValueError("No masks found. Please run generate() first.")
+
+ h, w, _ = self.image.shape
+ masks = self.masks
+
+ # Set output image data type based on the number of objects
+ if len(masks) < 255:
+ dtype = np.uint8
+ elif len(masks) < 65535:
+ dtype = np.uint16
+ else:
+ dtype = np.uint32
+
+ # Generate a mask of objects with unique values
+ if unique:
+ # Sort the masks by area in ascending order
+ sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)
+
+ # Create an output image with the same size as the input image
+ objects = np.zeros(
+ (
+ sorted_masks[0]["segmentation"].shape[0],
+ sorted_masks[0]["segmentation"].shape[1],
+ )
+ )
+ # Assign a unique value to each object
+ for index, ann in enumerate(sorted_masks):
+ m = ann["segmentation"]
+ objects[m] = index + 1
+
+ # Generate a binary mask
+ else:
+ if foreground: # Extract foreground objects only
+ resulting_mask = np.zeros((h, w), dtype=dtype)
+ else:
+ resulting_mask = np.ones((h, w), dtype=dtype)
+ resulting_borders = np.zeros((h, w), dtype=dtype)
+
+ for m in masks:
+ mask = (m["segmentation"] > 0).astype(dtype)
+ resulting_mask += mask
+
+ # Apply erosion to the mask
+ if erosion_kernel is not None:
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
+ mask_erode = (mask_erode > 0).astype(dtype)
+ edge_mask = mask - mask_erode
+ resulting_borders += edge_mask
+
+ resulting_mask = (resulting_mask > 0).astype(dtype)
+ resulting_borders = (resulting_borders > 0).astype(dtype)
+ objects = resulting_mask - resulting_borders
+ objects = objects * mask_multiplier
+
+ objects = objects.astype(dtype)
+ self.objects = objects
+
+ if output is not None: # Save the output image
+ array_to_image(self.objects, output, self.source, **kwargs)
+
save_prediction(self, output, index=None, mask_multiplier=255, dtype=<class 'numpy.float32'>, vector=None, simplify_tolerance=None, **kwargs)
+
+
+¶Save the predicted mask to the output path.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
output |
+ str |
+ The path to the output image. |
+ required | +
index |
+ int |
+ The index of the mask to save. Defaults to None, +which will save the mask with the highest score. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
+ 255 |
+
vector |
+ str |
+ The path to the output vector file. Defaults to None. |
+ None |
+
dtype |
+ np.dtype |
+ The data type of the output image. Defaults to np.float32. |
+ <class 'numpy.float32'> |
+
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/samgeo.py
def save_prediction(
+ self,
+ output,
+ index=None,
+ mask_multiplier=255,
+ dtype=np.float32,
+ vector=None,
+ simplify_tolerance=None,
+ **kwargs,
+):
+ """Save the predicted mask to the output path.
+
+ Args:
+ output (str): The path to the output image.
+ index (int, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ vector (str, optional): The path to the output vector file. Defaults to None.
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+
+ """
+ if self.scores is None:
+ raise ValueError("No predictions found. Please run predict() first.")
+
+ if index is None:
+ index = self.scores.argmax(axis=0)
+
+ array = self.masks[index] * mask_multiplier
+ self.prediction = array
+ array_to_image(array, output, self.source, dtype=dtype, **kwargs)
+
+ if vector is not None:
+ raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)
+
set_image(self, image, image_format='RGB')
+
+
+¶Set the input image as a numpy array.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ np.ndarray |
+ The input image as a numpy array. |
+ required | +
image_format |
+ str |
+ The image format, can be RGB or BGR. Defaults to "RGB". |
+ 'RGB' |
+
samgeo/samgeo.py
def set_image(self, image, image_format="RGB"):
+ """Set the input image as a numpy array.
+
+ Args:
+ image (np.ndarray): The input image as a numpy array.
+ image_format (str, optional): The image format, can be RGB or BGR. Defaults to "RGB".
+ """
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+
+ image = cv2.imread(image)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ self.image = image
+ elif isinstance(image, np.ndarray):
+ pass
+ else:
+ raise ValueError("Input image must be either a path or a numpy array.")
+
+ self.predictor.set_image(image, image_format=image_format)
+
show_anns(self, figsize=(12, 10), axis='off', alpha=0.35, output=None, blend=True, **kwargs)
+
+
+¶Show the annotations (objects with random color) on the input image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
figsize |
+ tuple |
+ The figure size. Defaults to (12, 10). |
+ (12, 10) |
+
axis |
+ str |
+ Whether to show the axis. Defaults to "off". |
+ 'off' |
+
alpha |
+ float |
+ The alpha value for the annotations. Defaults to 0.35. |
+ 0.35 |
+
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
blend |
+ bool |
+ Whether to show the input image. Defaults to True. |
+ True |
+
samgeo/samgeo.py
def show_anns(
+ self,
+ figsize=(12, 10),
+ axis="off",
+ alpha=0.35,
+ output=None,
+ blend=True,
+ **kwargs,
+):
+ """Show the annotations (objects with random color) on the input image.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
+ output (str, optional): The path to the output image. Defaults to None.
+ blend (bool, optional): Whether to show the input image. Defaults to True.
+ """
+
+ import matplotlib.pyplot as plt
+
+ anns = self.masks
+
+ if self.image is None:
+ print("Please run generate() first.")
+ return
+
+ if anns is None or len(anns) == 0:
+ return
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.image)
+
+ sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
+
+ ax = plt.gca()
+ ax.set_autoscale_on(False)
+
+ img = np.ones(
+ (
+ sorted_anns[0]["segmentation"].shape[0],
+ sorted_anns[0]["segmentation"].shape[1],
+ 4,
+ )
+ )
+ img[:, :, 3] = 0
+ for ann in sorted_anns:
+ m = ann["segmentation"]
+ color_mask = np.concatenate([np.random.random(3), [alpha]])
+ img[m] = color_mask
+ ax.imshow(img)
+
+ if "dpi" not in kwargs:
+ kwargs["dpi"] = 100
+
+ if "bbox_inches" not in kwargs:
+ kwargs["bbox_inches"] = "tight"
+
+ plt.axis(axis)
+
+ self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
+
+ if output is not None:
+ if blend:
+ array = blend_images(
+ self.annotations, self.image, alpha=alpha, show=False
+ )
+ else:
+ array = self.annotations
+ array_to_image(array, output, self.source)
+
show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)
+
+
+¶Show a canvas to collect foreground and background points.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str | np.ndarray |
+ The input image. |
+ required | +
fg_color |
+ tuple |
+ The color for the foreground points. Defaults to (0, 255, 0). |
+ (0, 255, 0) |
+
bg_color |
+ tuple |
+ The color for the background points. Defaults to (0, 0, 255). |
+ (0, 0, 255) |
+
radius |
+ int |
+ The radius of the points. Defaults to 5. |
+ 5 |
+
Returns:
+Type | +Description | +
---|---|
tuple |
+ A tuple of two lists of foreground and background points. |
+
samgeo/samgeo.py
def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
+ """Show a canvas to collect foreground and background points.
+
+ Args:
+ image (str | np.ndarray): The input image.
+ fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
+ bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
+ radius (int, optional): The radius of the points. Defaults to 5.
+
+ Returns:
+ tuple: A tuple of two lists of foreground and background points.
+ """
+
+ if self.image is None:
+ raise ValueError("Please run set_image() first.")
+
+ image = self.image
+ fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)
+ self.fg_points = fg_points
+ self.bg_points = bg_points
+ point_coords = fg_points + bg_points
+ point_labels = [1] * len(fg_points) + [0] * len(bg_points)
+ self.point_coords = point_coords
+ self.point_labels = point_labels
+
show_map(self, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)
+
+
+¶Show the interactive map.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
basemap |
+ str |
+ The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID. |
+ 'SATELLITE' |
+
repeat_mode |
+ bool |
+ Whether to use the repeat mode for draw control. Defaults to True. |
+ True |
+
out_dir |
+ str |
+ The path to the output directory. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
leafmap.Map |
+ The map object. |
+
samgeo/samgeo.py
def show_map(self, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
+ """Show the interactive map.
+
+ Args:
+ basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
+ repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.
+ out_dir (str, optional): The path to the output directory. Defaults to None.
+
+ Returns:
+ leafmap.Map: The map object.
+ """
+ return sam_map_gui(
+ self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
+ )
+
show_masks(self, figsize=(12, 10), cmap='binary_r', axis='off', foreground=True, **kwargs)
+
+
+¶Show the binary mask or the mask of objects with unique values.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
figsize |
+ tuple |
+ The figure size. Defaults to (12, 10). |
+ (12, 10) |
+
cmap |
+ str |
+ The colormap. Defaults to "binary_r". |
+ 'binary_r' |
+
axis |
+ str |
+ Whether to show the axis. Defaults to "off". |
+ 'off' |
+
foreground |
+ bool |
+ Whether to show the foreground mask only. Defaults to True. |
+ True |
+
**kwargs |
+ + | Other arguments for save_masks(). |
+ {} |
+
samgeo/samgeo.py
def show_masks(
+ self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs
+):
+ """Show the binary mask or the mask of objects with unique values.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ cmap (str, optional): The colormap. Defaults to "binary_r".
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
+ **kwargs: Other arguments for save_masks().
+ """
+
+ import matplotlib.pyplot as plt
+
+ if self.batch:
+ self.objects = cv2.imread(self.masks)
+ else:
+ if self.objects is None:
+ self.save_masks(foreground=foreground, **kwargs)
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.objects, cmap=cmap)
+ plt.axis(axis)
+ plt.show()
+
tensor_to_numpy(self, index=None, output=None, mask_multiplier=255, dtype='uint8', save_args={})
+
+
+¶Convert the predicted masks from tensors to numpy arrays.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
index |
+ index |
+ The index of the mask to save. Defaults to None, +which will save the mask with the highest score. |
+ None |
+
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
mask_multiplier |
+ int |
+ The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
+ 255 |
+
dtype |
+ np.dtype |
+ The data type of the output image. Defaults to np.uint8. |
+ 'uint8' |
+
save_args |
+ dict |
+ Optional arguments for saving the output image. Defaults to {}. |
+ {} |
+
Returns:
+Type | +Description | +
---|---|
np.ndarray |
+ The predicted mask as a numpy array. |
+
samgeo/samgeo.py
def tensor_to_numpy(
+ self, index=None, output=None, mask_multiplier=255, dtype="uint8", save_args={}
+):
+ """Convert the predicted masks from tensors to numpy arrays.
+
+ Args:
+ index (index, optional): The index of the mask to save. Defaults to None,
+ which will save the mask with the highest score.
+ output (str, optional): The path to the output image. Defaults to None.
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.
+ save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.
+
+ Returns:
+ np.ndarray: The predicted mask as a numpy array.
+ """
+
+ boxes = self.boxes
+ masks = self.masks
+
+ image_pil = self.image
+ image_np = np.array(image_pil)
+
+ if index is None:
+ index = 1
+
+ masks = masks[:, index, :, :]
+ masks = masks.squeeze(1)
+
+ if boxes is None or (len(boxes) == 0): # No "object" instances found
+ print("No objects found in the image.")
+ return
+ else:
+ # Create an empty image to store the mask overlays
+ mask_overlay = np.zeros_like(
+ image_np[..., 0], dtype=dtype
+ ) # Adjusted for single channel
+
+ for i, (box, mask) in enumerate(zip(boxes, masks)):
+ # Convert tensor to numpy array if necessary and ensure it contains integers
+ if isinstance(mask, torch.Tensor):
+ mask = (
+ mask.cpu().numpy().astype(dtype)
+ ) # If mask is on GPU, use .cpu() before .numpy()
+ mask_overlay += ((mask > 0) * (i + 1)).astype(
+ dtype
+ ) # Assign a unique value for each mask
+
+ # Normalize mask_overlay to be in [0, 255]
+ mask_overlay = (
+ mask_overlay > 0
+ ) * mask_multiplier # Binary mask in [0, 255]
+
+ if output is not None:
+ array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
+ else:
+ return mask_overlay
+
tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a GeoJSON file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the GeoJSON file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/samgeo.py
def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a GeoJSON file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the GeoJSON file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_geojson(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a gpkg file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the gpkg file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/samgeo.py
def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the gpkg file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_gpkg(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a shapefile.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the shapefile. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/samgeo.py
def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a shapefile.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the shapefile.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_shp(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs)
+
+
+¶Convert a tiff file to a gpkg file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tiff_path |
+ str |
+ The path to the tiff file. |
+ required | +
output |
+ str |
+ The path to the vector file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/samgeo.py
def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):
+ """Convert a tiff file to a gpkg file.
+
+ Args:
+ tiff_path (str): The path to the tiff file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(
+ tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
+ )
+
+SamGeoPredictor (SamPredictor)
+
+
+
+
+¶samgeo/samgeo.py
class SamGeoPredictor(SamPredictor):
+ def __init__(
+ self,
+ sam_model,
+ ):
+ from segment_anything.utils.transforms import ResizeLongestSide
+
+ self.model = sam_model
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
+
+ def set_image(self, image):
+ super(SamGeoPredictor, self).set_image(image)
+
+ def predict(
+ self,
+ src_fp=None,
+ geo_box=None,
+ point_coords=None,
+ point_labels=None,
+ box=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+ ):
+ if geo_box and src_fp:
+ self.crs = "EPSG:4326"
+ dst_crs = get_crs(src_fp)
+ sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)
+ ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)
+ xs = np.array([sw[0], ne[0]])
+ ys = np.array([sw[1], ne[1]])
+ box = get_pixel_coords(src_fp, xs, ys)
+ self.geo_box = geo_box
+ self.width = box[2] - box[0]
+ self.height = box[3] - box[1]
+ self.geo_transform = set_transform(geo_box, self.width, self.height)
+
+ masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(
+ point_coords, point_labels, box, mask_input, multimask_output, return_logits
+ )
+
+ return masks, iou_predictions, low_res_masks
+
+ def masks_to_geotiff(self, src_fp, dst_fp, masks):
+ profile = get_profile(src_fp)
+ write_raster(
+ dst_fp,
+ masks,
+ profile,
+ self.width,
+ self.height,
+ self.geo_transform,
+ self.crs,
+ )
+
+ def geotiff_to_geojson(self, src_fp, dst_fp, bidx=1):
+ gdf = get_features(src_fp, bidx)
+ write_features(gdf, dst_fp)
+ return gdf
+
predict(self, src_fp=None, geo_box=None, point_coords=None, point_labels=None, box=None, mask_input=None, multimask_output=True, return_logits=False)
+
+
+¶Predict masks for the given input prompts, using the currently set image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
point_coords |
+ np.ndarray or None |
+ A Nx2 array of point prompts to the +model. Each point is in (X,Y) in pixels. |
+ None |
+
point_labels |
+ np.ndarray or None |
+ A length N array of labels for the +point prompts. 1 indicates a foreground point and 0 indicates a +background point. |
+ None |
+
box |
+ np.ndarray or None |
+ A length 4 array given a box prompt to the +model, in XYXY format. |
+ None |
+
mask_input |
+ np.ndarray |
+ A low resolution mask input to the model, typically +coming from a previous prediction iteration. Has form 1xHxW, where +for SAM, H=W=256. |
+ None |
+
multimask_output |
+ bool |
+ If true, the model will return three masks. +For ambiguous input prompts (such as a single click), this will often +produce better masks than a single prediction. If only a single +mask is needed, the model's predicted quality score can be used +to select the best mask. For non-ambiguous prompts, such as multiple +input prompts, multimask_output=False can give better results. |
+ True |
+
return_logits |
+ bool |
+ If true, returns un-thresholded masks logits +instead of a binary mask. |
+ False |
+
Returns:
+Type | +Description | +
---|---|
(np.ndarray) |
+ The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. +(np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. +(np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. |
+
samgeo/samgeo.py
def predict(
+ self,
+ src_fp=None,
+ geo_box=None,
+ point_coords=None,
+ point_labels=None,
+ box=None,
+ mask_input=None,
+ multimask_output=True,
+ return_logits=False,
+):
+ if geo_box and src_fp:
+ self.crs = "EPSG:4326"
+ dst_crs = get_crs(src_fp)
+ sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)
+ ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)
+ xs = np.array([sw[0], ne[0]])
+ ys = np.array([sw[1], ne[1]])
+ box = get_pixel_coords(src_fp, xs, ys)
+ self.geo_box = geo_box
+ self.width = box[2] - box[0]
+ self.height = box[3] - box[1]
+ self.geo_transform = set_transform(geo_box, self.width, self.height)
+
+ masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(
+ point_coords, point_labels, box, mask_input, multimask_output, return_logits
+ )
+
+ return masks, iou_predictions, low_res_masks
+
set_image(self, image)
+
+
+¶Calculates the image embeddings for the provided image, allowing +masks to be predicted with the 'predict' method.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ np.ndarray |
+ The image for calculating masks. Expects an +image in HWC uint8 format, with pixel values in [0, 255]. |
+ required | +
image_format |
+ str |
+ The color format of the image, in ['RGB', 'BGR']. |
+ required | +
samgeo/samgeo.py
def set_image(self, image):
+ super(SamGeoPredictor, self).set_image(image)
+
A Python package for segmenting geospatial data with the Segment Anything Model (SAM) \ud83d\uddfa\ufe0f
"},{"location":"#introduction","title":"Introduction","text":"The segment-geospatial package draws its inspiration from segment-anything-eo repository authored by Aliaksandr Hancharenka. To facilitate the use of the Segment Anything Model (SAM) for geospatial data, I have developed the segment-anything-py and segment-geospatial Python packages, which are now available on PyPI and conda-forge. My primary objective is to simplify the process of leveraging SAM for geospatial data analysis by enabling users to achieve this with minimal coding effort. I have adapted the source code of segment-geospatial from the segment-anything-eo repository, and credit for its original version goes to Aliaksandr Hancharenka.
Video tutorials are available on my YouTube Channel.
The Segment Anything Model is computationally intensive, and a powerful GPU is recommended to process large datasets. It is recommended to have a GPU with at least 8 GB of GPU memory. You can utilize the free GPU resources provided by Google Colab. Alternatively, you can apply for AWS Cloud Credit for Research, which offers cloud credits to support academic research. If you are in the Greater China region, apply for the AWS Cloud Credit here.
"},{"location":"#legal-notice","title":"Legal Notice","text":"This repository and its content are provided for educational purposes only. By using the information and code provided, users acknowledge that they are using the APIs and models at their own risk and agree to comply with any applicable laws and regulations. Users who intend to download a large number of image tiles from any basemap are advised to contact the basemap provider to obtain permission before doing so. Unauthorized use of the basemap or any of its components may be a violation of copyright laws or other applicable laws and regulations.
"},{"location":"#acknowledgements","title":"Acknowledgements","text":"This project is based upon work partially supported by the National Aeronautics and Space Administration (NASA) under Grant No. 80NSSC22K1742 issued through the Open Source Tools, Frameworks, and Libraries 2020 Program.
This project is also supported by Amazon Web Services (AWS). In addition, this package was made possible by the following open source projects. Credit goes to the developers of these projects.
New Features
New Features
Improvements
Improvements
Improvements
New Features
Improvements
New Features
Improvements
Improvements
New Features
Improvements
New Features
Improvements
Contributors
@p-vdp @LucasOsco
"},{"location":"changelog/#v062-may-17-2023","title":"v0.6.2 - May 17, 2023","text":"Improvements
New Features
New Features
Improvements
Demos
New Features
Improvements
Demos
"},{"location":"changelog/#v040-may-6-2023","title":"v0.4.0 - May 6, 2023","text":"New Features
SamGeo
class, including show_masks
, save_masks
, show_anns
, making it much easier to save segmentation results in GeoTIFF and vector formats.common
module, including array_to_image
, show_image
, download_file
, overlay_images
, blend_images
, and update_package
SamGeoPredictor
classImprovements
SamGeo.generate()
methodDemos
Contributors
@darrenwiens
"},{"location":"changelog/#v030-apr-26-2023","title":"v0.3.0 - Apr 26, 2023","text":"New Features
get_basemaps
, reproject
, tiff_to_shp
, and tiff_to_geojson
Improvement
tiff_to_vector
crs bug #12crs
parameter to tms_to_geotiff
New Features
SamGeo.generate
methodSamGeo.tiff_to_vector
methodNew Features
SamGeo
classInitial release
"},{"location":"common/","title":"common module","text":"The source code is adapted from https://github.com/aliaksandr960/segment-anything-eo. Credit to the author Aliaksandr Hancharenka.
"},{"location":"common/#samgeo.common.array_to_image","title":"array_to_image(array, output, source=None, dtype=None, compress='deflate', **kwargs)
","text":"Save a NumPy array as a GeoTIFF using the projection information from an existing GeoTIFF file.
Parameters:
Name Type Description Defaultarray
np.ndarray
The NumPy array to be saved as a GeoTIFF.
requiredoutput
str
The path to the output image.
requiredsource
str
The path to an existing GeoTIFF file with map projection information. Defaults to None.
None
dtype
np.dtype
The data type of the output array. Defaults to None.
None
compress
str
The compression method. Can be one of the following: \"deflate\", \"lzw\", \"packbits\", \"jpeg\". Defaults to \"deflate\".
'deflate'
Source code in samgeo/common.py
def array_to_image(\n array, output, source=None, dtype=None, compress=\"deflate\", **kwargs\n):\n\"\"\"Save a NumPy array as a GeoTIFF using the projection information from an existing GeoTIFF file.\n\n Args:\n array (np.ndarray): The NumPy array to be saved as a GeoTIFF.\n output (str): The path to the output image.\n source (str, optional): The path to an existing GeoTIFF file with map projection information. Defaults to None.\n dtype (np.dtype, optional): The data type of the output array. Defaults to None.\n compress (str, optional): The compression method. Can be one of the following: \"deflate\", \"lzw\", \"packbits\", \"jpeg\". Defaults to \"deflate\".\n \"\"\"\n\n from PIL import Image\n\n if isinstance(array, str) and os.path.exists(array):\n array = cv2.imread(array)\n array = cv2.cvtColor(array, cv2.COLOR_BGR2RGB)\n\n if output.endswith(\".tif\") and source is not None:\n with rasterio.open(source) as src:\n crs = src.crs\n transform = src.transform\n if compress is None:\n compress = src.compression\n\n # Determine the minimum and maximum values in the array\n\n min_value = np.min(array)\n max_value = np.max(array)\n\n if dtype is None:\n # Determine the best dtype for the array\n if min_value >= 0 and max_value <= 1:\n dtype = np.float32\n elif min_value >= 0 and max_value <= 255:\n dtype = np.uint8\n elif min_value >= -128 and max_value <= 127:\n dtype = np.int8\n elif min_value >= 0 and max_value <= 65535:\n dtype = np.uint16\n elif min_value >= -32768 and max_value <= 32767:\n dtype = np.int16\n else:\n dtype = np.float64\n\n # Convert the array to the best dtype\n array = array.astype(dtype)\n\n # Define the GeoTIFF metadata\n if array.ndim == 2:\n metadata = {\n \"driver\": \"GTiff\",\n \"height\": array.shape[0],\n \"width\": array.shape[1],\n \"count\": 1,\n \"dtype\": array.dtype,\n \"crs\": crs,\n \"transform\": transform,\n }\n elif array.ndim == 3:\n metadata = {\n \"driver\": \"GTiff\",\n \"height\": array.shape[0],\n \"width\": array.shape[1],\n \"count\": array.shape[2],\n \"dtype\": array.dtype,\n \"crs\": crs,\n \"transform\": transform,\n }\n\n if compress is not None:\n metadata[\"compress\"] = compress\n else:\n raise ValueError(\"Array must be 2D or 3D.\")\n\n # Create a new GeoTIFF file and write the array to it\n with rasterio.open(output, \"w\", **metadata) as dst:\n if array.ndim == 2:\n dst.write(array, 1)\n elif array.ndim == 3:\n for i in range(array.shape[2]):\n dst.write(array[:, :, i], i + 1)\n\n else:\n img = Image.fromarray(array)\n img.save(output, **kwargs)\n
"},{"location":"common/#samgeo.common.bbox_to_xy","title":"bbox_to_xy(src_fp, coords, coord_crs='epsg:4326', **kwargs)
","text":"Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates. Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright
Parameters:
Name Type Description Defaultsrc_fp
str
The source raster file path.
requiredcoords
list
A list of coordinates in the format of [[minx, miny, maxx, maxy], [minx, miny, maxx, maxy], ...]
requiredcoord_crs
str
The coordinate CRS of the input coordinates. Defaults to \"epsg:4326\".
'epsg:4326'
Returns:
Type Descriptionlist
A list of pixel coordinates in the format of [[minx, maxy, maxx, miny], ...] from top left to bottom right.
Source code insamgeo/common.py
def bbox_to_xy(\n src_fp: str, coords: list, coord_crs: str = \"epsg:4326\", **kwargs\n) -> list:\n\"\"\"Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.\n Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright\n While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright\n\n Args:\n src_fp (str): The source raster file path.\n coords (list): A list of coordinates in the format of [[minx, miny, maxx, maxy], [minx, miny, maxx, maxy], ...]\n coord_crs (str, optional): The coordinate CRS of the input coordinates. Defaults to \"epsg:4326\".\n\n Returns:\n list: A list of pixel coordinates in the format of [[minx, maxy, maxx, miny], ...] from top left to bottom right.\n \"\"\"\n\n if isinstance(coords, str):\n gdf = gpd.read_file(coords)\n coords = gdf.geometry.bounds.values.tolist()\n if gdf.crs is not None:\n coord_crs = f\"epsg:{gdf.crs.to_epsg()}\"\n elif isinstance(coords, np.ndarray):\n coords = coords.tolist()\n if isinstance(coords, dict):\n import json\n\n geojson = json.dumps(coords)\n gdf = gpd.read_file(geojson, driver=\"GeoJSON\")\n coords = gdf.geometry.bounds.values.tolist()\n\n elif not isinstance(coords, list):\n raise ValueError(\"coords must be a list of coordinates.\")\n\n if not isinstance(coords[0], list):\n coords = [coords]\n\n new_coords = []\n\n with rasterio.open(src_fp) as src:\n width = src.width\n height = src.height\n\n for coord in coords:\n minx, miny, maxx, maxy = coord\n\n if coord_crs != src.crs:\n minx, miny = transform_coords(minx, miny, coord_crs, src.crs, **kwargs)\n maxx, maxy = transform_coords(maxx, maxy, coord_crs, src.crs, **kwargs)\n\n rows1, cols1 = rasterio.transform.rowcol(\n src.transform, minx, miny, **kwargs\n )\n rows2, cols2 = rasterio.transform.rowcol(\n src.transform, maxx, maxy, **kwargs\n )\n\n new_coords.append([cols1, rows1, cols2, rows2])\n\n else:\n new_coords.append([minx, miny, maxx, maxy])\n\n result = []\n\n for coord in new_coords:\n minx, miny, maxx, maxy = coord\n\n if (\n minx >= 0\n and miny >= 0\n and maxx >= 0\n and maxy >= 0\n and minx < width\n and miny < height\n and maxx < width\n and maxy < height\n ):\n # Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright\n # While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright\n result.append([minx, maxy, maxx, miny])\n\n if len(result) == 0:\n print(\"No valid pixel coordinates found.\")\n return None\n elif len(result) == 1:\n return result[0]\n elif len(result) < len(coords):\n print(\"Some coordinates are out of the image boundary.\")\n\n return result\n
"},{"location":"common/#samgeo.common.blend_images","title":"blend_images(img1, img2, alpha=0.5, output=False, show=True, figsize=(12, 10), axis='off', **kwargs)
","text":"Blends two images together using the addWeighted function from the OpenCV library.
Parameters:
Name Type Description Defaultimg1
numpy.ndarray
The first input image on top represented as a NumPy array.
requiredimg2
numpy.ndarray
The second input image at the bottom represented as a NumPy array.
requiredalpha
float
The weighting factor for the first image in the blend. By default, this is set to 0.5.
0.5
output
str
The path to the output image. Defaults to False.
False
show
bool
Whether to display the blended image. Defaults to True.
True
figsize
tuple
The size of the figure. Defaults to (12, 10).
(12, 10)
axis
str
The axis of the figure. Defaults to \"off\".
'off'
**kwargs
Additional keyword arguments to pass to the cv2.addWeighted() function.
{}
Returns:
Type Descriptionnumpy.ndarray
The blended image as a NumPy array.
Source code insamgeo/common.py
def blend_images(\n img1,\n img2,\n alpha=0.5,\n output=False,\n show=True,\n figsize=(12, 10),\n axis=\"off\",\n **kwargs,\n):\n\"\"\"\n Blends two images together using the addWeighted function from the OpenCV library.\n\n Args:\n img1 (numpy.ndarray): The first input image on top represented as a NumPy array.\n img2 (numpy.ndarray): The second input image at the bottom represented as a NumPy array.\n alpha (float): The weighting factor for the first image in the blend. By default, this is set to 0.5.\n output (str, optional): The path to the output image. Defaults to False.\n show (bool, optional): Whether to display the blended image. Defaults to True.\n figsize (tuple, optional): The size of the figure. Defaults to (12, 10).\n axis (str, optional): The axis of the figure. Defaults to \"off\".\n **kwargs: Additional keyword arguments to pass to the cv2.addWeighted() function.\n\n Returns:\n numpy.ndarray: The blended image as a NumPy array.\n \"\"\"\n # Resize the images to have the same dimensions\n if isinstance(img1, str):\n if img1.startswith(\"http\"):\n img1 = download_file(img1)\n\n if not os.path.exists(img1):\n raise ValueError(f\"Input path {img1} does not exist.\")\n\n img1 = cv2.imread(img1)\n\n if isinstance(img2, str):\n if img2.startswith(\"http\"):\n img2 = download_file(img2)\n\n if not os.path.exists(img2):\n raise ValueError(f\"Input path {img2} does not exist.\")\n\n img2 = cv2.imread(img2)\n\n if img1.dtype == np.float32:\n img1 = (img1 * 255).astype(np.uint8)\n\n if img2.dtype == np.float32:\n img2 = (img2 * 255).astype(np.uint8)\n\n if img1.dtype != img2.dtype:\n img2 = img2.astype(img1.dtype)\n\n img1 = cv2.resize(img1, (img2.shape[1], img2.shape[0]))\n\n # Blend the images using the addWeighted function\n beta = 1 - alpha\n blend_img = cv2.addWeighted(img1, alpha, img2, beta, 0, **kwargs)\n\n if output:\n array_to_image(blend_img, output, img2)\n\n if show:\n plt.figure(figsize=figsize)\n plt.imshow(blend_img)\n plt.axis(axis)\n plt.show()\n else:\n return blend_img\n
"},{"location":"common/#samgeo.common.boxes_to_vector","title":"boxes_to_vector(coords, src_crs, dst_crs='EPSG:4326', output=None, **kwargs)
","text":"Convert a list of bounding box coordinates to vector data.
Parameters:
Name Type Description Defaultcoords
list
A list of bounding box coordinates in the format [[left, top, right, bottom], [left, top, right, bottom], ...].
requiredsrc_crs
int or str
The EPSG code or proj4 string representing the source coordinate reference system (CRS) of the input coordinates.
requireddst_crs
int or str
The EPSG code or proj4 string representing the destination CRS to reproject the data (default is \"EPSG:4326\").
'EPSG:4326'
output
str or None
The full file path (including the directory and filename without the extension) where the vector data should be saved. If None (default), the function returns the GeoDataFrame without saving it to a file.
None
**kwargs
Additional keyword arguments to pass to geopandas.GeoDataFrame.to_file() when saving the vector data.
{}
Returns:
Type Descriptiongeopandas.GeoDataFrame or None
The GeoDataFrame with the converted vector data if output is None, otherwise None if the data is saved to a file.
Source code insamgeo/common.py
def boxes_to_vector(coords, src_crs, dst_crs=\"EPSG:4326\", output=None, **kwargs):\n\"\"\"\n Convert a list of bounding box coordinates to vector data.\n\n Args:\n coords (list): A list of bounding box coordinates in the format [[left, top, right, bottom], [left, top, right, bottom], ...].\n src_crs (int or str): The EPSG code or proj4 string representing the source coordinate reference system (CRS) of the input coordinates.\n dst_crs (int or str, optional): The EPSG code or proj4 string representing the destination CRS to reproject the data (default is \"EPSG:4326\").\n output (str or None, optional): The full file path (including the directory and filename without the extension) where the vector data should be saved.\n If None (default), the function returns the GeoDataFrame without saving it to a file.\n **kwargs: Additional keyword arguments to pass to geopandas.GeoDataFrame.to_file() when saving the vector data.\n\n Returns:\n geopandas.GeoDataFrame or None: The GeoDataFrame with the converted vector data if output is None, otherwise None if the data is saved to a file.\n \"\"\"\n\n from shapely.geometry import box\n\n # Create a list of Shapely Polygon objects based on the provided coordinates\n polygons = [box(*coord) for coord in coords]\n\n # Create a GeoDataFrame with the Shapely Polygon objects\n gdf = gpd.GeoDataFrame({\"geometry\": polygons}, crs=src_crs)\n\n # Reproject the GeoDataFrame to the specified EPSG code\n gdf_reprojected = gdf.to_crs(dst_crs)\n\n if output is not None:\n gdf_reprojected.to_file(output, **kwargs)\n else:\n return gdf_reprojected\n
"},{"location":"common/#samgeo.common.check_file_path","title":"check_file_path(file_path, make_dirs=True)
","text":"Gets the absolute file path.
Parameters:
Name Type Description Defaultfile_path
str
The path to the file.
requiredmake_dirs
bool
Whether to create the directory if it does not exist. Defaults to True.
True
Exceptions:
Type DescriptionFileNotFoundError
If the directory could not be found.
TypeError
If the input directory path is not a string.
Returns:
Type Descriptionstr
The absolute path to the file.
Source code insamgeo/common.py
def check_file_path(file_path, make_dirs=True):\n\"\"\"Gets the absolute file path.\n\n Args:\n file_path (str): The path to the file.\n make_dirs (bool, optional): Whether to create the directory if it does not exist. Defaults to True.\n\n Raises:\n FileNotFoundError: If the directory could not be found.\n TypeError: If the input directory path is not a string.\n\n Returns:\n str: The absolute path to the file.\n \"\"\"\n if isinstance(file_path, str):\n if file_path.startswith(\"~\"):\n file_path = os.path.expanduser(file_path)\n else:\n file_path = os.path.abspath(file_path)\n\n file_dir = os.path.dirname(file_path)\n if not os.path.exists(file_dir) and make_dirs:\n os.makedirs(file_dir)\n\n return file_path\n\n else:\n raise TypeError(\"The provided file path must be a string.\")\n
"},{"location":"common/#samgeo.common.coords_to_geojson","title":"coords_to_geojson(coords, output=None)
","text":"Convert a list of coordinates (lon, lat) to a GeoJSON string or file.
Parameters:
Name Type Description Defaultcoords
list
A list of coordinates (lon, lat).
requiredoutput
str
The output file path. Defaults to None.
None
Returns:
Type Descriptiondict
A GeoJSON dictionary.
Source code insamgeo/common.py
def coords_to_geojson(coords, output=None):\n\"\"\"Convert a list of coordinates (lon, lat) to a GeoJSON string or file.\n\n Args:\n coords (list): A list of coordinates (lon, lat).\n output (str, optional): The output file path. Defaults to None.\n\n Returns:\n dict: A GeoJSON dictionary.\n \"\"\"\n\n import json\n\n if len(coords) == 0:\n return\n # Create a GeoJSON FeatureCollection object\n feature_collection = {\"type\": \"FeatureCollection\", \"features\": []}\n\n # Iterate through the coordinates list and create a GeoJSON Feature object for each coordinate\n for coord in coords:\n feature = {\n \"type\": \"Feature\",\n \"geometry\": {\"type\": \"Point\", \"coordinates\": coord},\n \"properties\": {},\n }\n feature_collection[\"features\"].append(feature)\n\n # Convert the FeatureCollection object to a JSON string\n geojson_str = json.dumps(feature_collection)\n\n if output is not None:\n with open(output, \"w\") as f:\n f.write(geojson_str)\n else:\n return geojson_str\n
"},{"location":"common/#samgeo.common.coords_to_xy","title":"coords_to_xy(src_fp, coords, coord_crs='epsg:4326', **kwargs)
","text":"Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
Parameters:
Name Type Description Defaultsrc_fp
str
The source raster file path.
requiredcoords
list
A list of coordinates in the format of [[x1, y1], [x2, y2], ...]
requiredcoord_crs
str
The coordinate CRS of the input coordinates. Defaults to \"epsg:4326\".
'epsg:4326'
**kwargs
Additional keyword arguments to pass to rasterio.transform.rowcol.
{}
Returns:
Type Descriptionlist
A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
Source code insamgeo/common.py
def coords_to_xy(\n src_fp: str, coords: list, coord_crs: str = \"epsg:4326\", **kwargs\n) -> list:\n\"\"\"Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.\n\n Args:\n src_fp: The source raster file path.\n coords: A list of coordinates in the format of [[x1, y1], [x2, y2], ...]\n coord_crs: The coordinate CRS of the input coordinates. Defaults to \"epsg:4326\".\n **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.\n\n Returns:\n A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]\n \"\"\"\n if isinstance(coords, np.ndarray):\n coords = coords.tolist()\n\n xs, ys = zip(*coords)\n with rasterio.open(src_fp) as src:\n width = src.width\n height = src.height\n if coord_crs != src.crs:\n xs, ys = transform_coords(xs, ys, coord_crs, src.crs, **kwargs)\n rows, cols = rasterio.transform.rowcol(src.transform, xs, ys, **kwargs)\n result = [[col, row] for col, row in zip(cols, rows)]\n\n result = [\n [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height\n ]\n if len(result) == 0:\n print(\"No valid pixel coordinates found.\")\n elif len(result) < len(coords):\n print(\"Some coordinates are out of the image boundary.\")\n\n return result\n
"},{"location":"common/#samgeo.common.download_checkpoint","title":"download_checkpoint(model_type='vit_h', checkpoint_dir=None, hq=False)
","text":"Download the SAM model checkpoint.
Parameters:
Name Type Description Defaultmodel_type
str
The model type. Can be one of ['vit_h', 'vit_l', 'vit_b']. Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
'vit_h'
checkpoint_dir
str
The checkpoint_dir directory. Defaults to None, \"~/.cache/torch/hub/checkpoints\".
None
hq
bool
Whether to use HQ-SAM model (https://github.com/SysCV/sam-hq). Defaults to False.
False
Source code in samgeo/common.py
def download_checkpoint(model_type=\"vit_h\", checkpoint_dir=None, hq=False):\n\"\"\"Download the SAM model checkpoint.\n\n Args:\n model_type (str, optional): The model type. Can be one of ['vit_h', 'vit_l', 'vit_b'].\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n checkpoint_dir (str, optional): The checkpoint_dir directory. Defaults to None, \"~/.cache/torch/hub/checkpoints\".\n hq (bool, optional): Whether to use HQ-SAM model (https://github.com/SysCV/sam-hq). Defaults to False.\n \"\"\"\n\n if not hq:\n model_types = {\n \"vit_h\": {\n \"name\": \"sam_vit_h_4b8939.pth\",\n \"url\": \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth\",\n },\n \"vit_l\": {\n \"name\": \"sam_vit_l_0b3195.pth\",\n \"url\": \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth\",\n },\n \"vit_b\": {\n \"name\": \"sam_vit_b_01ec64.pth\",\n \"url\": \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\",\n },\n }\n else:\n model_types = {\n \"vit_h\": {\n \"name\": \"sam_hq_vit_h.pth\",\n \"url\": \"https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing\",\n },\n \"vit_l\": {\n \"name\": \"sam_hq_vit_l.pth\",\n \"url\": \"https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing\",\n },\n \"vit_b\": {\n \"name\": \"sam_hq_vit_b.pth\",\n \"url\": \"https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing\",\n },\n \"vit_tiny\": {\n \"name\": \"sam_hq_vit_tiny.pth\",\n \"url\": \"https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth\",\n },\n }\n\n if model_type not in model_types:\n raise ValueError(\n f\"Invalid model_type: {model_type}. It must be one of {', '.join(model_types)}\"\n )\n\n if checkpoint_dir is None:\n checkpoint_dir = os.environ.get(\n \"TORCH_HOME\", os.path.expanduser(\"~/.cache/torch/hub/checkpoints\")\n )\n\n checkpoint = os.path.join(checkpoint_dir, model_types[model_type][\"name\"])\n if not os.path.exists(checkpoint):\n print(f\"Model checkpoint for {model_type} not found.\")\n url = model_types[model_type][\"url\"]\n download_file(url, checkpoint)\n return checkpoint\n
"},{"location":"common/#samgeo.common.download_checkpoint_legacy","title":"download_checkpoint_legacy(url=None, output=None, overwrite=False, **kwargs)
","text":"Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
Parameters:
Name Type Description Defaulturl
str
The checkpoint URL. Defaults to None.
None
output
str
The output file path. Defaults to None.
None
overwrite
bool
Overwrite the file if it already exists. Defaults to False.
False
Returns:
Type Descriptionstr
The output file path.
Source code insamgeo/common.py
def download_checkpoint_legacy(url=None, output=None, overwrite=False, **kwargs):\n\"\"\"Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.\n\n Args:\n url (str, optional): The checkpoint URL. Defaults to None.\n output (str, optional): The output file path. Defaults to None.\n overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.\n\n Returns:\n str: The output file path.\n \"\"\"\n checkpoints = {\n \"sam_vit_h_4b8939.pth\": \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth\",\n \"sam_vit_l_0b3195.pth\": \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth\",\n \"sam_vit_b_01ec64.pth\": \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\",\n }\n\n if isinstance(url, str) and url in checkpoints:\n url = checkpoints[url]\n\n if url is None:\n url = checkpoints[\"sam_vit_h_4b8939.pth\"]\n\n if output is None:\n output = os.path.basename(url)\n\n return download_file(url, output, overwrite=overwrite, **kwargs)\n
"},{"location":"common/#samgeo.common.download_file","title":"download_file(url=None, output=None, quiet=False, proxy=None, speed=None, use_cookies=True, verify=True, id=None, fuzzy=False, resume=False, unzip=True, overwrite=False, subfolder=False)
","text":"Download a file from URL, including Google Drive shared URL.
Parameters:
Name Type Description Defaulturl
str
Google Drive URL is also supported. Defaults to None.
None
output
str
Output filename. Default is basename of URL.
None
quiet
bool
Suppress terminal output. Default is False.
False
proxy
str
Proxy. Defaults to None.
None
speed
float
Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.
None
use_cookies
bool
Flag to use cookies. Defaults to True.
True
verify
bool | str
Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string, in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.
True
id
str
Google Drive's file ID. Defaults to None.
None
fuzzy
bool
Fuzzy extraction of Google Drive's file Id. Defaults to False.
False
resume
bool
Resume the download from existing tmp file if possible. Defaults to False.
False
unzip
bool
Unzip the file. Defaults to True.
True
overwrite
bool
Overwrite the file if it already exists. Defaults to False.
False
subfolder
bool
Create a subfolder with the same name as the file. Defaults to False.
False
Returns:
Type Descriptionstr
The output file path.
Source code insamgeo/common.py
def download_file(\n url=None,\n output=None,\n quiet=False,\n proxy=None,\n speed=None,\n use_cookies=True,\n verify=True,\n id=None,\n fuzzy=False,\n resume=False,\n unzip=True,\n overwrite=False,\n subfolder=False,\n):\n\"\"\"Download a file from URL, including Google Drive shared URL.\n\n Args:\n url (str, optional): Google Drive URL is also supported. Defaults to None.\n output (str, optional): Output filename. Default is basename of URL.\n quiet (bool, optional): Suppress terminal output. Default is False.\n proxy (str, optional): Proxy. Defaults to None.\n speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.\n use_cookies (bool, optional): Flag to use cookies. Defaults to True.\n verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string,\n in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.\n id (str, optional): Google Drive's file ID. Defaults to None.\n fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False.\n resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False.\n unzip (bool, optional): Unzip the file. Defaults to True.\n overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.\n subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False.\n\n Returns:\n str: The output file path.\n \"\"\"\n import zipfile\n\n try:\n import gdown\n except ImportError:\n print(\n \"The gdown package is required for this function. Use `pip install gdown` to install it.\"\n )\n return\n\n if output is None:\n if isinstance(url, str) and url.startswith(\"http\"):\n output = os.path.basename(url)\n\n out_dir = os.path.abspath(os.path.dirname(output))\n if not os.path.exists(out_dir):\n os.makedirs(out_dir)\n\n if isinstance(url, str):\n if os.path.exists(os.path.abspath(output)) and (not overwrite):\n print(\n f\"{output} already exists. Skip downloading. Set overwrite=True to overwrite.\"\n )\n return os.path.abspath(output)\n else:\n url = github_raw_url(url)\n\n if \"https://drive.google.com/file/d/\" in url:\n fuzzy = True\n\n output = gdown.download(\n url, output, quiet, proxy, speed, use_cookies, verify, id, fuzzy, resume\n )\n\n if unzip and output.endswith(\".zip\"):\n with zipfile.ZipFile(output, \"r\") as zip_ref:\n if not quiet:\n print(\"Extracting files...\")\n if subfolder:\n basename = os.path.splitext(os.path.basename(output))[0]\n\n output = os.path.join(out_dir, basename)\n if not os.path.exists(output):\n os.makedirs(output)\n zip_ref.extractall(output)\n else:\n zip_ref.extractall(os.path.dirname(output))\n\n return os.path.abspath(output)\n
"},{"location":"common/#samgeo.common.geojson_to_coords","title":"geojson_to_coords(geojson, src_crs='epsg:4326', dst_crs='epsg:4326')
","text":"Converts a geojson file or a dictionary of feature collection to a list of centroid coordinates.
Parameters:
Name Type Description Defaultgeojson
str | dict
The geojson file path or a dictionary of feature collection.
requiredsrc_crs
str
The source CRS. Defaults to \"epsg:4326\".
'epsg:4326'
dst_crs
str
The destination CRS. Defaults to \"epsg:4326\".
'epsg:4326'
Returns:
Type Descriptionlist
A list of centroid coordinates in the format of [[x1, y1], [x2, y2], ...]
Source code insamgeo/common.py
def geojson_to_coords(\n geojson: str, src_crs: str = \"epsg:4326\", dst_crs: str = \"epsg:4326\"\n) -> list:\n\"\"\"Converts a geojson file or a dictionary of feature collection to a list of centroid coordinates.\n\n Args:\n geojson (str | dict): The geojson file path or a dictionary of feature collection.\n src_crs (str, optional): The source CRS. Defaults to \"epsg:4326\".\n dst_crs (str, optional): The destination CRS. Defaults to \"epsg:4326\".\n\n Returns:\n list: A list of centroid coordinates in the format of [[x1, y1], [x2, y2], ...]\n \"\"\"\n\n import json\n import warnings\n\n warnings.filterwarnings(\"ignore\")\n\n if isinstance(geojson, dict):\n geojson = json.dumps(geojson)\n gdf = gpd.read_file(geojson, driver=\"GeoJSON\")\n centroids = gdf.geometry.centroid\n centroid_list = [[point.x, point.y] for point in centroids]\n if src_crs != dst_crs:\n centroid_list = transform_coords(\n [x[0] for x in centroid_list],\n [x[1] for x in centroid_list],\n src_crs,\n dst_crs,\n )\n centroid_list = [[x, y] for x, y in zip(centroid_list[0], centroid_list[1])]\n return centroid_list\n
"},{"location":"common/#samgeo.common.geojson_to_xy","title":"geojson_to_xy(src_fp, geojson, coord_crs='epsg:4326', **kwargs)
","text":"Converts a geojson file or a dictionary of feature collection to a list of pixel coordinates.
Parameters:
Name Type Description Defaultsrc_fp
str
The source raster file path.
requiredgeojson
str
The geojson file path or a dictionary of feature collection.
requiredcoord_crs
str
The coordinate CRS of the input coordinates. Defaults to \"epsg:4326\".
'epsg:4326'
**kwargs
Additional keyword arguments to pass to rasterio.transform.rowcol.
{}
Returns:
Type Descriptionlist
A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
Source code insamgeo/common.py
def geojson_to_xy(\n src_fp: str, geojson: str, coord_crs: str = \"epsg:4326\", **kwargs\n) -> list:\n\"\"\"Converts a geojson file or a dictionary of feature collection to a list of pixel coordinates.\n\n Args:\n src_fp: The source raster file path.\n geojson: The geojson file path or a dictionary of feature collection.\n coord_crs: The coordinate CRS of the input coordinates. Defaults to \"epsg:4326\".\n **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.\n\n Returns:\n A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]\n \"\"\"\n with rasterio.open(src_fp) as src:\n src_crs = src.crs\n coords = geojson_to_coords(geojson, coord_crs, src_crs)\n return coords_to_xy(src_fp, coords, src_crs, **kwargs)\n
"},{"location":"common/#samgeo.common.get_basemaps","title":"get_basemaps(free_only=True)
","text":"Returns a dictionary of xyz basemaps.
Parameters:
Name Type Description Defaultfree_only
bool
Whether to return only free xyz tile services that do not require an access token. Defaults to True.
True
Returns:
Type Descriptiondict
A dictionary of xyz basemaps.
Source code insamgeo/common.py
def get_basemaps(free_only=True):\n\"\"\"Returns a dictionary of xyz basemaps.\n\n Args:\n free_only (bool, optional): Whether to return only free xyz tile services that do not require an access token. Defaults to True.\n\n Returns:\n dict: A dictionary of xyz basemaps.\n \"\"\"\n\n basemaps = {}\n xyz_dict = get_xyz_dict(free_only=free_only)\n for item in xyz_dict:\n name = xyz_dict[item].name\n url = xyz_dict[item].build_url()\n basemaps[name] = url\n\n return basemaps\n
"},{"location":"common/#samgeo.common.get_vector_crs","title":"get_vector_crs(filename, **kwargs)
","text":"Gets the CRS of a vector file.
Parameters:
Name Type Description Defaultfilename
str
The vector file path.
requiredReturns:
Type Descriptionstr
The CRS of the vector file.
Source code insamgeo/common.py
def get_vector_crs(filename, **kwargs):\n\"\"\"Gets the CRS of a vector file.\n\n Args:\n filename (str): The vector file path.\n\n Returns:\n str: The CRS of the vector file.\n \"\"\"\n gdf = gpd.read_file(filename, **kwargs)\n epsg = gdf.crs.to_epsg()\n if epsg is None:\n return gdf.crs\n else:\n return f\"EPSG:{epsg}\"\n
"},{"location":"common/#samgeo.common.get_xyz_dict","title":"get_xyz_dict(free_only=True)
","text":"Returns a dictionary of xyz services.
Parameters:
Name Type Description Defaultfree_only
bool
Whether to return only free xyz tile services that do not require an access token. Defaults to True.
True
Returns:
Type Descriptiondict
A dictionary of xyz services.
Source code insamgeo/common.py
def get_xyz_dict(free_only=True):\n\"\"\"Returns a dictionary of xyz services.\n\n Args:\n free_only (bool, optional): Whether to return only free xyz tile services that do not require an access token. Defaults to True.\n\n Returns:\n dict: A dictionary of xyz services.\n \"\"\"\n import collections\n import xyzservices.providers as xyz\n\n def _unpack_sub_parameters(var, param):\n temp = var\n for sub_param in param.split(\".\"):\n temp = getattr(temp, sub_param)\n return temp\n\n xyz_dict = {}\n for item in xyz.values():\n try:\n name = item[\"name\"]\n tile = _unpack_sub_parameters(xyz, name)\n if _unpack_sub_parameters(xyz, name).requires_token():\n if free_only:\n pass\n else:\n xyz_dict[name] = tile\n else:\n xyz_dict[name] = tile\n\n except Exception:\n for sub_item in item:\n name = item[sub_item][\"name\"]\n tile = _unpack_sub_parameters(xyz, name)\n if _unpack_sub_parameters(xyz, name).requires_token():\n if free_only:\n pass\n else:\n xyz_dict[name] = tile\n else:\n xyz_dict[name] = tile\n\n xyz_dict = collections.OrderedDict(sorted(xyz_dict.items()))\n return xyz_dict\n
"},{"location":"common/#samgeo.common.github_raw_url","title":"github_raw_url(url)
","text":"Get the raw URL for a GitHub file.
Parameters:
Name Type Description Defaulturl
str
The GitHub URL.
requiredReturns:
Type Descriptionstr
The raw URL.
Source code insamgeo/common.py
def github_raw_url(url):\n\"\"\"Get the raw URL for a GitHub file.\n\n Args:\n url (str): The GitHub URL.\n Returns:\n str: The raw URL.\n \"\"\"\n if isinstance(url, str) and url.startswith(\"https://github.com/\") and \"blob\" in url:\n url = url.replace(\"github.com\", \"raw.githubusercontent.com\").replace(\n \"blob/\", \"\"\n )\n return url\n
"},{"location":"common/#samgeo.common.image_to_cog","title":"image_to_cog(source, dst_path=None, profile='deflate', **kwargs)
","text":"Converts an image to a COG file.
Parameters:
Name Type Description Defaultsource
str
A dataset path, URL or rasterio.io.DatasetReader object.
requireddst_path
str
An output dataset path or or PathLike object. Defaults to None.
None
profile
str
COG profile. More at https://cogeotiff.github.io/rio-cogeo/profile. Defaults to \"deflate\".
'deflate'
Exceptions:
Type DescriptionImportError
If rio-cogeo is not installed.
FileNotFoundError
If the source file could not be found.
Source code insamgeo/common.py
def image_to_cog(source, dst_path=None, profile=\"deflate\", **kwargs):\n\"\"\"Converts an image to a COG file.\n\n Args:\n source (str): A dataset path, URL or rasterio.io.DatasetReader object.\n dst_path (str, optional): An output dataset path or or PathLike object. Defaults to None.\n profile (str, optional): COG profile. More at https://cogeotiff.github.io/rio-cogeo/profile. Defaults to \"deflate\".\n\n Raises:\n ImportError: If rio-cogeo is not installed.\n FileNotFoundError: If the source file could not be found.\n \"\"\"\n try:\n from rio_cogeo.cogeo import cog_translate\n from rio_cogeo.profiles import cog_profiles\n\n except ImportError:\n raise ImportError(\n \"The rio-cogeo package is not installed. Please install it with `pip install rio-cogeo` or `conda install rio-cogeo -c conda-forge`.\"\n )\n\n if not source.startswith(\"http\"):\n source = check_file_path(source)\n\n if not os.path.exists(source):\n raise FileNotFoundError(\"The provided input file could not be found.\")\n\n if dst_path is None:\n if not source.startswith(\"http\"):\n dst_path = os.path.splitext(source)[0] + \"_cog.tif\"\n else:\n dst_path = temp_file_path(extension=\".tif\")\n\n dst_path = check_file_path(dst_path)\n\n dst_profile = cog_profiles.get(profile)\n cog_translate(source, dst_path, dst_profile, **kwargs)\n
"},{"location":"common/#samgeo.common.install_package","title":"install_package(package)
","text":"Install a Python package.
Parameters:
Name Type Description Defaultpackage
str | list
The package name or a GitHub URL or a list of package names or GitHub URLs.
required Source code insamgeo/common.py
def install_package(package):\n\"\"\"Install a Python package.\n\n Args:\n package (str | list): The package name or a GitHub URL or a list of package names or GitHub URLs.\n \"\"\"\n import subprocess\n\n if isinstance(package, str):\n packages = [package]\n\n for package in packages:\n if package.startswith(\"https://github.com\"):\n package = f\"git+{package}\"\n\n # Execute pip install command and show output in real-time\n command = f\"pip install {package}\"\n process = subprocess.Popen(command.split(), stdout=subprocess.PIPE)\n\n # Print output in real-time\n while True:\n output = process.stdout.readline()\n if output == b\"\" and process.poll() is not None:\n break\n if output:\n print(output.decode(\"utf-8\").strip())\n\n # Wait for process to complete\n process.wait()\n
"},{"location":"common/#samgeo.common.is_colab","title":"is_colab()
","text":"Tests if the code is being executed within Google Colab.
Source code insamgeo/common.py
def is_colab():\n\"\"\"Tests if the code is being executed within Google Colab.\"\"\"\n import sys\n\n if \"google.colab\" in sys.modules:\n return True\n else:\n return False\n
"},{"location":"common/#samgeo.common.merge_rasters","title":"merge_rasters(input_dir, output, input_pattern='*.tif', output_format='GTiff', output_nodata=None, output_options=['COMPRESS=DEFLATE'])
","text":"Merge a directory of rasters into a single raster.
Parameters:
Name Type Description Defaultinput_dir
str
The path to the input directory.
requiredoutput
str
The path to the output raster.
requiredinput_pattern
str
The pattern to match the input files. Defaults to \"*.tif\".
'*.tif'
output_format
str
The output format. Defaults to \"GTiff\".
'GTiff'
output_nodata
float
The output nodata value. Defaults to None.
None
output_options
list
A list of output options. Defaults to [\"COMPRESS=DEFLATE\"].
['COMPRESS=DEFLATE']
Exceptions:
Type DescriptionImportError
Raised if GDAL is not installed.
Source code insamgeo/common.py
def merge_rasters(\n input_dir,\n output,\n input_pattern=\"*.tif\",\n output_format=\"GTiff\",\n output_nodata=None,\n output_options=[\"COMPRESS=DEFLATE\"],\n):\n\"\"\"Merge a directory of rasters into a single raster.\n\n Args:\n input_dir (str): The path to the input directory.\n output (str): The path to the output raster.\n input_pattern (str, optional): The pattern to match the input files. Defaults to \"*.tif\".\n output_format (str, optional): The output format. Defaults to \"GTiff\".\n output_nodata (float, optional): The output nodata value. Defaults to None.\n output_options (list, optional): A list of output options. Defaults to [\"COMPRESS=DEFLATE\"].\n\n Raises:\n ImportError: Raised if GDAL is not installed.\n \"\"\"\n\n import glob\n\n try:\n from osgeo import gdal\n except ImportError:\n raise ImportError(\n \"GDAL is required to use this function. Install it with `conda install gdal -c conda-forge`\"\n )\n # Get a list of all the input files\n input_files = glob.glob(os.path.join(input_dir, input_pattern))\n\n # Merge the input files into a single output file\n gdal.Warp(\n output,\n input_files,\n format=output_format,\n dstNodata=output_nodata,\n options=output_options,\n )\n
"},{"location":"common/#samgeo.common.overlay_images","title":"overlay_images(image1, image2, alpha=0.5, backend='TkAgg', height_ratios=[10, 1], show_args1={}, show_args2={})
","text":"Overlays two images using a slider to control the opacity of the top image.
Parameters:
Name Type Description Defaultimage1
str | np.ndarray
The first input image at the bottom represented as a NumPy array or the path to the image.
requiredimage2
_type_
The second input image on top represented as a NumPy array or the path to the image.
requiredalpha
float
The alpha value of the top image. Defaults to 0.5.
0.5
backend
str
The backend of the matplotlib plot. Defaults to \"TkAgg\".
'TkAgg'
height_ratios
list
The height ratios of the two subplots. Defaults to [10, 1].
[10, 1]
show_args1
dict
The keyword arguments to pass to the imshow() function for the first image. Defaults to {}.
{}
show_args2
dict
The keyword arguments to pass to the imshow() function for the second image. Defaults to {}.
{}
Source code in samgeo/common.py
def overlay_images(\n image1,\n image2,\n alpha=0.5,\n backend=\"TkAgg\",\n height_ratios=[10, 1],\n show_args1={},\n show_args2={},\n):\n\"\"\"Overlays two images using a slider to control the opacity of the top image.\n\n Args:\n image1 (str | np.ndarray): The first input image at the bottom represented as a NumPy array or the path to the image.\n image2 (_type_): The second input image on top represented as a NumPy array or the path to the image.\n alpha (float, optional): The alpha value of the top image. Defaults to 0.5.\n backend (str, optional): The backend of the matplotlib plot. Defaults to \"TkAgg\".\n height_ratios (list, optional): The height ratios of the two subplots. Defaults to [10, 1].\n show_args1 (dict, optional): The keyword arguments to pass to the imshow() function for the first image. Defaults to {}.\n show_args2 (dict, optional): The keyword arguments to pass to the imshow() function for the second image. Defaults to {}.\n\n \"\"\"\n import sys\n import matplotlib\n import matplotlib.widgets as mpwidgets\n\n if \"google.colab\" in sys.modules:\n backend = \"inline\"\n print(\n \"The TkAgg backend is not supported in Google Colab. The overlay_images function will not work on Colab.\"\n )\n return\n\n matplotlib.use(backend)\n\n if isinstance(image1, str):\n if image1.startswith(\"http\"):\n image1 = download_file(image1)\n\n if not os.path.exists(image1):\n raise ValueError(f\"Input path {image1} does not exist.\")\n\n if isinstance(image2, str):\n if image2.startswith(\"http\"):\n image2 = download_file(image2)\n\n if not os.path.exists(image2):\n raise ValueError(f\"Input path {image2} does not exist.\")\n\n # Load the two images\n x = plt.imread(image1)\n y = plt.imread(image2)\n\n # Create the plot\n fig, (ax0, ax1) = plt.subplots(2, 1, gridspec_kw={\"height_ratios\": height_ratios})\n img0 = ax0.imshow(x, **show_args1)\n img1 = ax0.imshow(y, alpha=alpha, **show_args2)\n\n # Define the update function\n def update(value):\n img1.set_alpha(value)\n fig.canvas.draw_idle()\n\n # Create the slider\n slider0 = mpwidgets.Slider(ax=ax1, label=\"alpha\", valmin=0, valmax=1, valinit=alpha)\n slider0.on_changed(update)\n\n # Display the plot\n plt.show()\n
"},{"location":"common/#samgeo.common.random_string","title":"random_string(string_length=6)
","text":"Generates a random string of fixed length.
Parameters:
Name Type Description Defaultstring_length
int
Fixed length. Defaults to 3.
6
Returns:
Type Descriptionstr
A random string
Source code insamgeo/common.py
def random_string(string_length=6):\n\"\"\"Generates a random string of fixed length.\n\n Args:\n string_length (int, optional): Fixed length. Defaults to 3.\n\n Returns:\n str: A random string\n \"\"\"\n import random\n import string\n\n # random.seed(1001)\n letters = string.ascii_lowercase\n return \"\".join(random.choice(letters) for i in range(string_length))\n
"},{"location":"common/#samgeo.common.raster_to_geojson","title":"raster_to_geojson(tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a GeoJSON file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the GeoJSON file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/common.py
def raster_to_geojson(tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a GeoJSON file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the GeoJSON file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n if not output.endswith(\".geojson\"):\n output += \".geojson\"\n\n raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)\n
"},{"location":"common/#samgeo.common.raster_to_gpkg","title":"raster_to_gpkg(tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a gpkg file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the gpkg file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/common.py
def raster_to_gpkg(tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the gpkg file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n if not output.endswith(\".gpkg\"):\n output += \".gpkg\"\n\n raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)\n
"},{"location":"common/#samgeo.common.raster_to_shp","title":"raster_to_shp(tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a shapefile.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the shapefile.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/common.py
def raster_to_shp(tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a shapefile.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the shapefile.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n if not output.endswith(\".shp\"):\n output += \".shp\"\n\n raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)\n
"},{"location":"common/#samgeo.common.raster_to_vector","title":"raster_to_vector(source, output, simplify_tolerance=None, **kwargs)
","text":"Vectorize a raster dataset.
Parameters:
Name Type Description Defaultsource
str
The path to the tiff file.
requiredoutput
str
The path to the vector file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/common.py
def raster_to_vector(source, output, simplify_tolerance=None, **kwargs):\n\"\"\"Vectorize a raster dataset.\n\n Args:\n source (str): The path to the tiff file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n from rasterio import features\n\n with rasterio.open(source) as src:\n band = src.read()\n\n mask = band != 0\n shapes = features.shapes(band, mask=mask, transform=src.transform)\n\n fc = [\n {\"geometry\": shapely.geometry.shape(shape), \"properties\": {\"value\": value}}\n for shape, value in shapes\n ]\n if simplify_tolerance is not None:\n for i in fc:\n i[\"geometry\"] = i[\"geometry\"].simplify(tolerance=simplify_tolerance)\n\n gdf = gpd.GeoDataFrame.from_features(fc)\n if src.crs is not None:\n gdf.set_crs(crs=src.crs, inplace=True)\n gdf.to_file(output, **kwargs)\n
"},{"location":"common/#samgeo.common.regularize","title":"regularize(source, output=None, crs='EPSG:4326', **kwargs)
","text":"Regularize a polygon GeoDataFrame.
Parameters:
Name Type Description Defaultsource
str | gpd.GeoDataFrame
The input file path or a GeoDataFrame.
requiredoutput
str
The output file path. Defaults to None.
None
Returns:
Type Descriptiongpd.GeoDataFrame
The output GeoDataFrame.
Source code insamgeo/common.py
def regularize(source, output=None, crs=\"EPSG:4326\", **kwargs):\n\"\"\"Regularize a polygon GeoDataFrame.\n\n Args:\n source (str | gpd.GeoDataFrame): The input file path or a GeoDataFrame.\n output (str, optional): The output file path. Defaults to None.\n\n\n Returns:\n gpd.GeoDataFrame: The output GeoDataFrame.\n \"\"\"\n if isinstance(source, str):\n gdf = gpd.read_file(source)\n elif isinstance(source, gpd.GeoDataFrame):\n gdf = source\n else:\n raise ValueError(\"The input source must be a GeoDataFrame or a file path.\")\n\n polygons = gdf.geometry.apply(lambda geom: geom.minimum_rotated_rectangle)\n result = gpd.GeoDataFrame(geometry=polygons, data=gdf.drop(\"geometry\", axis=1))\n\n if crs is not None:\n result.to_crs(crs, inplace=True)\n if output is not None:\n result.to_file(output, **kwargs)\n else:\n return result\n
"},{"location":"common/#samgeo.common.reproject","title":"reproject(image, output, dst_crs='EPSG:4326', resampling='nearest', to_cog=True, **kwargs)
","text":"Reprojects an image.
Parameters:
Name Type Description Defaultimage
str
The input image filepath.
requiredoutput
str
The output image filepath.
requireddst_crs
str
The destination CRS. Defaults to \"EPSG:4326\".
'EPSG:4326'
resampling
Resampling
The resampling method. Defaults to \"nearest\".
'nearest'
to_cog
bool
Whether to convert the output image to a Cloud Optimized GeoTIFF. Defaults to True.
True
**kwargs
Additional keyword arguments to pass to rasterio.open.
{}
Source code in samgeo/common.py
def reproject(\n image, output, dst_crs=\"EPSG:4326\", resampling=\"nearest\", to_cog=True, **kwargs\n):\n\"\"\"Reprojects an image.\n\n Args:\n image (str): The input image filepath.\n output (str): The output image filepath.\n dst_crs (str, optional): The destination CRS. Defaults to \"EPSG:4326\".\n resampling (Resampling, optional): The resampling method. Defaults to \"nearest\".\n to_cog (bool, optional): Whether to convert the output image to a Cloud Optimized GeoTIFF. Defaults to True.\n **kwargs: Additional keyword arguments to pass to rasterio.open.\n\n \"\"\"\n import rasterio as rio\n from rasterio.warp import calculate_default_transform, reproject, Resampling\n\n if isinstance(resampling, str):\n resampling = getattr(Resampling, resampling)\n\n image = os.path.abspath(image)\n output = os.path.abspath(output)\n\n if not os.path.exists(os.path.dirname(output)):\n os.makedirs(os.path.dirname(output))\n\n with rio.open(image, **kwargs) as src:\n transform, width, height = calculate_default_transform(\n src.crs, dst_crs, src.width, src.height, *src.bounds\n )\n kwargs = src.meta.copy()\n kwargs.update(\n {\n \"crs\": dst_crs,\n \"transform\": transform,\n \"width\": width,\n \"height\": height,\n }\n )\n\n with rio.open(output, \"w\", **kwargs) as dst:\n for i in range(1, src.count + 1):\n reproject(\n source=rio.band(src, i),\n destination=rio.band(dst, i),\n src_transform=src.transform,\n src_crs=src.crs,\n dst_transform=transform,\n dst_crs=dst_crs,\n resampling=resampling,\n **kwargs,\n )\n\n if to_cog:\n image_to_cog(output, output)\n
"},{"location":"common/#samgeo.common.rowcol_to_xy","title":"rowcol_to_xy(src_fp, rows=None, cols=None, boxes=None, zs=None, offset='center', output=None, dst_crs='EPSG:4326', **kwargs)
","text":"Converts a list of (row, col) coordinates to (x, y) coordinates.
Parameters:
Name Type Description Defaultsrc_fp
str
The source raster file path.
requiredrows
list
A list of row coordinates. Defaults to None.
None
cols
list
A list of col coordinates. Defaults to None.
None
boxes
list
A list of (row, col) coordinates in the format of [[left, top, right, bottom], [left, top, right, bottom], ...]
None
zs
zs (list or float, optional): Height associated with coordinates. Primarily used for RPC based coordinate transformations.
None
offset
str
Determines if the returned coordinates are for the center of the pixel or for a corner.
'center'
output
str
The output vector file path. Defaults to None.
None
dst_crs
str
The destination CRS. Defaults to \"EPSG:4326\".
'EPSG:4326'
**kwargs
Additional keyword arguments to pass to rasterio.transform.xy.
{}
Returns:
Type DescriptionA list of (x, y) coordinates.
Source code insamgeo/common.py
def rowcol_to_xy(\n src_fp,\n rows=None,\n cols=None,\n boxes=None,\n zs=None,\n offset=\"center\",\n output=None,\n dst_crs=\"EPSG:4326\",\n **kwargs,\n):\n\"\"\"Converts a list of (row, col) coordinates to (x, y) coordinates.\n\n Args:\n src_fp (str): The source raster file path.\n rows (list, optional): A list of row coordinates. Defaults to None.\n cols (list, optional): A list of col coordinates. Defaults to None.\n boxes (list, optional): A list of (row, col) coordinates in the format of [[left, top, right, bottom], [left, top, right, bottom], ...]\n zs: zs (list or float, optional): Height associated with coordinates. Primarily used for RPC based coordinate transformations.\n offset (str, optional): Determines if the returned coordinates are for the center of the pixel or for a corner.\n output (str, optional): The output vector file path. Defaults to None.\n dst_crs (str, optional): The destination CRS. Defaults to \"EPSG:4326\".\n **kwargs: Additional keyword arguments to pass to rasterio.transform.xy.\n\n Returns:\n A list of (x, y) coordinates.\n \"\"\"\n\n if boxes is not None:\n rows = []\n cols = []\n\n for box in boxes:\n rows.append(box[1])\n rows.append(box[3])\n cols.append(box[0])\n cols.append(box[2])\n\n if rows is None or cols is None:\n raise ValueError(\"rows and cols must be provided.\")\n\n with rasterio.open(src_fp) as src:\n xs, ys = rasterio.transform.xy(src.transform, rows, cols, zs, offset, **kwargs)\n src_crs = src.crs\n\n if boxes is None:\n return [[x, y] for x, y in zip(xs, ys)]\n else:\n result = [[xs[i], ys[i + 1], xs[i + 1], ys[i]] for i in range(0, len(xs), 2)]\n\n if output is not None:\n boxes_to_vector(result, src_crs, dst_crs, output)\n else:\n return result\n
"},{"location":"common/#samgeo.common.sam_map_gui","title":"sam_map_gui(sam, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)
","text":"Display the SAM Map GUI.
Parameters:
Name Type Description Defaultsam
SamGeo
required basemap
str
The basemap to use. Defaults to \"SATELLITE\".
'SATELLITE'
repeat_mode
bool
Whether to use the repeat mode for the draw control. Defaults to True.
True
out_dir
str
The output directory. Defaults to None.
None
Source code in samgeo/common.py
def sam_map_gui(sam, basemap=\"SATELLITE\", repeat_mode=True, out_dir=None, **kwargs):\n\"\"\"Display the SAM Map GUI.\n\n Args:\n sam (SamGeo):\n basemap (str, optional): The basemap to use. Defaults to \"SATELLITE\".\n repeat_mode (bool, optional): Whether to use the repeat mode for the draw control. Defaults to True.\n out_dir (str, optional): The output directory. Defaults to None.\n\n \"\"\"\n try:\n import shutil\n import tempfile\n import leafmap\n import ipyleaflet\n import ipyevents\n import ipywidgets as widgets\n from ipyfilechooser import FileChooser\n except ImportError:\n raise ImportError(\n \"The sam_map function requires the leafmap package. Please install it first.\"\n )\n\n if out_dir is None:\n out_dir = tempfile.gettempdir()\n\n m = leafmap.Map(repeat_mode=repeat_mode, **kwargs)\n m.default_style = {\"cursor\": \"crosshair\"}\n m.add_basemap(basemap, show=False)\n\n # Skip the image layer if localtileserver is not available\n try:\n m.add_raster(sam.source, layer_name=\"Image\")\n except:\n pass\n\n m.fg_markers = []\n m.bg_markers = []\n\n fg_layer = ipyleaflet.LayerGroup(layers=m.fg_markers, name=\"Foreground\")\n bg_layer = ipyleaflet.LayerGroup(layers=m.bg_markers, name=\"Background\")\n m.add(fg_layer)\n m.add(bg_layer)\n m.fg_layer = fg_layer\n m.bg_layer = bg_layer\n\n widget_width = \"280px\"\n button_width = \"90px\"\n padding = \"0px 0px 0px 4px\" # upper, right, bottom, left\n style = {\"description_width\": \"initial\"}\n\n toolbar_button = widgets.ToggleButton(\n value=True,\n tooltip=\"Toolbar\",\n icon=\"gear\",\n layout=widgets.Layout(width=\"28px\", height=\"28px\", padding=padding),\n )\n\n close_button = widgets.ToggleButton(\n value=False,\n tooltip=\"Close the tool\",\n icon=\"times\",\n button_style=\"primary\",\n layout=widgets.Layout(height=\"28px\", width=\"28px\", padding=padding),\n )\n\n plus_button = widgets.ToggleButton(\n value=False,\n tooltip=\"Load foreground points\",\n icon=\"plus-circle\",\n button_style=\"primary\",\n layout=widgets.Layout(height=\"28px\", width=\"28px\", padding=padding),\n )\n\n minus_button = widgets.ToggleButton(\n value=False,\n tooltip=\"Load background points\",\n icon=\"minus-circle\",\n button_style=\"primary\",\n layout=widgets.Layout(height=\"28px\", width=\"28px\", padding=padding),\n )\n\n radio_buttons = widgets.RadioButtons(\n options=[\"Foreground\", \"Background\"],\n description=\"Class Type:\",\n disabled=False,\n style=style,\n layout=widgets.Layout(width=widget_width, padding=padding),\n )\n\n fg_count = widgets.IntText(\n value=0,\n description=\"Foreground #:\",\n disabled=True,\n style=style,\n layout=widgets.Layout(width=\"135px\", padding=padding),\n )\n bg_count = widgets.IntText(\n value=0,\n description=\"Background #:\",\n disabled=True,\n style=style,\n layout=widgets.Layout(width=\"135px\", padding=padding),\n )\n\n segment_button = widgets.ToggleButton(\n description=\"Segment\",\n value=False,\n button_style=\"primary\",\n layout=widgets.Layout(padding=padding),\n )\n\n save_button = widgets.ToggleButton(\n description=\"Save\", value=False, button_style=\"primary\"\n )\n\n reset_button = widgets.ToggleButton(\n description=\"Reset\", value=False, button_style=\"primary\"\n )\n segment_button.layout.width = button_width\n save_button.layout.width = button_width\n reset_button.layout.width = button_width\n\n opacity_slider = widgets.FloatSlider(\n description=\"Mask opacity:\",\n min=0,\n max=1,\n value=0.5,\n readout=True,\n continuous_update=True,\n layout=widgets.Layout(width=widget_width, padding=padding),\n style=style,\n )\n\n rectangular = widgets.Checkbox(\n value=False,\n description=\"Regularize\",\n layout=widgets.Layout(width=\"130px\", padding=padding),\n style=style,\n )\n\n colorpicker = widgets.ColorPicker(\n concise=False,\n description=\"Color\",\n value=\"#ffff00\",\n layout=widgets.Layout(width=\"140px\", padding=padding),\n style=style,\n )\n\n buttons = widgets.VBox(\n [\n radio_buttons,\n widgets.HBox([fg_count, bg_count]),\n opacity_slider,\n widgets.HBox([rectangular, colorpicker]),\n widgets.HBox(\n [segment_button, save_button, reset_button],\n layout=widgets.Layout(padding=\"0px 4px 0px 4px\"),\n ),\n ]\n )\n\n def opacity_changed(change):\n if change[\"new\"]:\n mask_layer = m.find_layer(\"Masks\")\n if mask_layer is not None:\n mask_layer.interact(opacity=opacity_slider.value)\n\n opacity_slider.observe(opacity_changed, \"value\")\n\n output = widgets.Output(\n layout=widgets.Layout(\n width=widget_width, padding=padding, max_width=widget_width\n )\n )\n\n toolbar_header = widgets.HBox()\n toolbar_header.children = [close_button, plus_button, minus_button, toolbar_button]\n toolbar_footer = widgets.VBox()\n toolbar_footer.children = [\n buttons,\n output,\n ]\n toolbar_widget = widgets.VBox()\n toolbar_widget.children = [toolbar_header, toolbar_footer]\n\n toolbar_event = ipyevents.Event(\n source=toolbar_widget, watched_events=[\"mouseenter\", \"mouseleave\"]\n )\n\n def marker_callback(chooser):\n with output:\n if chooser.selected is not None:\n try:\n gdf = gpd.read_file(chooser.selected)\n centroids = gdf.centroid\n coords = [[point.x, point.y] for point in centroids]\n for coord in coords:\n if plus_button.value:\n if is_colab(): # Colab does not support AwesomeIcon\n marker = ipyleaflet.CircleMarker(\n location=(coord[1], coord[0]),\n radius=2,\n color=\"green\",\n fill_color=\"green\",\n )\n else:\n marker = ipyleaflet.Marker(\n location=[coord[1], coord[0]],\n icon=ipyleaflet.AwesomeIcon(\n name=\"plus-circle\",\n marker_color=\"green\",\n icon_color=\"darkred\",\n ),\n )\n m.fg_layer.add(marker)\n m.fg_markers.append(marker)\n fg_count.value = len(m.fg_markers)\n elif minus_button.value:\n if is_colab():\n marker = ipyleaflet.CircleMarker(\n location=(coord[1], coord[0]),\n radius=2,\n color=\"red\",\n fill_color=\"red\",\n )\n else:\n marker = ipyleaflet.Marker(\n location=[coord[1], coord[0]],\n icon=ipyleaflet.AwesomeIcon(\n name=\"minus-circle\",\n marker_color=\"red\",\n icon_color=\"darkred\",\n ),\n )\n m.bg_layer.add(marker)\n m.bg_markers.append(marker)\n bg_count.value = len(m.bg_markers)\n\n except Exception as e:\n print(e)\n\n if m.marker_control in m.controls:\n m.remove_control(m.marker_control)\n delattr(m, \"marker_control\")\n\n plus_button.value = False\n minus_button.value = False\n\n def marker_button_click(change):\n if change[\"new\"]:\n sandbox_path = os.environ.get(\"SANDBOX_PATH\")\n filechooser = FileChooser(\n path=os.getcwd(),\n sandbox_path=sandbox_path,\n layout=widgets.Layout(width=\"454px\"),\n )\n filechooser.use_dir_icons = True\n filechooser.filter_pattern = [\"*.shp\", \"*.geojson\", \"*.gpkg\"]\n filechooser.register_callback(marker_callback)\n marker_control = ipyleaflet.WidgetControl(\n widget=filechooser, position=\"topright\"\n )\n m.add_control(marker_control)\n m.marker_control = marker_control\n else:\n if hasattr(m, \"marker_control\") and m.marker_control in m.controls:\n m.remove_control(m.marker_control)\n m.marker_control.close()\n\n plus_button.observe(marker_button_click, \"value\")\n minus_button.observe(marker_button_click, \"value\")\n\n def handle_toolbar_event(event):\n if event[\"type\"] == \"mouseenter\":\n toolbar_widget.children = [toolbar_header, toolbar_footer]\n elif event[\"type\"] == \"mouseleave\":\n if not toolbar_button.value:\n toolbar_widget.children = [toolbar_button]\n toolbar_button.value = False\n close_button.value = False\n\n toolbar_event.on_dom_event(handle_toolbar_event)\n\n def toolbar_btn_click(change):\n if change[\"new\"]:\n close_button.value = False\n toolbar_widget.children = [toolbar_header, toolbar_footer]\n else:\n if not close_button.value:\n toolbar_widget.children = [toolbar_button]\n\n toolbar_button.observe(toolbar_btn_click, \"value\")\n\n def close_btn_click(change):\n if change[\"new\"]:\n toolbar_button.value = False\n if m.toolbar_control in m.controls:\n m.remove_control(m.toolbar_control)\n toolbar_widget.close()\n\n close_button.observe(close_btn_click, \"value\")\n\n def handle_map_interaction(**kwargs):\n try:\n if kwargs.get(\"type\") == \"click\":\n latlon = kwargs.get(\"coordinates\")\n if radio_buttons.value == \"Foreground\":\n if is_colab():\n marker = ipyleaflet.CircleMarker(\n location=tuple(latlon),\n radius=2,\n color=\"green\",\n fill_color=\"green\",\n )\n else:\n marker = ipyleaflet.Marker(\n location=latlon,\n icon=ipyleaflet.AwesomeIcon(\n name=\"plus-circle\",\n marker_color=\"green\",\n icon_color=\"darkred\",\n ),\n )\n fg_layer.add(marker)\n m.fg_markers.append(marker)\n fg_count.value = len(m.fg_markers)\n elif radio_buttons.value == \"Background\":\n if is_colab():\n marker = ipyleaflet.CircleMarker(\n location=tuple(latlon),\n radius=2,\n color=\"red\",\n fill_color=\"red\",\n )\n else:\n marker = ipyleaflet.Marker(\n location=latlon,\n icon=ipyleaflet.AwesomeIcon(\n name=\"minus-circle\",\n marker_color=\"red\",\n icon_color=\"darkred\",\n ),\n )\n bg_layer.add(marker)\n m.bg_markers.append(marker)\n bg_count.value = len(m.bg_markers)\n\n except (TypeError, KeyError) as e:\n print(f\"Error handling map interaction: {e}\")\n\n m.on_interaction(handle_map_interaction)\n\n def segment_button_click(change):\n if change[\"new\"]:\n segment_button.value = False\n with output:\n output.clear_output()\n if len(m.fg_markers) == 0:\n print(\"Please add some foreground markers.\")\n segment_button.value = False\n return\n\n else:\n try:\n fg_points = [\n [marker.location[1], marker.location[0]]\n for marker in m.fg_markers\n ]\n bg_points = [\n [marker.location[1], marker.location[0]]\n for marker in m.bg_markers\n ]\n point_coords = fg_points + bg_points\n point_labels = [1] * len(fg_points) + [0] * len(bg_points)\n\n filename = f\"masks_{random_string()}.tif\"\n filename = os.path.join(out_dir, filename)\n sam.predict(\n point_coords=point_coords,\n point_labels=point_labels,\n point_crs=\"EPSG:4326\",\n output=filename,\n )\n if m.find_layer(\"Masks\") is not None:\n m.remove_layer(m.find_layer(\"Masks\"))\n if m.find_layer(\"Regularized\") is not None:\n m.remove_layer(m.find_layer(\"Regularized\"))\n\n if hasattr(sam, \"prediction_fp\") and os.path.exists(\n sam.prediction_fp\n ):\n try:\n os.remove(sam.prediction_fp)\n except:\n pass\n\n # Skip the image layer if localtileserver is not available\n try:\n m.add_raster(\n filename,\n nodata=0,\n cmap=\"Blues\",\n opacity=opacity_slider.value,\n layer_name=\"Masks\",\n zoom_to_layer=False,\n )\n\n if rectangular.value:\n vector = filename.replace(\".tif\", \".gpkg\")\n vector_rec = filename.replace(\".tif\", \"_rect.gpkg\")\n raster_to_vector(filename, vector)\n regularize(vector, vector_rec)\n vector_style = {\"color\": colorpicker.value}\n m.add_vector(\n vector_rec,\n layer_name=\"Regularized\",\n style=vector_style,\n info_mode=None,\n zoom_to_layer=False,\n )\n\n except:\n pass\n output.clear_output()\n segment_button.value = False\n sam.prediction_fp = filename\n except Exception as e:\n segment_button.value = False\n print(e)\n\n segment_button.observe(segment_button_click, \"value\")\n\n def filechooser_callback(chooser):\n with output:\n if chooser.selected is not None:\n try:\n filename = chooser.selected\n shutil.copy(sam.prediction_fp, filename)\n vector = filename.replace(\".tif\", \".gpkg\")\n raster_to_vector(filename, vector)\n if rectangular.value:\n vector_rec = filename.replace(\".tif\", \"_rect.gpkg\")\n regularize(vector, vector_rec)\n\n fg_points = [\n [marker.location[1], marker.location[0]]\n for marker in m.fg_markers\n ]\n bg_points = [\n [marker.location[1], marker.location[0]]\n for marker in m.bg_markers\n ]\n\n coords_to_geojson(\n fg_points, filename.replace(\".tif\", \"_fg_markers.geojson\")\n )\n coords_to_geojson(\n bg_points, filename.replace(\".tif\", \"_bg_markers.geojson\")\n )\n\n except Exception as e:\n print(e)\n\n if hasattr(m, \"save_control\") and m.save_control in m.controls:\n m.remove_control(m.save_control)\n delattr(m, \"save_control\")\n save_button.value = False\n\n def save_button_click(change):\n if change[\"new\"]:\n with output:\n sandbox_path = os.environ.get(\"SANDBOX_PATH\")\n filechooser = FileChooser(\n path=os.getcwd(),\n filename=\"masks.tif\",\n sandbox_path=sandbox_path,\n layout=widgets.Layout(width=\"454px\"),\n )\n filechooser.use_dir_icons = True\n filechooser.filter_pattern = [\"*.tif\"]\n filechooser.register_callback(filechooser_callback)\n save_control = ipyleaflet.WidgetControl(\n widget=filechooser, position=\"topright\"\n )\n m.add_control(save_control)\n m.save_control = save_control\n else:\n if hasattr(m, \"save_control\") and m.save_control in m.controls:\n m.remove_control(m.save_control)\n delattr(m, \"save_control\")\n\n save_button.observe(save_button_click, \"value\")\n\n def reset_button_click(change):\n if change[\"new\"]:\n segment_button.value = False\n reset_button.value = False\n opacity_slider.value = 0.5\n rectangular.value = False\n colorpicker.value = \"#ffff00\"\n output.clear_output()\n try:\n m.remove_layer(m.find_layer(\"Masks\"))\n if m.find_layer(\"Regularized\") is not None:\n m.remove_layer(m.find_layer(\"Regularized\"))\n m.clear_drawings()\n if hasattr(m, \"fg_markers\"):\n m.user_rois = None\n m.fg_markers = []\n m.bg_markers = []\n m.fg_layer.clear_layers()\n m.bg_layer.clear_layers()\n fg_count.value = 0\n bg_count.value = 0\n try:\n os.remove(sam.prediction_fp)\n except:\n pass\n except:\n pass\n\n reset_button.observe(reset_button_click, \"value\")\n\n toolbar_control = ipyleaflet.WidgetControl(\n widget=toolbar_widget, position=\"topright\"\n )\n m.add_control(toolbar_control)\n m.toolbar_control = toolbar_control\n\n return m\n
"},{"location":"common/#samgeo.common.show_canvas","title":"show_canvas(image, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)
","text":"Show a canvas to collect foreground and background points.
Parameters:
Name Type Description Defaultimage
str | np.ndarray
The input image.
requiredfg_color
tuple
The color for the foreground points. Defaults to (0, 255, 0).
(0, 255, 0)
bg_color
tuple
The color for the background points. Defaults to (0, 0, 255).
(0, 0, 255)
radius
int
The radius of the points. Defaults to 5.
5
Returns:
Type Descriptiontuple
A tuple of two lists of foreground and background points.
Source code insamgeo/common.py
def show_canvas(image, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):\n\"\"\"Show a canvas to collect foreground and background points.\n\n Args:\n image (str | np.ndarray): The input image.\n fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).\n bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).\n radius (int, optional): The radius of the points. Defaults to 5.\n\n Returns:\n tuple: A tuple of two lists of foreground and background points.\n \"\"\"\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n image = cv2.imread(image)\n elif isinstance(image, np.ndarray):\n pass\n else:\n raise ValueError(\"Input image must be a URL or a NumPy array.\")\n\n # Create an empty list to store the mouse click coordinates\n left_clicks = []\n right_clicks = []\n\n # Create a mouse callback function\n def get_mouse_coordinates(event, x, y):\n if event == cv2.EVENT_LBUTTONDOWN:\n # Append the coordinates to the mouse_clicks list\n left_clicks.append((x, y))\n\n # Draw a green circle at the mouse click coordinates\n cv2.circle(image, (x, y), radius, fg_color, -1)\n\n # Show the updated image with the circle\n cv2.imshow(\"Image\", image)\n\n elif event == cv2.EVENT_RBUTTONDOWN:\n # Append the coordinates to the mouse_clicks list\n right_clicks.append((x, y))\n\n # Draw a red circle at the mouse click coordinates\n cv2.circle(image, (x, y), radius, bg_color, -1)\n\n # Show the updated image with the circle\n cv2.imshow(\"Image\", image)\n\n # Create a window to display the image\n cv2.namedWindow(\"Image\")\n\n # Set the mouse callback function for the window\n cv2.setMouseCallback(\"Image\", get_mouse_coordinates)\n\n # Display the image in the window\n cv2.imshow(\"Image\", image)\n\n # Wait for a key press to exit\n cv2.waitKey(0)\n\n # Destroy the window\n cv2.destroyAllWindows()\n\n return left_clicks, right_clicks\n
"},{"location":"common/#samgeo.common.split_raster","title":"split_raster(filename, out_dir, tile_size=256, overlap=0)
","text":"Split a raster into tiles.
Parameters:
Name Type Description Defaultfilename
str
The path or http URL to the raster file.
requiredout_dir
str
The path to the output directory.
requiredtile_size
int | tuple
The size of the tiles. Can be an integer or a tuple of (width, height). Defaults to 256.
256
overlap
int
The number of pixels to overlap between tiles. Defaults to 0.
0
Exceptions:
Type DescriptionImportError
Raised if GDAL is not installed.
Source code insamgeo/common.py
def split_raster(filename, out_dir, tile_size=256, overlap=0):\n\"\"\"Split a raster into tiles.\n\n Args:\n filename (str): The path or http URL to the raster file.\n out_dir (str): The path to the output directory.\n tile_size (int | tuple, optional): The size of the tiles. Can be an integer or a tuple of (width, height). Defaults to 256.\n overlap (int, optional): The number of pixels to overlap between tiles. Defaults to 0.\n\n Raises:\n ImportError: Raised if GDAL is not installed.\n \"\"\"\n\n try:\n from osgeo import gdal\n except ImportError:\n raise ImportError(\n \"GDAL is required to use this function. Install it with `conda install gdal -c conda-forge`\"\n )\n\n if isinstance(filename, str):\n if filename.startswith(\"http\"):\n output = filename.split(\"/\")[-1]\n download_file(filename, output)\n filename = output\n\n # Open the input GeoTIFF file\n ds = gdal.Open(filename)\n\n if not os.path.exists(out_dir):\n os.makedirs(out_dir)\n\n if isinstance(tile_size, int):\n tile_width = tile_size\n tile_height = tile_size\n elif isinstance(tile_size, tuple):\n tile_width = tile_size[0]\n tile_height = tile_size[1]\n\n # Get the size of the input raster\n width = ds.RasterXSize\n height = ds.RasterYSize\n\n # Calculate the number of tiles needed in both directions, taking into account the overlap\n num_tiles_x = (width - overlap) // (tile_width - overlap) + int(\n (width - overlap) % (tile_width - overlap) > 0\n )\n num_tiles_y = (height - overlap) // (tile_height - overlap) + int(\n (height - overlap) % (tile_height - overlap) > 0\n )\n\n # Get the georeferencing information of the input raster\n geotransform = ds.GetGeoTransform()\n\n # Loop over all the tiles\n for i in range(num_tiles_x):\n for j in range(num_tiles_y):\n # Calculate the pixel coordinates of the tile, taking into account the overlap and clamping to the edge of the raster\n x_min = i * (tile_width - overlap)\n y_min = j * (tile_height - overlap)\n x_max = min(x_min + tile_width, width)\n y_max = min(y_min + tile_height, height)\n\n # Adjust the size of the last tile in each row and column to include any remaining pixels\n if i == num_tiles_x - 1:\n x_min = max(x_max - tile_width, 0)\n if j == num_tiles_y - 1:\n y_min = max(y_max - tile_height, 0)\n\n # Calculate the size of the tile, taking into account the overlap\n tile_width = x_max - x_min\n tile_height = y_max - y_min\n\n # Set the output file name\n output_file = f\"{out_dir}/tile_{i}_{j}.tif\"\n\n # Create a new dataset for the tile\n driver = gdal.GetDriverByName(\"GTiff\")\n tile_ds = driver.Create(\n output_file,\n tile_width,\n tile_height,\n ds.RasterCount,\n ds.GetRasterBand(1).DataType,\n )\n\n # Calculate the georeferencing information for the output tile\n tile_geotransform = (\n geotransform[0] + x_min * geotransform[1],\n geotransform[1],\n 0,\n geotransform[3] + y_min * geotransform[5],\n 0,\n geotransform[5],\n )\n\n # Set the geotransform and projection of the tile\n tile_ds.SetGeoTransform(tile_geotransform)\n tile_ds.SetProjection(ds.GetProjection())\n\n # Read the data from the input raster band(s) and write it to the tile band(s)\n for k in range(ds.RasterCount):\n band = ds.GetRasterBand(k + 1)\n tile_band = tile_ds.GetRasterBand(k + 1)\n tile_data = band.ReadAsArray(x_min, y_min, tile_width, tile_height)\n tile_band.WriteArray(tile_data)\n\n # Close the tile dataset\n tile_ds = None\n\n # Close the input dataset\n ds = None\n
"},{"location":"common/#samgeo.common.temp_file_path","title":"temp_file_path(extension)
","text":"Returns a temporary file path.
Parameters:
Name Type Description Defaultextension
str
The file extension.
requiredReturns:
Type Descriptionstr
The temporary file path.
Source code insamgeo/common.py
def temp_file_path(extension):\n\"\"\"Returns a temporary file path.\n\n Args:\n extension (str): The file extension.\n\n Returns:\n str: The temporary file path.\n \"\"\"\n\n import tempfile\n import uuid\n\n if not extension.startswith(\".\"):\n extension = \".\" + extension\n file_id = str(uuid.uuid4())\n file_path = os.path.join(tempfile.gettempdir(), f\"{file_id}{extension}\")\n\n return file_path\n
"},{"location":"common/#samgeo.common.text_sam_gui","title":"text_sam_gui(sam, basemap='SATELLITE', out_dir=None, box_threshold=0.25, text_threshold=0.25, cmap='viridis', opacity=0.5, **kwargs)
","text":"Display the SAM Map GUI.
Parameters:
Name Type Description Defaultsam
SamGeo
required basemap
str
The basemap to use. Defaults to \"SATELLITE\".
'SATELLITE'
out_dir
str
The output directory. Defaults to None.
None
Source code in samgeo/common.py
def text_sam_gui(\n sam,\n basemap=\"SATELLITE\",\n out_dir=None,\n box_threshold=0.25,\n text_threshold=0.25,\n cmap=\"viridis\",\n opacity=0.5,\n **kwargs,\n):\n\"\"\"Display the SAM Map GUI.\n\n Args:\n sam (SamGeo):\n basemap (str, optional): The basemap to use. Defaults to \"SATELLITE\".\n out_dir (str, optional): The output directory. Defaults to None.\n\n \"\"\"\n try:\n import shutil\n import tempfile\n import leafmap\n import ipyleaflet\n import ipyevents\n import ipywidgets as widgets\n import leafmap.colormaps as cm\n from ipyfilechooser import FileChooser\n except ImportError:\n raise ImportError(\n \"The sam_map function requires the leafmap package. Please install it first.\"\n )\n\n if out_dir is None:\n out_dir = tempfile.gettempdir()\n\n m = leafmap.Map(**kwargs)\n m.default_style = {\"cursor\": \"crosshair\"}\n m.add_basemap(basemap, show=False)\n\n # Skip the image layer if localtileserver is not available\n try:\n m.add_raster(sam.source, layer_name=\"Image\")\n except:\n pass\n\n widget_width = \"280px\"\n button_width = \"90px\"\n padding = \"0px 4px 0px 4px\" # upper, right, bottom, left\n style = {\"description_width\": \"initial\"}\n\n toolbar_button = widgets.ToggleButton(\n value=True,\n tooltip=\"Toolbar\",\n icon=\"gear\",\n layout=widgets.Layout(width=\"28px\", height=\"28px\", padding=\"0px 0px 0px 4px\"),\n )\n\n close_button = widgets.ToggleButton(\n value=False,\n tooltip=\"Close the tool\",\n icon=\"times\",\n button_style=\"primary\",\n layout=widgets.Layout(height=\"28px\", width=\"28px\", padding=\"0px 0px 0px 4px\"),\n )\n\n text_prompt = widgets.Text(\n description=\"Text prompt:\",\n style=style,\n layout=widgets.Layout(width=widget_width, padding=padding),\n )\n\n box_slider = widgets.FloatSlider(\n description=\"Box threshold:\",\n min=0,\n max=1,\n value=box_threshold,\n step=0.01,\n readout=True,\n continuous_update=True,\n layout=widgets.Layout(width=widget_width, padding=padding),\n style=style,\n )\n\n text_slider = widgets.FloatSlider(\n description=\"Text threshold:\",\n min=0,\n max=1,\n step=0.01,\n value=text_threshold,\n readout=True,\n continuous_update=True,\n layout=widgets.Layout(width=widget_width, padding=padding),\n style=style,\n )\n\n cmap_dropdown = widgets.Dropdown(\n description=\"Palette:\",\n options=cm.list_colormaps(),\n value=cmap,\n style=style,\n layout=widgets.Layout(width=widget_width, padding=padding),\n )\n\n opacity_slider = widgets.FloatSlider(\n description=\"Opacity:\",\n min=0,\n max=1,\n value=opacity,\n readout=True,\n continuous_update=True,\n layout=widgets.Layout(width=widget_width, padding=padding),\n style=style,\n )\n\n def opacity_changed(change):\n if change[\"new\"]:\n if hasattr(m, \"layer_name\"):\n mask_layer = m.find_layer(m.layer_name)\n if mask_layer is not None:\n mask_layer.interact(opacity=opacity_slider.value)\n\n opacity_slider.observe(opacity_changed, \"value\")\n\n rectangular = widgets.Checkbox(\n value=False,\n description=\"Regularize\",\n layout=widgets.Layout(width=\"130px\", padding=padding),\n style=style,\n )\n\n colorpicker = widgets.ColorPicker(\n concise=False,\n description=\"Color\",\n value=\"#ffff00\",\n layout=widgets.Layout(width=\"140px\", padding=padding),\n style=style,\n )\n\n segment_button = widgets.ToggleButton(\n description=\"Segment\",\n value=False,\n button_style=\"primary\",\n layout=widgets.Layout(padding=padding),\n )\n\n save_button = widgets.ToggleButton(\n description=\"Save\", value=False, button_style=\"primary\"\n )\n\n reset_button = widgets.ToggleButton(\n description=\"Reset\", value=False, button_style=\"primary\"\n )\n segment_button.layout.width = button_width\n save_button.layout.width = button_width\n reset_button.layout.width = button_width\n\n output = widgets.Output(\n layout=widgets.Layout(\n width=widget_width, padding=padding, max_width=widget_width\n )\n )\n\n toolbar_header = widgets.HBox()\n toolbar_header.children = [close_button, toolbar_button]\n toolbar_footer = widgets.VBox()\n toolbar_footer.children = [\n text_prompt,\n box_slider,\n text_slider,\n cmap_dropdown,\n opacity_slider,\n widgets.HBox([rectangular, colorpicker]),\n widgets.HBox(\n [segment_button, save_button, reset_button],\n layout=widgets.Layout(padding=\"0px 4px 0px 4px\"),\n ),\n output,\n ]\n toolbar_widget = widgets.VBox()\n toolbar_widget.children = [toolbar_header, toolbar_footer]\n\n toolbar_event = ipyevents.Event(\n source=toolbar_widget, watched_events=[\"mouseenter\", \"mouseleave\"]\n )\n\n def handle_toolbar_event(event):\n if event[\"type\"] == \"mouseenter\":\n toolbar_widget.children = [toolbar_header, toolbar_footer]\n elif event[\"type\"] == \"mouseleave\":\n if not toolbar_button.value:\n toolbar_widget.children = [toolbar_button]\n toolbar_button.value = False\n close_button.value = False\n\n toolbar_event.on_dom_event(handle_toolbar_event)\n\n def toolbar_btn_click(change):\n if change[\"new\"]:\n close_button.value = False\n toolbar_widget.children = [toolbar_header, toolbar_footer]\n else:\n if not close_button.value:\n toolbar_widget.children = [toolbar_button]\n\n toolbar_button.observe(toolbar_btn_click, \"value\")\n\n def close_btn_click(change):\n if change[\"new\"]:\n toolbar_button.value = False\n if m.toolbar_control in m.controls:\n m.remove_control(m.toolbar_control)\n toolbar_widget.close()\n\n close_button.observe(close_btn_click, \"value\")\n\n def segment_button_click(change):\n if change[\"new\"]:\n segment_button.value = False\n with output:\n output.clear_output()\n if len(text_prompt.value) == 0:\n print(\"Please enter a text prompt first.\")\n elif sam.source is None:\n print(\"Please run sam.set_image() first.\")\n else:\n print(\"Segmenting...\")\n layer_name = text_prompt.value.replace(\" \", \"_\")\n filename = os.path.join(\n out_dir, f\"{layer_name}_{random_string()}.tif\"\n )\n try:\n sam.predict(\n sam.source,\n text_prompt.value,\n box_slider.value,\n text_slider.value,\n output=filename,\n )\n sam.output = filename\n if m.find_layer(layer_name) is not None:\n m.remove_layer(m.find_layer(layer_name))\n if m.find_layer(f\"{layer_name}_rect\") is not None:\n m.remove_layer(m.find_layer(f\"{layer_name} Regularized\"))\n except Exception as e:\n output.clear_output()\n print(e)\n if os.path.exists(filename):\n try:\n m.add_raster(\n filename,\n layer_name=layer_name,\n palette=cmap_dropdown.value,\n opacity=opacity_slider.value,\n nodata=0,\n zoom_to_layer=False,\n )\n m.layer_name = layer_name\n\n if rectangular.value:\n vector = filename.replace(\".tif\", \".gpkg\")\n vector_rec = filename.replace(\".tif\", \"_rect.gpkg\")\n raster_to_vector(filename, vector)\n regularize(vector, vector_rec)\n vector_style = {\"color\": colorpicker.value}\n m.add_vector(\n vector_rec,\n layer_name=f\"{layer_name} Regularized\",\n style=vector_style,\n info_mode=None,\n zoom_to_layer=False,\n )\n\n output.clear_output()\n except Exception as e:\n print(e)\n\n segment_button.observe(segment_button_click, \"value\")\n\n def filechooser_callback(chooser):\n with output:\n if chooser.selected is not None:\n try:\n filename = chooser.selected\n shutil.copy(sam.output, filename)\n vector = filename.replace(\".tif\", \".gpkg\")\n raster_to_vector(filename, vector)\n if rectangular.value:\n vector_rec = filename.replace(\".tif\", \"_rect.gpkg\")\n regularize(vector, vector_rec)\n except Exception as e:\n print(e)\n\n if hasattr(m, \"save_control\") and m.save_control in m.controls:\n m.remove_control(m.save_control)\n delattr(m, \"save_control\")\n save_button.value = False\n\n def save_button_click(change):\n if change[\"new\"]:\n with output:\n output.clear_output()\n if not hasattr(m, \"layer_name\"):\n print(\"Please click the Segment button first.\")\n else:\n sandbox_path = os.environ.get(\"SANDBOX_PATH\")\n filechooser = FileChooser(\n path=os.getcwd(),\n filename=f\"{m.layer_name}.tif\",\n sandbox_path=sandbox_path,\n layout=widgets.Layout(width=\"454px\"),\n )\n filechooser.use_dir_icons = True\n filechooser.filter_pattern = [\"*.tif\"]\n filechooser.register_callback(filechooser_callback)\n save_control = ipyleaflet.WidgetControl(\n widget=filechooser, position=\"topright\"\n )\n m.add_control(save_control)\n m.save_control = save_control\n\n else:\n if hasattr(m, \"save_control\") and m.save_control in m.controls:\n m.remove_control(m.save_control)\n delattr(m, \"save_control\")\n\n save_button.observe(save_button_click, \"value\")\n\n def reset_button_click(change):\n if change[\"new\"]:\n segment_button.value = False\n save_button.value = False\n reset_button.value = False\n opacity_slider.value = 0.5\n box_slider.value = 0.25\n text_slider.value = 0.25\n cmap_dropdown.value = \"viridis\"\n text_prompt.value = \"\"\n output.clear_output()\n try:\n if hasattr(m, \"layer_name\") and m.find_layer(m.layer_name) is not None:\n m.remove_layer(m.find_layer(m.layer_name))\n m.clear_drawings()\n except:\n pass\n\n reset_button.observe(reset_button_click, \"value\")\n\n toolbar_control = ipyleaflet.WidgetControl(\n widget=toolbar_widget, position=\"topright\"\n )\n m.add_control(toolbar_control)\n m.toolbar_control = toolbar_control\n\n return m\n
"},{"location":"common/#samgeo.common.tms_to_geotiff","title":"tms_to_geotiff(output, bbox, zoom=None, resolution=None, source='OpenStreetMap', crs='EPSG:3857', to_cog=False, return_image=False, overwrite=False, quiet=False, **kwargs)
","text":"Download TMS tiles and convert them to a GeoTIFF. The source is adapted from https://github.com/gumblex/tms2geotiff. Credits to the GitHub user @gumblex.
Parameters:
Name Type Description Defaultoutput
str
The output GeoTIFF file.
requiredbbox
list
The bounding box [minx, miny, maxx, maxy], e.g., [-122.5216, 37.733, -122.3661, 37.8095]
requiredzoom
int
The map zoom level. Defaults to None.
None
resolution
float
The resolution in meters. Defaults to None.
None
source
str
The tile source. It can be one of the following: \"OPENSTREETMAP\", \"ROADMAP\", \"SATELLITE\", \"TERRAIN\", \"HYBRID\", or an HTTP URL. Defaults to \"OpenStreetMap\".
'OpenStreetMap'
crs
str
The output CRS. Defaults to \"EPSG:3857\".
'EPSG:3857'
to_cog
bool
Convert to Cloud Optimized GeoTIFF. Defaults to False.
False
return_image
bool
Return the image as PIL.Image. Defaults to False.
False
overwrite
bool
Overwrite the output file if it already exists. Defaults to False.
False
quiet
bool
Suppress output. Defaults to False.
False
**kwargs
Additional arguments to pass to gdal.GetDriverByName(\"GTiff\").Create().
{}
Source code in samgeo/common.py
def tms_to_geotiff(\n output,\n bbox,\n zoom=None,\n resolution=None,\n source=\"OpenStreetMap\",\n crs=\"EPSG:3857\",\n to_cog=False,\n return_image=False,\n overwrite=False,\n quiet=False,\n **kwargs,\n):\n\"\"\"Download TMS tiles and convert them to a GeoTIFF. The source is adapted from https://github.com/gumblex/tms2geotiff.\n Credits to the GitHub user @gumblex.\n\n Args:\n output (str): The output GeoTIFF file.\n bbox (list): The bounding box [minx, miny, maxx, maxy], e.g., [-122.5216, 37.733, -122.3661, 37.8095]\n zoom (int, optional): The map zoom level. Defaults to None.\n resolution (float, optional): The resolution in meters. Defaults to None.\n source (str, optional): The tile source. It can be one of the following: \"OPENSTREETMAP\", \"ROADMAP\",\n \"SATELLITE\", \"TERRAIN\", \"HYBRID\", or an HTTP URL. Defaults to \"OpenStreetMap\".\n crs (str, optional): The output CRS. Defaults to \"EPSG:3857\".\n to_cog (bool, optional): Convert to Cloud Optimized GeoTIFF. Defaults to False.\n return_image (bool, optional): Return the image as PIL.Image. Defaults to False.\n overwrite (bool, optional): Overwrite the output file if it already exists. Defaults to False.\n quiet (bool, optional): Suppress output. Defaults to False.\n **kwargs: Additional arguments to pass to gdal.GetDriverByName(\"GTiff\").Create().\n\n \"\"\"\n\n import os\n import io\n import math\n import itertools\n import concurrent.futures\n\n import numpy\n from PIL import Image\n\n try:\n from osgeo import gdal, osr\n except ImportError:\n raise ImportError(\"GDAL is not installed. Install it with pip install GDAL\")\n\n try:\n import httpx\n\n SESSION = httpx.Client()\n except ImportError:\n import requests\n\n SESSION = requests.Session()\n\n if not overwrite and os.path.exists(output):\n print(\n f\"The output file {output} already exists. Use `overwrite=True` to overwrite it.\"\n )\n return\n\n xyz_tiles = {\n \"OPENSTREETMAP\": \"https://tile.openstreetmap.org/{z}/{x}/{y}.png\",\n \"ROADMAP\": \"https://mt1.google.com/vt/lyrs=m&x={x}&y={y}&z={z}\",\n \"SATELLITE\": \"https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}\",\n \"TERRAIN\": \"https://mt1.google.com/vt/lyrs=p&x={x}&y={y}&z={z}\",\n \"HYBRID\": \"https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}\",\n }\n\n basemaps = get_basemaps()\n\n if isinstance(source, str):\n if source.upper() in xyz_tiles:\n source = xyz_tiles[source.upper()]\n elif source in basemaps:\n source = basemaps[source]\n elif source.startswith(\"http\"):\n pass\n else:\n raise ValueError(\n 'source must be one of \"OpenStreetMap\", \"ROADMAP\", \"SATELLITE\", \"TERRAIN\", \"HYBRID\", or a URL'\n )\n\n def resolution_to_zoom_level(resolution):\n\"\"\"\n Convert map resolution in meters to zoom level for Web Mercator (EPSG:3857) tiles.\n \"\"\"\n # Web Mercator tile size in meters at zoom level 0\n initial_resolution = 156543.03392804097\n\n # Calculate the zoom level\n zoom_level = math.log2(initial_resolution / resolution)\n\n return int(zoom_level)\n\n if isinstance(bbox, list) and len(bbox) == 4:\n west, south, east, north = bbox\n else:\n raise ValueError(\n \"bbox must be a list of 4 coordinates in the format of [xmin, ymin, xmax, ymax]\"\n )\n\n if zoom is None and resolution is None:\n raise ValueError(\"Either zoom or resolution must be provided\")\n elif zoom is not None and resolution is not None:\n raise ValueError(\"Only one of zoom or resolution can be provided\")\n\n if resolution is not None:\n zoom = resolution_to_zoom_level(resolution)\n\n EARTH_EQUATORIAL_RADIUS = 6378137.0\n\n Image.MAX_IMAGE_PIXELS = None\n\n gdal.UseExceptions()\n web_mercator = osr.SpatialReference()\n web_mercator.ImportFromEPSG(3857)\n\n WKT_3857 = web_mercator.ExportToWkt()\n\n def from4326_to3857(lat, lon):\n xtile = math.radians(lon) * EARTH_EQUATORIAL_RADIUS\n ytile = (\n math.log(math.tan(math.radians(45 + lat / 2.0))) * EARTH_EQUATORIAL_RADIUS\n )\n return (xtile, ytile)\n\n def deg2num(lat, lon, zoom):\n lat_r = math.radians(lat)\n n = 2**zoom\n xtile = (lon + 180) / 360 * n\n ytile = (1 - math.log(math.tan(lat_r) + 1 / math.cos(lat_r)) / math.pi) / 2 * n\n return (xtile, ytile)\n\n def is_empty(im):\n extrema = im.getextrema()\n if len(extrema) >= 3:\n if len(extrema) > 3 and extrema[-1] == (0, 0):\n return True\n for ext in extrema[:3]:\n if ext != (0, 0):\n return False\n return True\n else:\n return extrema[0] == (0, 0)\n\n def paste_tile(bigim, base_size, tile, corner_xy, bbox):\n if tile is None:\n return bigim\n im = Image.open(io.BytesIO(tile))\n mode = \"RGB\" if im.mode == \"RGB\" else \"RGBA\"\n size = im.size\n if bigim is None:\n base_size[0] = size[0]\n base_size[1] = size[1]\n newim = Image.new(\n mode, (size[0] * (bbox[2] - bbox[0]), size[1] * (bbox[3] - bbox[1]))\n )\n else:\n newim = bigim\n\n dx = abs(corner_xy[0] - bbox[0])\n dy = abs(corner_xy[1] - bbox[1])\n xy0 = (size[0] * dx, size[1] * dy)\n if mode == \"RGB\":\n newim.paste(im, xy0)\n else:\n if im.mode != mode:\n im = im.convert(mode)\n if not is_empty(im):\n newim.paste(im, xy0)\n im.close()\n return newim\n\n def finish_picture(bigim, base_size, bbox, x0, y0, x1, y1):\n xfrac = x0 - bbox[0]\n yfrac = y0 - bbox[1]\n x2 = round(base_size[0] * xfrac)\n y2 = round(base_size[1] * yfrac)\n imgw = round(base_size[0] * (x1 - x0))\n imgh = round(base_size[1] * (y1 - y0))\n retim = bigim.crop((x2, y2, x2 + imgw, y2 + imgh))\n if retim.mode == \"RGBA\" and retim.getextrema()[3] == (255, 255):\n retim = retim.convert(\"RGB\")\n bigim.close()\n return retim\n\n def get_tile(url):\n retry = 3\n while 1:\n try:\n r = SESSION.get(url, timeout=60)\n break\n except Exception:\n retry -= 1\n if not retry:\n raise\n if r.status_code == 404:\n return None\n elif not r.content:\n return None\n r.raise_for_status()\n return r.content\n\n def draw_tile(\n source, lat0, lon0, lat1, lon1, zoom, filename, quiet=False, **kwargs\n ):\n x0, y0 = deg2num(lat0, lon0, zoom)\n x1, y1 = deg2num(lat1, lon1, zoom)\n x0, x1 = sorted([x0, x1])\n y0, y1 = sorted([y0, y1])\n corners = tuple(\n itertools.product(\n range(math.floor(x0), math.ceil(x1)),\n range(math.floor(y0), math.ceil(y1)),\n )\n )\n totalnum = len(corners)\n futures = []\n with concurrent.futures.ThreadPoolExecutor(5) as executor:\n for x, y in corners:\n futures.append(\n executor.submit(get_tile, source.format(z=zoom, x=x, y=y))\n )\n bbox = (math.floor(x0), math.floor(y0), math.ceil(x1), math.ceil(y1))\n bigim = None\n base_size = [256, 256]\n for k, (fut, corner_xy) in enumerate(zip(futures, corners), 1):\n bigim = paste_tile(bigim, base_size, fut.result(), corner_xy, bbox)\n if not quiet:\n print(\n f\"Downloaded image {str(k).zfill(len(str(totalnum)))}/{totalnum}\"\n )\n\n if not quiet:\n print(\"Saving GeoTIFF. Please wait...\")\n img = finish_picture(bigim, base_size, bbox, x0, y0, x1, y1)\n imgbands = len(img.getbands())\n driver = gdal.GetDriverByName(\"GTiff\")\n\n if \"options\" not in kwargs:\n kwargs[\"options\"] = [\n \"COMPRESS=DEFLATE\",\n \"PREDICTOR=2\",\n \"ZLEVEL=9\",\n \"TILED=YES\",\n ]\n\n gtiff = driver.Create(\n filename,\n img.size[0],\n img.size[1],\n imgbands,\n gdal.GDT_Byte,\n **kwargs,\n )\n xp0, yp0 = from4326_to3857(lat0, lon0)\n xp1, yp1 = from4326_to3857(lat1, lon1)\n pwidth = abs(xp1 - xp0) / img.size[0]\n pheight = abs(yp1 - yp0) / img.size[1]\n gtiff.SetGeoTransform((min(xp0, xp1), pwidth, 0, max(yp0, yp1), 0, -pheight))\n gtiff.SetProjection(WKT_3857)\n for band in range(imgbands):\n array = numpy.array(img.getdata(band), dtype=\"u8\")\n array = array.reshape((img.size[1], img.size[0]))\n band = gtiff.GetRasterBand(band + 1)\n band.WriteArray(array)\n gtiff.FlushCache()\n\n if not quiet:\n print(f\"Image saved to {filename}\")\n return img\n\n try:\n image = draw_tile(\n source, south, west, north, east, zoom, output, quiet, **kwargs\n )\n if return_image:\n return image\n if crs.upper() != \"EPSG:3857\":\n reproject(output, output, crs, to_cog=to_cog)\n elif to_cog:\n image_to_cog(output, output)\n except Exception as e:\n raise Exception(e)\n
"},{"location":"common/#samgeo.common.transform_coords","title":"transform_coords(x, y, src_crs, dst_crs, **kwargs)
","text":"Transform coordinates from one CRS to another.
Parameters:
Name Type Description Defaultx
float
The x coordinate.
requiredy
float
The y coordinate.
requiredsrc_crs
str
The source CRS, e.g., \"EPSG:4326\".
requireddst_crs
str
The destination CRS, e.g., \"EPSG:3857\".
requiredReturns:
Type Descriptiondict
The transformed coordinates in the format of (x, y)
Source code insamgeo/common.py
def transform_coords(x, y, src_crs, dst_crs, **kwargs):\n\"\"\"Transform coordinates from one CRS to another.\n\n Args:\n x (float): The x coordinate.\n y (float): The y coordinate.\n src_crs (str): The source CRS, e.g., \"EPSG:4326\".\n dst_crs (str): The destination CRS, e.g., \"EPSG:3857\".\n\n Returns:\n dict: The transformed coordinates in the format of (x, y)\n \"\"\"\n transformer = pyproj.Transformer.from_crs(\n src_crs, dst_crs, always_xy=True, **kwargs\n )\n return transformer.transform(x, y)\n
"},{"location":"common/#samgeo.common.update_package","title":"update_package(out_dir=None, keep=False, **kwargs)
","text":"Updates the package from the GitHub repository without the need to use pip or conda.
Parameters:
Name Type Description Defaultout_dir
str
The output directory. Defaults to None.
None
keep
bool
Whether to keep the downloaded package. Defaults to False.
False
**kwargs
Additional keyword arguments to pass to the download_file() function.
{}
Source code in samgeo/common.py
def update_package(out_dir=None, keep=False, **kwargs):\n\"\"\"Updates the package from the GitHub repository without the need to use pip or conda.\n\n Args:\n out_dir (str, optional): The output directory. Defaults to None.\n keep (bool, optional): Whether to keep the downloaded package. Defaults to False.\n **kwargs: Additional keyword arguments to pass to the download_file() function.\n \"\"\"\n\n import shutil\n\n try:\n if out_dir is None:\n out_dir = os.getcwd()\n url = (\n \"https://github.com/opengeos/segment-geospatial/archive/refs/heads/main.zip\"\n )\n filename = \"segment-geospatial-main.zip\"\n download_file(url, filename, **kwargs)\n\n pkg_dir = os.path.join(out_dir, \"segment-geospatial-main\")\n work_dir = os.getcwd()\n os.chdir(pkg_dir)\n\n if shutil.which(\"pip\") is None:\n cmd = \"pip3 install .\"\n else:\n cmd = \"pip install .\"\n\n os.system(cmd)\n os.chdir(work_dir)\n\n if not keep:\n shutil.rmtree(pkg_dir)\n try:\n os.remove(filename)\n except:\n pass\n\n print(\"Package updated successfully.\")\n\n except Exception as e:\n raise Exception(e)\n
"},{"location":"common/#samgeo.common.vector_to_geojson","title":"vector_to_geojson(filename, output=None, **kwargs)
","text":"Converts a vector file to a geojson file.
Parameters:
Name Type Description Defaultfilename
str
The vector file path.
requiredoutput
str
The output geojson file path. Defaults to None.
None
Returns:
Type Descriptiondict
The geojson dictionary.
Source code insamgeo/common.py
def vector_to_geojson(filename, output=None, **kwargs):\n\"\"\"Converts a vector file to a geojson file.\n\n Args:\n filename (str): The vector file path.\n output (str, optional): The output geojson file path. Defaults to None.\n\n Returns:\n dict: The geojson dictionary.\n \"\"\"\n\n if not filename.startswith(\"http\"):\n filename = download_file(filename)\n\n gdf = gpd.read_file(filename, **kwargs)\n if output is None:\n return gdf.__geo_interface__\n else:\n gdf.to_file(output, driver=\"GeoJSON\")\n
"},{"location":"contributing/","title":"Contributing","text":"Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given.
You can contribute in many ways:
"},{"location":"contributing/#types-of-contributions","title":"Types of Contributions","text":""},{"location":"contributing/#report-bugs","title":"Report Bugs","text":"Report bugs at https://github.com/giswqs/segment-geospatial/issues.
If you are reporting a bug, please include:
Look through the GitHub issues for bugs. Anything tagged with bug
and help wanted
is open to whoever wants to implement it.
Look through the GitHub issues for features. Anything tagged with enhancement
and help wanted
is open to whoever wants to implement it.
segment-geospatial could always use more documentation, whether as part of the official segment-geospatial docs, in docstrings, or even on the web in blog posts, articles, and such.
"},{"location":"contributing/#submit-feedback","title":"Submit Feedback","text":"The best way to send feedback is to file an issue at https://github.com/giswqs/segment-geospatial/issues.
If you are proposing a feature:
Ready to contribute? Here's how to set up segment-geospatial for local development.
Fork the segment-geospatial repo on GitHub.
Clone your fork locally:
$ git clone git@github.com:your_name_here/segment-geospatial.git\n
Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:
$ mkvirtualenv segment-geospatial\n$ cd segment-geospatial/\n$ python setup.py develop\n
Create a branch for local development:
$ git checkout -b name-of-your-bugfix-or-feature\n
Now you can make your changes locally.
When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:
$ flake8 segment-geospatial tests\n$ python setup.py test or pytest\n$ tox\n
To get flake8 and tox, just pip install them into your virtualenv.
Commit your changes and push your branch to GitHub:
$ git add .\n$ git commit -m \"Your detailed description of your changes.\"\n$ git push origin name-of-your-bugfix-or-feature\n
Submit a pull request through the GitHub website.
Before you submit a pull request, check that it meets these guidelines:
Unit tests are in the tests
folder. If you add new functionality to the package, please add a unit test for it. You can either add the test to an existing test file or create a new one. For example, if you add a new function to samgeo/samgeo.py
, you can add the unit test to tests/test_samgeo.py
. If you add a new module to samgeo/<MODULE-NAME>
, you can create a new test file in tests/test_<MODULE-NAME>
. Please refer to tests/test_samgeo.py
for examples. For more information about unit testing, please refer to this tutorial - Getting Started With Testing in Python.
To run the unit tests, navigate to the root directory of the package and run the following command:
python -m unittest discover tests/\n
"},{"location":"contributing/#add-new-dependencies","title":"Add new dependencies","text":"If you PR involves adding new dependencies, please make sure that the new dependencies are available on both PyPI and conda-forge. Search here to see if the package is available on conda-forge. If the package is not available on conda-forge, it can't be added as a required dependency in requirements.txt
. Instead, it should be added as an optional dependency in requirements_dev.txt
.
If the package is available on PyPI and conda-forge, but if it is challenging to install the package on some operating systems, we would recommend adding the package as an optional dependency in requirements_dev.txt
rather than a required dependency in requirements.txt
.
The dependencies required for building the documentation should be added to requirements_docs.txt
. In most cases, contributors do not need to add new dependencies to requirements_docs.txt
unless the documentation fails to build due to missing dependencies.
Segment remote sensing imagery with HQ-SAM (High Quality Segment Anything Model). See https://github.com/SysCV/sam-hq
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo","title":" SamGeo
","text":"The main class for segmenting geospatial data with the Segment Anything Model (SAM). See https://github.com/facebookresearch/segment-anything for details.
Source code insamgeo/hq_sam.py
class SamGeo:\n\"\"\"The main class for segmenting geospatial data with the Segment Anything Model (SAM). See\n https://github.com/facebookresearch/segment-anything for details.\n \"\"\"\n\n def __init__(\n self,\n model_type=\"vit_h\",\n automatic=True,\n device=None,\n checkpoint_dir=None,\n hq=False,\n sam_kwargs=None,\n **kwargs,\n ):\n\"\"\"Initialize the class.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.\n The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.\n device (str, optional): The device to use. It can be one of the following: cpu, cuda.\n Defaults to None, which will use cuda if available.\n hq (bool, optional): Whether to use the HQ-SAM model. Defaults to False.\n checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:\n sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.\n Defaults to None. See https://bit.ly/3VrpxUh for more details.\n sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.\n The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.\n\n points_per_side: Optional[int] = 32,\n points_per_batch: int = 64,\n pred_iou_thresh: float = 0.88,\n stability_score_thresh: float = 0.95,\n stability_score_offset: float = 1.0,\n box_nms_thresh: float = 0.7,\n crop_n_layers: int = 0,\n crop_nms_thresh: float = 0.7,\n crop_overlap_ratio: float = 512 / 1500,\n crop_n_points_downscale_factor: int = 1,\n point_grids: Optional[List[np.ndarray]] = None,\n min_mask_region_area: int = 0,\n output_mode: str = \"binary_mask\",\n\n \"\"\"\n\n hq = True # Using HQ-SAM\n if \"checkpoint\" in kwargs:\n checkpoint = kwargs[\"checkpoint\"]\n if not os.path.exists(checkpoint):\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n kwargs.pop(\"checkpoint\")\n else:\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n\n # Use cuda if available\n if device is None:\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n if device == \"cuda\":\n torch.cuda.empty_cache()\n\n self.checkpoint = checkpoint\n self.model_type = model_type\n self.device = device\n self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model\n self.source = None # Store the input image path\n self.image = None # Store the input image as a numpy array\n # Store the masks as a list of dictionaries. Each mask is a dictionary\n # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box\n self.masks = None\n self.objects = None # Store the mask objects as a numpy array\n # Store the annotations (objects with random color) as a numpy array.\n self.annotations = None\n\n # Store the predicted masks, iou_predictions, and low_res_masks\n self.prediction = None\n self.scores = None\n self.logits = None\n\n # Build the SAM model\n self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)\n self.sam.to(device=self.device)\n # Use optional arguments for fine-tuning the SAM model\n sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}\n\n if automatic:\n # Segment the entire image using the automatic mask generator\n self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)\n else:\n # Segment selected objects using input prompts\n self.predictor = SamPredictor(self.sam, **sam_kwargs)\n\n def __call__(\n self,\n image,\n foreground=True,\n erosion_kernel=(3, 3),\n mask_multiplier=255,\n **kwargs,\n ):\n\"\"\"Generate masks for the input tile. This function originates from the segment-anything-eo repository.\n See https://bit.ly/41pwiHw\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n \"\"\"\n h, w, _ = image.shape\n\n masks = self.mask_generator.generate(image)\n\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=np.uint8)\n else:\n resulting_mask = np.ones((h, w), dtype=np.uint8)\n resulting_borders = np.zeros((h, w), dtype=np.uint8)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(np.uint8)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(np.uint8)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(np.uint8)\n resulting_borders = (resulting_borders > 0).astype(np.uint8)\n resulting_mask_with_borders = resulting_mask - resulting_borders\n return resulting_mask_with_borders * mask_multiplier\n\n def generate(\n self,\n source,\n output=None,\n foreground=True,\n batch=False,\n erosion_kernel=None,\n mask_multiplier=255,\n unique=True,\n **kwargs,\n ):\n\"\"\"Generate masks for the input image.\n\n Args:\n source (str | np.ndarray): The path to the input image or the input image as a numpy array.\n output (str, optional): The path to the output image. Defaults to None.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n The parameter is ignored if unique is True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.\n\n \"\"\"\n\n if isinstance(source, str):\n if source.startswith(\"http\"):\n source = download_file(source)\n\n if not os.path.exists(source):\n raise ValueError(f\"Input path {source} does not exist.\")\n\n if batch: # Subdivide the image into tiles and segment each tile\n self.batch = True\n self.source = source\n self.masks = output\n return tiff_to_tiff(\n source,\n output,\n self,\n foreground=foreground,\n erosion_kernel=erosion_kernel,\n mask_multiplier=mask_multiplier,\n **kwargs,\n )\n\n image = cv2.imread(source)\n elif isinstance(source, np.ndarray):\n image = source\n source = None\n else:\n raise ValueError(\"Input source must be either a path or a numpy array.\")\n\n self.source = source # Store the input image path\n self.image = image # Store the input image as a numpy array\n mask_generator = self.mask_generator # The automatic mask generator\n masks = mask_generator.generate(image) # Segment the input image\n self.masks = masks # Store the masks as a list of dictionaries\n self.batch = False\n\n if output is not None:\n # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n self.save_masks(\n output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs\n )\n\n def save_masks(\n self,\n output=None,\n foreground=True,\n unique=True,\n erosion_kernel=None,\n mask_multiplier=255,\n **kwargs,\n ):\n\"\"\"Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n\n Args:\n output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n\n \"\"\"\n\n if self.masks is None:\n raise ValueError(\"No masks found. Please run generate() first.\")\n\n h, w, _ = self.image.shape\n masks = self.masks\n\n # Set output image data type based on the number of objects\n if len(masks) < 255:\n dtype = np.uint8\n elif len(masks) < 65535:\n dtype = np.uint16\n else:\n dtype = np.uint32\n\n # Generate a mask of objects with unique values\n if unique:\n # Sort the masks by area in ascending order\n sorted_masks = sorted(masks, key=(lambda x: x[\"area\"]), reverse=False)\n\n # Create an output image with the same size as the input image\n objects = np.zeros(\n (\n sorted_masks[0][\"segmentation\"].shape[0],\n sorted_masks[0][\"segmentation\"].shape[1],\n )\n )\n # Assign a unique value to each object\n for index, ann in enumerate(sorted_masks):\n m = ann[\"segmentation\"]\n objects[m] = index + 1\n\n # Generate a binary mask\n else:\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=dtype)\n else:\n resulting_mask = np.ones((h, w), dtype=dtype)\n resulting_borders = np.zeros((h, w), dtype=dtype)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(dtype)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(dtype)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(dtype)\n resulting_borders = (resulting_borders > 0).astype(dtype)\n objects = resulting_mask - resulting_borders\n objects = objects * mask_multiplier\n\n objects = objects.astype(dtype)\n self.objects = objects\n\n if output is not None: # Save the output image\n array_to_image(self.objects, output, self.source, **kwargs)\n\n def show_masks(\n self, figsize=(12, 10), cmap=\"binary_r\", axis=\"off\", foreground=True, **kwargs\n ):\n\"\"\"Show the binary mask or the mask of objects with unique values.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n cmap (str, optional): The colormap. Defaults to \"binary_r\".\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.\n **kwargs: Other arguments for save_masks().\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n if self.batch:\n self.objects = cv2.imread(self.masks)\n else:\n if self.objects is None:\n self.save_masks(foreground=foreground, **kwargs)\n\n plt.figure(figsize=figsize)\n plt.imshow(self.objects, cmap=cmap)\n plt.axis(axis)\n plt.show()\n\n def show_anns(\n self,\n figsize=(12, 10),\n axis=\"off\",\n alpha=0.35,\n output=None,\n blend=True,\n **kwargs,\n ):\n\"\"\"Show the annotations (objects with random color) on the input image.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.\n output (str, optional): The path to the output image. Defaults to None.\n blend (bool, optional): Whether to show the input image. Defaults to True.\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n anns = self.masks\n\n if self.image is None:\n print(\"Please run generate() first.\")\n return\n\n if anns is None or len(anns) == 0:\n return\n\n plt.figure(figsize=figsize)\n plt.imshow(self.image)\n\n sorted_anns = sorted(anns, key=(lambda x: x[\"area\"]), reverse=True)\n\n ax = plt.gca()\n ax.set_autoscale_on(False)\n\n img = np.ones(\n (\n sorted_anns[0][\"segmentation\"].shape[0],\n sorted_anns[0][\"segmentation\"].shape[1],\n 4,\n )\n )\n img[:, :, 3] = 0\n for ann in sorted_anns:\n m = ann[\"segmentation\"]\n color_mask = np.concatenate([np.random.random(3), [alpha]])\n img[m] = color_mask\n ax.imshow(img)\n\n if \"dpi\" not in kwargs:\n kwargs[\"dpi\"] = 100\n\n if \"bbox_inches\" not in kwargs:\n kwargs[\"bbox_inches\"] = \"tight\"\n\n plt.axis(axis)\n\n self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)\n\n if output is not None:\n if blend:\n array = blend_images(\n self.annotations, self.image, alpha=alpha, show=False\n )\n else:\n array = self.annotations\n array_to_image(array, output, self.source)\n\n def set_image(self, image, image_format=\"RGB\"):\n\"\"\"Set the input image as a numpy array.\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n image_format (str, optional): The image format, can be RGB or BGR. Defaults to \"RGB\".\n \"\"\"\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n\n image = cv2.imread(image)\n image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n self.image = image\n elif isinstance(image, np.ndarray):\n pass\n else:\n raise ValueError(\"Input image must be either a path or a numpy array.\")\n\n self.predictor.set_image(image, image_format=image_format)\n\n def save_prediction(\n self,\n output,\n index=None,\n mask_multiplier=255,\n dtype=np.float32,\n vector=None,\n simplify_tolerance=None,\n **kwargs,\n ):\n\"\"\"Save the predicted mask to the output path.\n\n Args:\n output (str): The path to the output image.\n index (int, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n vector (str, optional): The path to the output vector file. Defaults to None.\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n\n \"\"\"\n if self.scores is None:\n raise ValueError(\"No predictions found. Please run predict() first.\")\n\n if index is None:\n index = self.scores.argmax(axis=0)\n\n array = self.masks[index] * mask_multiplier\n self.prediction = array\n array_to_image(array, output, self.source, dtype=dtype, **kwargs)\n\n if vector is not None:\n raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)\n\n def predict(\n self,\n point_coords=None,\n point_labels=None,\n boxes=None,\n point_crs=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n output=None,\n index=None,\n mask_multiplier=255,\n dtype=\"float32\",\n return_results=False,\n **kwargs,\n ):\n\"\"\"Predict masks for the given input prompts, using the currently set image.\n\n Args:\n point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the\n model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON\n dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.\n point_labels (list | int | np.ndarray, optional): A length N array of labels for the\n point prompts. 1 indicates a foreground point and 0 indicates a background point.\n point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.\n boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the\n model, in XYXY format.\n mask_input (np.ndarray, optional): A low resolution mask input to the model, typically\n coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.\n multimask_output (bool, optional): If true, the model will return three masks.\n For ambiguous input prompts (such as a single click), this will often\n produce better masks than a single prediction. If only a single\n mask is needed, the model's predicted quality score can be used\n to select the best mask. For non-ambiguous prompts, such as multiple\n input prompts, multimask_output=False can give better results.\n return_logits (bool, optional): If true, returns un-thresholded masks logits\n instead of a binary mask.\n output (str, optional): The path to the output image. Defaults to None.\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.\n\n \"\"\"\n\n if isinstance(boxes, str):\n gdf = gpd.read_file(boxes)\n if gdf.crs is not None:\n gdf = gdf.to_crs(\"epsg:4326\")\n boxes = gdf.geometry.bounds.values.tolist()\n elif isinstance(boxes, dict):\n import json\n\n geojson = json.dumps(boxes)\n gdf = gpd.read_file(geojson, driver=\"GeoJSON\")\n boxes = gdf.geometry.bounds.values.tolist()\n\n if isinstance(point_coords, str):\n point_coords = vector_to_geojson(point_coords)\n\n if isinstance(point_coords, dict):\n point_coords = geojson_to_coords(point_coords)\n\n if hasattr(self, \"point_coords\"):\n point_coords = self.point_coords\n\n if hasattr(self, \"point_labels\"):\n point_labels = self.point_labels\n\n if (point_crs is not None) and (point_coords is not None):\n point_coords = coords_to_xy(self.source, point_coords, point_crs)\n\n if isinstance(point_coords, list):\n point_coords = np.array(point_coords)\n\n if point_coords is not None:\n if point_labels is None:\n point_labels = [1] * len(point_coords)\n elif isinstance(point_labels, int):\n point_labels = [point_labels] * len(point_coords)\n\n if isinstance(point_labels, list):\n if len(point_labels) != len(point_coords):\n if len(point_labels) == 1:\n point_labels = point_labels * len(point_coords)\n else:\n raise ValueError(\n \"The length of point_labels must be equal to the length of point_coords.\"\n )\n point_labels = np.array(point_labels)\n\n predictor = self.predictor\n\n input_boxes = None\n if isinstance(boxes, list) and (point_crs is not None):\n coords = bbox_to_xy(self.source, boxes, point_crs)\n input_boxes = np.array(coords)\n if isinstance(coords[0], int):\n input_boxes = input_boxes[None, :]\n else:\n input_boxes = torch.tensor(input_boxes, device=self.device)\n input_boxes = predictor.transform.apply_boxes_torch(\n input_boxes, self.image.shape[:2]\n )\n elif isinstance(boxes, list) and (point_crs is None):\n input_boxes = np.array(boxes)\n if isinstance(boxes[0], int):\n input_boxes = input_boxes[None, :]\n\n self.boxes = input_boxes\n\n if boxes is None or (not isinstance(boxes[0], list)):\n masks, scores, logits = predictor.predict(\n point_coords,\n point_labels,\n input_boxes,\n mask_input,\n multimask_output,\n return_logits,\n )\n else:\n masks, scores, logits = predictor.predict_torch(\n point_coords=point_coords,\n point_labels=point_coords,\n boxes=input_boxes,\n multimask_output=True,\n )\n\n self.masks = masks\n self.scores = scores\n self.logits = logits\n\n if output is not None:\n if boxes is None or (not isinstance(boxes[0], list)):\n self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)\n else:\n self.tensor_to_numpy(\n index, output, mask_multiplier, dtype, save_args=kwargs\n )\n\n if return_results:\n return masks, scores, logits\n\n def tensor_to_numpy(\n self, index=None, output=None, mask_multiplier=255, dtype=\"uint8\", save_args={}\n ):\n\"\"\"Convert the predicted masks from tensors to numpy arrays.\n\n Args:\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n output (str, optional): The path to the output image. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.\n save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.\n\n Returns:\n np.ndarray: The predicted mask as a numpy array.\n \"\"\"\n\n boxes = self.boxes\n masks = self.masks\n\n image_pil = self.image\n image_np = np.array(image_pil)\n\n if index is None:\n index = 1\n\n masks = masks[:, index, :, :]\n masks = masks.squeeze(1)\n\n if boxes is None or (len(boxes) == 0): # No \"object\" instances found\n print(\"No objects found in the image.\")\n return\n else:\n # Create an empty image to store the mask overlays\n mask_overlay = np.zeros_like(\n image_np[..., 0], dtype=dtype\n ) # Adjusted for single channel\n\n for i, (box, mask) in enumerate(zip(boxes, masks)):\n # Convert tensor to numpy array if necessary and ensure it contains integers\n if isinstance(mask, torch.Tensor):\n mask = (\n mask.cpu().numpy().astype(dtype)\n ) # If mask is on GPU, use .cpu() before .numpy()\n mask_overlay += ((mask > 0) * (i + 1)).astype(\n dtype\n ) # Assign a unique value for each mask\n\n # Normalize mask_overlay to be in [0, 255]\n mask_overlay = (\n mask_overlay > 0\n ) * mask_multiplier # Binary mask in [0, 255]\n\n if output is not None:\n array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)\n else:\n return mask_overlay\n\n def show_map(self, basemap=\"SATELLITE\", repeat_mode=True, out_dir=None, **kwargs):\n\"\"\"Show the interactive map.\n\n Args:\n basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.\n repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.\n out_dir (str, optional): The path to the output directory. Defaults to None.\n\n Returns:\n leafmap.Map: The map object.\n \"\"\"\n return sam_map_gui(\n self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs\n )\n\n def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):\n\"\"\"Show a canvas to collect foreground and background points.\n\n Args:\n image (str | np.ndarray): The input image.\n fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).\n bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).\n radius (int, optional): The radius of the points. Defaults to 5.\n\n Returns:\n tuple: A tuple of two lists of foreground and background points.\n \"\"\"\n\n if self.image is None:\n raise ValueError(\"Please run set_image() first.\")\n\n image = self.image\n fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)\n self.fg_points = fg_points\n self.bg_points = bg_points\n point_coords = fg_points + bg_points\n point_labels = [1] * len(fg_points) + [0] * len(bg_points)\n self.point_coords = point_coords\n self.point_labels = point_labels\n\n def clear_cuda_cache(self):\n\"\"\"Clear the CUDA cache.\"\"\"\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n\n def image_to_image(self, image, **kwargs):\n return image_to_image(image, self, **kwargs)\n\n def download_tms_as_tiff(self, source, pt1, pt2, zoom, dist):\n image = draw_tile(source, pt1[0], pt1[1], pt2[0], pt2[1], zoom, dist)\n return image\n\n def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):\n\"\"\"Save the result to a vector file.\n\n Args:\n image (str): The path to the image file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)\n\n def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n\n def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the gpkg file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_gpkg(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n\n def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a shapefile.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the shapefile.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_shp(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n\n def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a GeoJSON file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the GeoJSON file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_geojson(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.__call__","title":"__call__(self, image, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255, **kwargs)
special
","text":"Generate masks for the input tile. This function originates from the segment-anything-eo repository. See https://bit.ly/41pwiHw
Parameters:
Name Type Description Defaultimage
np.ndarray
The input image as a numpy array.
requiredforeground
bool
Whether to generate the foreground mask. Defaults to True.
True
erosion_kernel
tuple
The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
(3, 3)
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
255
Source code in samgeo/hq_sam.py
def __call__(\n self,\n image,\n foreground=True,\n erosion_kernel=(3, 3),\n mask_multiplier=255,\n **kwargs,\n):\n\"\"\"Generate masks for the input tile. This function originates from the segment-anything-eo repository.\n See https://bit.ly/41pwiHw\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n \"\"\"\n h, w, _ = image.shape\n\n masks = self.mask_generator.generate(image)\n\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=np.uint8)\n else:\n resulting_mask = np.ones((h, w), dtype=np.uint8)\n resulting_borders = np.zeros((h, w), dtype=np.uint8)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(np.uint8)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(np.uint8)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(np.uint8)\n resulting_borders = (resulting_borders > 0).astype(np.uint8)\n resulting_mask_with_borders = resulting_mask - resulting_borders\n return resulting_mask_with_borders * mask_multiplier\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.__init__","title":"__init__(self, model_type='vit_h', automatic=True, device=None, checkpoint_dir=None, hq=False, sam_kwargs=None, **kwargs)
special
","text":"Initialize the class.
Parameters:
Name Type Description Defaultmodel_type
str
The model type. It can be one of the following: vit_h, vit_l, vit_b. Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
'vit_h'
automatic
bool
Whether to use the automatic mask generator or input prompts. Defaults to True. The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
True
device
str
The device to use. It can be one of the following: cpu, cuda. Defaults to None, which will use cuda if available.
None
hq
bool
Whether to use the HQ-SAM model. Defaults to False.
False
checkpoint_dir
str
The path to the model checkpoint. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth. Defaults to None. See https://bit.ly/3VrpxUh for more details.
None
sam_kwargs
dict
Optional arguments for fine-tuning the SAM model. Defaults to None. The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
points_per_side: Optional[int] = 32, points_per_batch: int = 64, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, box_nms_thresh: float = 0.7, crop_n_layers: int = 0, crop_nms_thresh: float = 0.7, crop_overlap_ratio: float = 512 / 1500, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = \"binary_mask\",
None
Source code in samgeo/hq_sam.py
def __init__(\n self,\n model_type=\"vit_h\",\n automatic=True,\n device=None,\n checkpoint_dir=None,\n hq=False,\n sam_kwargs=None,\n **kwargs,\n):\n\"\"\"Initialize the class.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.\n The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.\n device (str, optional): The device to use. It can be one of the following: cpu, cuda.\n Defaults to None, which will use cuda if available.\n hq (bool, optional): Whether to use the HQ-SAM model. Defaults to False.\n checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:\n sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.\n Defaults to None. See https://bit.ly/3VrpxUh for more details.\n sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.\n The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.\n\n points_per_side: Optional[int] = 32,\n points_per_batch: int = 64,\n pred_iou_thresh: float = 0.88,\n stability_score_thresh: float = 0.95,\n stability_score_offset: float = 1.0,\n box_nms_thresh: float = 0.7,\n crop_n_layers: int = 0,\n crop_nms_thresh: float = 0.7,\n crop_overlap_ratio: float = 512 / 1500,\n crop_n_points_downscale_factor: int = 1,\n point_grids: Optional[List[np.ndarray]] = None,\n min_mask_region_area: int = 0,\n output_mode: str = \"binary_mask\",\n\n \"\"\"\n\n hq = True # Using HQ-SAM\n if \"checkpoint\" in kwargs:\n checkpoint = kwargs[\"checkpoint\"]\n if not os.path.exists(checkpoint):\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n kwargs.pop(\"checkpoint\")\n else:\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n\n # Use cuda if available\n if device is None:\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n if device == \"cuda\":\n torch.cuda.empty_cache()\n\n self.checkpoint = checkpoint\n self.model_type = model_type\n self.device = device\n self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model\n self.source = None # Store the input image path\n self.image = None # Store the input image as a numpy array\n # Store the masks as a list of dictionaries. Each mask is a dictionary\n # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box\n self.masks = None\n self.objects = None # Store the mask objects as a numpy array\n # Store the annotations (objects with random color) as a numpy array.\n self.annotations = None\n\n # Store the predicted masks, iou_predictions, and low_res_masks\n self.prediction = None\n self.scores = None\n self.logits = None\n\n # Build the SAM model\n self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)\n self.sam.to(device=self.device)\n # Use optional arguments for fine-tuning the SAM model\n sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}\n\n if automatic:\n # Segment the entire image using the automatic mask generator\n self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)\n else:\n # Segment selected objects using input prompts\n self.predictor = SamPredictor(self.sam, **sam_kwargs)\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.clear_cuda_cache","title":"clear_cuda_cache(self)
","text":"Clear the CUDA cache.
Source code insamgeo/hq_sam.py
def clear_cuda_cache(self):\n\"\"\"Clear the CUDA cache.\"\"\"\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.generate","title":"generate(self, source, output=None, foreground=True, batch=False, erosion_kernel=None, mask_multiplier=255, unique=True, **kwargs)
","text":"Generate masks for the input image.
Parameters:
Name Type Description Defaultsource
str | np.ndarray
The path to the input image or the input image as a numpy array.
requiredoutput
str
The path to the output image. Defaults to None.
None
foreground
bool
Whether to generate the foreground mask. Defaults to True.
True
batch
bool
Whether to generate masks for a batch of image tiles. Defaults to False.
False
erosion_kernel
tuple
The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. The parameter is ignored if unique is True.
255
unique
bool
Whether to assign a unique value to each object. Defaults to True. The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
True
Source code in samgeo/hq_sam.py
def generate(\n self,\n source,\n output=None,\n foreground=True,\n batch=False,\n erosion_kernel=None,\n mask_multiplier=255,\n unique=True,\n **kwargs,\n):\n\"\"\"Generate masks for the input image.\n\n Args:\n source (str | np.ndarray): The path to the input image or the input image as a numpy array.\n output (str, optional): The path to the output image. Defaults to None.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n The parameter is ignored if unique is True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.\n\n \"\"\"\n\n if isinstance(source, str):\n if source.startswith(\"http\"):\n source = download_file(source)\n\n if not os.path.exists(source):\n raise ValueError(f\"Input path {source} does not exist.\")\n\n if batch: # Subdivide the image into tiles and segment each tile\n self.batch = True\n self.source = source\n self.masks = output\n return tiff_to_tiff(\n source,\n output,\n self,\n foreground=foreground,\n erosion_kernel=erosion_kernel,\n mask_multiplier=mask_multiplier,\n **kwargs,\n )\n\n image = cv2.imread(source)\n elif isinstance(source, np.ndarray):\n image = source\n source = None\n else:\n raise ValueError(\"Input source must be either a path or a numpy array.\")\n\n self.source = source # Store the input image path\n self.image = image # Store the input image as a numpy array\n mask_generator = self.mask_generator # The automatic mask generator\n masks = mask_generator.generate(image) # Segment the input image\n self.masks = masks # Store the masks as a list of dictionaries\n self.batch = False\n\n if output is not None:\n # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n self.save_masks(\n output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs\n )\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.predict","title":"predict(self, point_coords=None, point_labels=None, boxes=None, point_crs=None, mask_input=None, multimask_output=True, return_logits=False, output=None, index=None, mask_multiplier=255, dtype='float32', return_results=False, **kwargs)
","text":"Predict masks for the given input prompts, using the currently set image.
Parameters:
Name Type Description Defaultpoint_coords
str | dict | list | np.ndarray
A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
None
point_labels
list | int | np.ndarray
A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
None
point_crs
str
The coordinate reference system (CRS) of the point prompts.
None
boxes
list | np.ndarray
A length 4 array given a box prompt to the model, in XYXY format.
None
mask_input
np.ndarray
A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool, optional): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.
None
return_logits
bool
If true, returns un-thresholded masks logits instead of a binary mask.
False
output
str
The path to the output image. Defaults to None.
None
index
index
The index of the mask to save. Defaults to None, which will save the mask with the highest score.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1].
255
dtype
np.dtype
The data type of the output image. Defaults to np.float32.
'float32'
return_results
bool
Whether to return the predicted masks, scores, and logits. Defaults to False.
False
Source code in samgeo/hq_sam.py
def predict(\n self,\n point_coords=None,\n point_labels=None,\n boxes=None,\n point_crs=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n output=None,\n index=None,\n mask_multiplier=255,\n dtype=\"float32\",\n return_results=False,\n **kwargs,\n):\n\"\"\"Predict masks for the given input prompts, using the currently set image.\n\n Args:\n point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the\n model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON\n dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.\n point_labels (list | int | np.ndarray, optional): A length N array of labels for the\n point prompts. 1 indicates a foreground point and 0 indicates a background point.\n point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.\n boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the\n model, in XYXY format.\n mask_input (np.ndarray, optional): A low resolution mask input to the model, typically\n coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.\n multimask_output (bool, optional): If true, the model will return three masks.\n For ambiguous input prompts (such as a single click), this will often\n produce better masks than a single prediction. If only a single\n mask is needed, the model's predicted quality score can be used\n to select the best mask. For non-ambiguous prompts, such as multiple\n input prompts, multimask_output=False can give better results.\n return_logits (bool, optional): If true, returns un-thresholded masks logits\n instead of a binary mask.\n output (str, optional): The path to the output image. Defaults to None.\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.\n\n \"\"\"\n\n if isinstance(boxes, str):\n gdf = gpd.read_file(boxes)\n if gdf.crs is not None:\n gdf = gdf.to_crs(\"epsg:4326\")\n boxes = gdf.geometry.bounds.values.tolist()\n elif isinstance(boxes, dict):\n import json\n\n geojson = json.dumps(boxes)\n gdf = gpd.read_file(geojson, driver=\"GeoJSON\")\n boxes = gdf.geometry.bounds.values.tolist()\n\n if isinstance(point_coords, str):\n point_coords = vector_to_geojson(point_coords)\n\n if isinstance(point_coords, dict):\n point_coords = geojson_to_coords(point_coords)\n\n if hasattr(self, \"point_coords\"):\n point_coords = self.point_coords\n\n if hasattr(self, \"point_labels\"):\n point_labels = self.point_labels\n\n if (point_crs is not None) and (point_coords is not None):\n point_coords = coords_to_xy(self.source, point_coords, point_crs)\n\n if isinstance(point_coords, list):\n point_coords = np.array(point_coords)\n\n if point_coords is not None:\n if point_labels is None:\n point_labels = [1] * len(point_coords)\n elif isinstance(point_labels, int):\n point_labels = [point_labels] * len(point_coords)\n\n if isinstance(point_labels, list):\n if len(point_labels) != len(point_coords):\n if len(point_labels) == 1:\n point_labels = point_labels * len(point_coords)\n else:\n raise ValueError(\n \"The length of point_labels must be equal to the length of point_coords.\"\n )\n point_labels = np.array(point_labels)\n\n predictor = self.predictor\n\n input_boxes = None\n if isinstance(boxes, list) and (point_crs is not None):\n coords = bbox_to_xy(self.source, boxes, point_crs)\n input_boxes = np.array(coords)\n if isinstance(coords[0], int):\n input_boxes = input_boxes[None, :]\n else:\n input_boxes = torch.tensor(input_boxes, device=self.device)\n input_boxes = predictor.transform.apply_boxes_torch(\n input_boxes, self.image.shape[:2]\n )\n elif isinstance(boxes, list) and (point_crs is None):\n input_boxes = np.array(boxes)\n if isinstance(boxes[0], int):\n input_boxes = input_boxes[None, :]\n\n self.boxes = input_boxes\n\n if boxes is None or (not isinstance(boxes[0], list)):\n masks, scores, logits = predictor.predict(\n point_coords,\n point_labels,\n input_boxes,\n mask_input,\n multimask_output,\n return_logits,\n )\n else:\n masks, scores, logits = predictor.predict_torch(\n point_coords=point_coords,\n point_labels=point_coords,\n boxes=input_boxes,\n multimask_output=True,\n )\n\n self.masks = masks\n self.scores = scores\n self.logits = logits\n\n if output is not None:\n if boxes is None or (not isinstance(boxes[0], list)):\n self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)\n else:\n self.tensor_to_numpy(\n index, output, mask_multiplier, dtype, save_args=kwargs\n )\n\n if return_results:\n return masks, scores, logits\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.raster_to_vector","title":"raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs)
","text":"Save the result to a vector file.
Parameters:
Name Type Description Defaultimage
str
The path to the image file.
requiredoutput
str
The path to the vector file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/hq_sam.py
def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):\n\"\"\"Save the result to a vector file.\n\n Args:\n image (str): The path to the image file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.save_masks","title":"save_masks(self, output=None, foreground=True, unique=True, erosion_kernel=None, mask_multiplier=255, **kwargs)
","text":"Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
Parameters:
Name Type Description Defaultoutput
str
The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
None
foreground
bool
Whether to generate the foreground mask. Defaults to True.
True
unique
bool
Whether to assign a unique value to each object. Defaults to True.
True
erosion_kernel
tuple
The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
255
Source code in samgeo/hq_sam.py
def save_masks(\n self,\n output=None,\n foreground=True,\n unique=True,\n erosion_kernel=None,\n mask_multiplier=255,\n **kwargs,\n):\n\"\"\"Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n\n Args:\n output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n\n \"\"\"\n\n if self.masks is None:\n raise ValueError(\"No masks found. Please run generate() first.\")\n\n h, w, _ = self.image.shape\n masks = self.masks\n\n # Set output image data type based on the number of objects\n if len(masks) < 255:\n dtype = np.uint8\n elif len(masks) < 65535:\n dtype = np.uint16\n else:\n dtype = np.uint32\n\n # Generate a mask of objects with unique values\n if unique:\n # Sort the masks by area in ascending order\n sorted_masks = sorted(masks, key=(lambda x: x[\"area\"]), reverse=False)\n\n # Create an output image with the same size as the input image\n objects = np.zeros(\n (\n sorted_masks[0][\"segmentation\"].shape[0],\n sorted_masks[0][\"segmentation\"].shape[1],\n )\n )\n # Assign a unique value to each object\n for index, ann in enumerate(sorted_masks):\n m = ann[\"segmentation\"]\n objects[m] = index + 1\n\n # Generate a binary mask\n else:\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=dtype)\n else:\n resulting_mask = np.ones((h, w), dtype=dtype)\n resulting_borders = np.zeros((h, w), dtype=dtype)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(dtype)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(dtype)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(dtype)\n resulting_borders = (resulting_borders > 0).astype(dtype)\n objects = resulting_mask - resulting_borders\n objects = objects * mask_multiplier\n\n objects = objects.astype(dtype)\n self.objects = objects\n\n if output is not None: # Save the output image\n array_to_image(self.objects, output, self.source, **kwargs)\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.save_prediction","title":"save_prediction(self, output, index=None, mask_multiplier=255, dtype=<class 'numpy.float32'>, vector=None, simplify_tolerance=None, **kwargs)
","text":"Save the predicted mask to the output path.
Parameters:
Name Type Description Defaultoutput
str
The path to the output image.
requiredindex
int
The index of the mask to save. Defaults to None, which will save the mask with the highest score.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1].
255
vector
str
The path to the output vector file. Defaults to None.
None
dtype
np.dtype
The data type of the output image. Defaults to np.float32.
<class 'numpy.float32'>
simplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/hq_sam.py
def save_prediction(\n self,\n output,\n index=None,\n mask_multiplier=255,\n dtype=np.float32,\n vector=None,\n simplify_tolerance=None,\n **kwargs,\n):\n\"\"\"Save the predicted mask to the output path.\n\n Args:\n output (str): The path to the output image.\n index (int, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n vector (str, optional): The path to the output vector file. Defaults to None.\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n\n \"\"\"\n if self.scores is None:\n raise ValueError(\"No predictions found. Please run predict() first.\")\n\n if index is None:\n index = self.scores.argmax(axis=0)\n\n array = self.masks[index] * mask_multiplier\n self.prediction = array\n array_to_image(array, output, self.source, dtype=dtype, **kwargs)\n\n if vector is not None:\n raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.set_image","title":"set_image(self, image, image_format='RGB')
","text":"Set the input image as a numpy array.
Parameters:
Name Type Description Defaultimage
np.ndarray
The input image as a numpy array.
requiredimage_format
str
The image format, can be RGB or BGR. Defaults to \"RGB\".
'RGB'
Source code in samgeo/hq_sam.py
def set_image(self, image, image_format=\"RGB\"):\n\"\"\"Set the input image as a numpy array.\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n image_format (str, optional): The image format, can be RGB or BGR. Defaults to \"RGB\".\n \"\"\"\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n\n image = cv2.imread(image)\n image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n self.image = image\n elif isinstance(image, np.ndarray):\n pass\n else:\n raise ValueError(\"Input image must be either a path or a numpy array.\")\n\n self.predictor.set_image(image, image_format=image_format)\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.show_anns","title":"show_anns(self, figsize=(12, 10), axis='off', alpha=0.35, output=None, blend=True, **kwargs)
","text":"Show the annotations (objects with random color) on the input image.
Parameters:
Name Type Description Defaultfigsize
tuple
The figure size. Defaults to (12, 10).
(12, 10)
axis
str
Whether to show the axis. Defaults to \"off\".
'off'
alpha
float
The alpha value for the annotations. Defaults to 0.35.
0.35
output
str
The path to the output image. Defaults to None.
None
blend
bool
Whether to show the input image. Defaults to True.
True
Source code in samgeo/hq_sam.py
def show_anns(\n self,\n figsize=(12, 10),\n axis=\"off\",\n alpha=0.35,\n output=None,\n blend=True,\n **kwargs,\n):\n\"\"\"Show the annotations (objects with random color) on the input image.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.\n output (str, optional): The path to the output image. Defaults to None.\n blend (bool, optional): Whether to show the input image. Defaults to True.\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n anns = self.masks\n\n if self.image is None:\n print(\"Please run generate() first.\")\n return\n\n if anns is None or len(anns) == 0:\n return\n\n plt.figure(figsize=figsize)\n plt.imshow(self.image)\n\n sorted_anns = sorted(anns, key=(lambda x: x[\"area\"]), reverse=True)\n\n ax = plt.gca()\n ax.set_autoscale_on(False)\n\n img = np.ones(\n (\n sorted_anns[0][\"segmentation\"].shape[0],\n sorted_anns[0][\"segmentation\"].shape[1],\n 4,\n )\n )\n img[:, :, 3] = 0\n for ann in sorted_anns:\n m = ann[\"segmentation\"]\n color_mask = np.concatenate([np.random.random(3), [alpha]])\n img[m] = color_mask\n ax.imshow(img)\n\n if \"dpi\" not in kwargs:\n kwargs[\"dpi\"] = 100\n\n if \"bbox_inches\" not in kwargs:\n kwargs[\"bbox_inches\"] = \"tight\"\n\n plt.axis(axis)\n\n self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)\n\n if output is not None:\n if blend:\n array = blend_images(\n self.annotations, self.image, alpha=alpha, show=False\n )\n else:\n array = self.annotations\n array_to_image(array, output, self.source)\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.show_canvas","title":"show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)
","text":"Show a canvas to collect foreground and background points.
Parameters:
Name Type Description Defaultimage
str | np.ndarray
The input image.
requiredfg_color
tuple
The color for the foreground points. Defaults to (0, 255, 0).
(0, 255, 0)
bg_color
tuple
The color for the background points. Defaults to (0, 0, 255).
(0, 0, 255)
radius
int
The radius of the points. Defaults to 5.
5
Returns:
Type Descriptiontuple
A tuple of two lists of foreground and background points.
Source code insamgeo/hq_sam.py
def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):\n\"\"\"Show a canvas to collect foreground and background points.\n\n Args:\n image (str | np.ndarray): The input image.\n fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).\n bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).\n radius (int, optional): The radius of the points. Defaults to 5.\n\n Returns:\n tuple: A tuple of two lists of foreground and background points.\n \"\"\"\n\n if self.image is None:\n raise ValueError(\"Please run set_image() first.\")\n\n image = self.image\n fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)\n self.fg_points = fg_points\n self.bg_points = bg_points\n point_coords = fg_points + bg_points\n point_labels = [1] * len(fg_points) + [0] * len(bg_points)\n self.point_coords = point_coords\n self.point_labels = point_labels\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.show_map","title":"show_map(self, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)
","text":"Show the interactive map.
Parameters:
Name Type Description Defaultbasemap
str
The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
'SATELLITE'
repeat_mode
bool
Whether to use the repeat mode for draw control. Defaults to True.
True
out_dir
str
The path to the output directory. Defaults to None.
None
Returns:
Type Descriptionleafmap.Map
The map object.
Source code insamgeo/hq_sam.py
def show_map(self, basemap=\"SATELLITE\", repeat_mode=True, out_dir=None, **kwargs):\n\"\"\"Show the interactive map.\n\n Args:\n basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.\n repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.\n out_dir (str, optional): The path to the output directory. Defaults to None.\n\n Returns:\n leafmap.Map: The map object.\n \"\"\"\n return sam_map_gui(\n self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs\n )\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.show_masks","title":"show_masks(self, figsize=(12, 10), cmap='binary_r', axis='off', foreground=True, **kwargs)
","text":"Show the binary mask or the mask of objects with unique values.
Parameters:
Name Type Description Defaultfigsize
tuple
The figure size. Defaults to (12, 10).
(12, 10)
cmap
str
The colormap. Defaults to \"binary_r\".
'binary_r'
axis
str
Whether to show the axis. Defaults to \"off\".
'off'
foreground
bool
Whether to show the foreground mask only. Defaults to True.
True
**kwargs
Other arguments for save_masks().
{}
Source code in samgeo/hq_sam.py
def show_masks(\n self, figsize=(12, 10), cmap=\"binary_r\", axis=\"off\", foreground=True, **kwargs\n):\n\"\"\"Show the binary mask or the mask of objects with unique values.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n cmap (str, optional): The colormap. Defaults to \"binary_r\".\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.\n **kwargs: Other arguments for save_masks().\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n if self.batch:\n self.objects = cv2.imread(self.masks)\n else:\n if self.objects is None:\n self.save_masks(foreground=foreground, **kwargs)\n\n plt.figure(figsize=figsize)\n plt.imshow(self.objects, cmap=cmap)\n plt.axis(axis)\n plt.show()\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.tensor_to_numpy","title":"tensor_to_numpy(self, index=None, output=None, mask_multiplier=255, dtype='uint8', save_args={})
","text":"Convert the predicted masks from tensors to numpy arrays.
Parameters:
Name Type Description Defaultindex
index
The index of the mask to save. Defaults to None, which will save the mask with the highest score.
None
output
str
The path to the output image. Defaults to None.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1].
255
dtype
np.dtype
The data type of the output image. Defaults to np.uint8.
'uint8'
save_args
dict
Optional arguments for saving the output image. Defaults to {}.
{}
Returns:
Type Descriptionnp.ndarray
The predicted mask as a numpy array.
Source code insamgeo/hq_sam.py
def tensor_to_numpy(\n self, index=None, output=None, mask_multiplier=255, dtype=\"uint8\", save_args={}\n):\n\"\"\"Convert the predicted masks from tensors to numpy arrays.\n\n Args:\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n output (str, optional): The path to the output image. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.\n save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.\n\n Returns:\n np.ndarray: The predicted mask as a numpy array.\n \"\"\"\n\n boxes = self.boxes\n masks = self.masks\n\n image_pil = self.image\n image_np = np.array(image_pil)\n\n if index is None:\n index = 1\n\n masks = masks[:, index, :, :]\n masks = masks.squeeze(1)\n\n if boxes is None or (len(boxes) == 0): # No \"object\" instances found\n print(\"No objects found in the image.\")\n return\n else:\n # Create an empty image to store the mask overlays\n mask_overlay = np.zeros_like(\n image_np[..., 0], dtype=dtype\n ) # Adjusted for single channel\n\n for i, (box, mask) in enumerate(zip(boxes, masks)):\n # Convert tensor to numpy array if necessary and ensure it contains integers\n if isinstance(mask, torch.Tensor):\n mask = (\n mask.cpu().numpy().astype(dtype)\n ) # If mask is on GPU, use .cpu() before .numpy()\n mask_overlay += ((mask > 0) * (i + 1)).astype(\n dtype\n ) # Assign a unique value for each mask\n\n # Normalize mask_overlay to be in [0, 255]\n mask_overlay = (\n mask_overlay > 0\n ) * mask_multiplier # Binary mask in [0, 255]\n\n if output is not None:\n array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)\n else:\n return mask_overlay\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.tiff_to_geojson","title":"tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a GeoJSON file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the GeoJSON file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/hq_sam.py
def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a GeoJSON file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the GeoJSON file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_geojson(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.tiff_to_gpkg","title":"tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a gpkg file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the gpkg file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/hq_sam.py
def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the gpkg file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_gpkg(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.tiff_to_shp","title":"tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a shapefile.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the shapefile.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/hq_sam.py
def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a shapefile.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the shapefile.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_shp(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeo.tiff_to_vector","title":"tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a gpkg file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the vector file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/hq_sam.py
def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeoPredictor","title":" SamGeoPredictor (SamPredictor)
","text":"Source code in samgeo/hq_sam.py
class SamGeoPredictor(SamPredictor):\n def __init__(\n self,\n sam_model,\n ):\n from segment_anything.utils.transforms import ResizeLongestSide\n\n self.model = sam_model\n self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)\n\n def set_image(self, image):\n super(SamGeoPredictor, self).set_image(image)\n\n def predict(\n self,\n src_fp=None,\n geo_box=None,\n point_coords=None,\n point_labels=None,\n box=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n ):\n if geo_box and src_fp:\n self.crs = \"EPSG:4326\"\n dst_crs = get_crs(src_fp)\n sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)\n ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)\n xs = np.array([sw[0], ne[0]])\n ys = np.array([sw[1], ne[1]])\n box = get_pixel_coords(src_fp, xs, ys)\n self.geo_box = geo_box\n self.width = box[2] - box[0]\n self.height = box[3] - box[1]\n self.geo_transform = set_transform(geo_box, self.width, self.height)\n\n masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(\n point_coords, point_labels, box, mask_input, multimask_output, return_logits\n )\n\n return masks, iou_predictions, low_res_masks\n\n def masks_to_geotiff(self, src_fp, dst_fp, masks):\n profile = get_profile(src_fp)\n write_raster(\n dst_fp,\n masks,\n profile,\n self.width,\n self.height,\n self.geo_transform,\n self.crs,\n )\n\n def geotiff_to_geojson(self, src_fp, dst_fp, bidx=1):\n gdf = get_features(src_fp, bidx)\n write_features(gdf, dst_fp)\n return gdf\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeoPredictor.predict","title":"predict(self, src_fp=None, geo_box=None, point_coords=None, point_labels=None, box=None, mask_input=None, multimask_output=True, return_logits=False)
","text":"Predict masks for the given input prompts, using the currently set image.
Parameters:
Name Type Description Defaultpoint_coords
np.ndarray or None
A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels.
None
point_labels
np.ndarray or None
A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
None
box
np.ndarray or None
A length 4 array given a box prompt to the model, in XYXY format.
None
mask_input
np.ndarray
A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
None
multimask_output
bool
If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.
True
return_logits
bool
If true, returns un-thresholded masks logits instead of a binary mask.
False
Returns:
Type Description(np.ndarray)
The output masks in CxHxW format, where C is the number of masks, and (H, W) is the original image size. (np.ndarray): An array of length C containing the model's predictions for the quality of each mask. (np.ndarray): An array of shape CxHxW, where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input.
Source code insamgeo/hq_sam.py
def predict(\n self,\n src_fp=None,\n geo_box=None,\n point_coords=None,\n point_labels=None,\n box=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n):\n if geo_box and src_fp:\n self.crs = \"EPSG:4326\"\n dst_crs = get_crs(src_fp)\n sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)\n ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)\n xs = np.array([sw[0], ne[0]])\n ys = np.array([sw[1], ne[1]])\n box = get_pixel_coords(src_fp, xs, ys)\n self.geo_box = geo_box\n self.width = box[2] - box[0]\n self.height = box[3] - box[1]\n self.geo_transform = set_transform(geo_box, self.width, self.height)\n\n masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(\n point_coords, point_labels, box, mask_input, multimask_output, return_logits\n )\n\n return masks, iou_predictions, low_res_masks\n
"},{"location":"hq_sam/#samgeo.hq_sam.SamGeoPredictor.set_image","title":"set_image(self, image)
","text":"Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method.
Parameters:
Name Type Description Defaultimage
np.ndarray
The image for calculating masks. Expects an image in HWC uint8 format, with pixel values in [0, 255].
requiredimage_format
str
The color format of the image, in ['RGB', 'BGR'].
required Source code insamgeo/hq_sam.py
def set_image(self, image):\n super(SamGeoPredictor, self).set_image(image)\n
"},{"location":"installation/","title":"Installation","text":""},{"location":"installation/#install-from-pypi","title":"Install from PyPI","text":"segment-geospatial is available on PyPI. To install segment-geospatial, run this command in your terminal:
pip install segment-geospatial\n
"},{"location":"installation/#install-from-conda-forge","title":"Install from conda-forge","text":"segment-geospatial is also available on conda-forge. If you have Anaconda or Miniconda installed on your computer, you can install segment-geospatial using the following commands. It is recommended to create a fresh conda environment for segment-geospatial. The following commands will create a new conda environment named geo
and install segment-geospatial and its dependencies:
conda create -n geo python=3.10\nconda activate geo\nconda install -c conda-forge mamba\nmamba install -c conda-forge segment-geospatial\n
Samgeo-geospatial has some optional dependencies that are not included in the default conda environment. To install these dependencies, run the following command:
mamba install -c conda-forge groundingdino-py\n
As of July 9th, 2023 Linux systems have also required that libgl1
be installed for segment-geospatial to work. The following command will install that dependency
apt update; apt install -y libgl1\n
"},{"location":"installation/#install-from-github","title":"Install from GitHub","text":"To install the development version from GitHub using Git, run the following command in your terminal:
pip install git+https://github.com/opengeos/segment-geospatial\n
"},{"location":"installation/#use-docker","title":"Use docker","text":"You can also use docker to run segment-geospatial:
docker run -it -p 8888:8888 giswqs/segment-geospatial:latest\n
To enable GPU for segment-geospatial, run the following command to run a short benchmark on your GPU:
docker run --rm -it --gpus=all nvcr.io/nvidia/k8s/cuda-sample:nbody nbody -gpu -benchmark\n
The output should be similar to the following:
Run \"nbody -benchmark [-numbodies=<numBodies>]\" to measure performance.\n -fullscreen (run n-body simulation in fullscreen mode)\n -fp64 (use double precision floating point values for simulation)\n -hostmem (stores simulation data in host memory)\n -benchmark (run benchmark to measure performance)\n -numbodies=<N> (number of bodies (>= 1) to run in simulation)\n -device=<d> (where d=0,1,2.... for the CUDA device to use)\n -numdevices=<i> (where i=(number of CUDA devices > 0) to use for simulation)\n -compare (compares simulation results running once on the default GPU and once on the CPU)\n -cpu (run n-body simulation on the CPU)\n -tipsy=<file.bin> (load a tipsy model file for simulation)\n\nNOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.\n\n> Windowed mode\n> Simulation data stored in video memory\n> Single precision floating point simulation\n> 1 Devices used for simulation\nGPU Device 0: \"Turing\" with compute capability 7.5\n\n> Compute 7.5 CUDA device: [Quadro RTX 5000]\n49152 bodies, total time for 10 iterations: 69.386 ms\n= 348.185 billion interactions per second\n= 6963.703 single-precision GFLOP/s at 20 flops per interaction\n
If you encounter the following error:
nvidia-container-cli: initialization error: load library failed: libnvidia-ml.so.1: cannot open shared object file: no such file or directory: unknown.\n
Try adding sudo
to the command:
sudo docker run --rm -it --gpus=all nvcr.io/nvidia/k8s/cuda-sample:nbody nbody -gpu -benchmark\n
Once everything is working, you can run the following command to start a Jupyter Notebook server:
docker run -it -p 8888:8888 --gpus=all giswqs/segment-geospatial:latest\n
"},{"location":"samgeo/","title":"samgeo module","text":"The source code is adapted from https://github.com/aliaksandr960/segment-anything-eo. Credit to the author Aliaksandr Hancharenka.
"},{"location":"samgeo/#samgeo.samgeo.SamGeo","title":" SamGeo
","text":"The main class for segmenting geospatial data with the Segment Anything Model (SAM). See https://github.com/facebookresearch/segment-anything for details.
Source code insamgeo/samgeo.py
class SamGeo:\n\"\"\"The main class for segmenting geospatial data with the Segment Anything Model (SAM). See\n https://github.com/facebookresearch/segment-anything for details.\n \"\"\"\n\n def __init__(\n self,\n model_type=\"vit_h\",\n automatic=True,\n device=None,\n checkpoint_dir=None,\n sam_kwargs=None,\n **kwargs,\n ):\n\"\"\"Initialize the class.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.\n The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.\n device (str, optional): The device to use. It can be one of the following: cpu, cuda.\n Defaults to None, which will use cuda if available.\n checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:\n sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.\n Defaults to None. See https://bit.ly/3VrpxUh for more details.\n sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.\n The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.\n\n points_per_side: Optional[int] = 32,\n points_per_batch: int = 64,\n pred_iou_thresh: float = 0.88,\n stability_score_thresh: float = 0.95,\n stability_score_offset: float = 1.0,\n box_nms_thresh: float = 0.7,\n crop_n_layers: int = 0,\n crop_nms_thresh: float = 0.7,\n crop_overlap_ratio: float = 512 / 1500,\n crop_n_points_downscale_factor: int = 1,\n point_grids: Optional[List[np.ndarray]] = None,\n min_mask_region_area: int = 0,\n output_mode: str = \"binary_mask\",\n\n \"\"\"\n hq = False # Not using HQ-SAM\n\n if \"checkpoint\" in kwargs:\n checkpoint = kwargs[\"checkpoint\"]\n if not os.path.exists(checkpoint):\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n kwargs.pop(\"checkpoint\")\n else:\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n\n # Use cuda if available\n if device is None:\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n if device == \"cuda\":\n torch.cuda.empty_cache()\n\n self.checkpoint = checkpoint\n self.model_type = model_type\n self.device = device\n self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model\n self.source = None # Store the input image path\n self.image = None # Store the input image as a numpy array\n # Store the masks as a list of dictionaries. Each mask is a dictionary\n # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box\n self.masks = None\n self.objects = None # Store the mask objects as a numpy array\n # Store the annotations (objects with random color) as a numpy array.\n self.annotations = None\n\n # Store the predicted masks, iou_predictions, and low_res_masks\n self.prediction = None\n self.scores = None\n self.logits = None\n\n # Build the SAM model\n self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)\n self.sam.to(device=self.device)\n # Use optional arguments for fine-tuning the SAM model\n sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}\n\n if automatic:\n # Segment the entire image using the automatic mask generator\n self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)\n else:\n # Segment selected objects using input prompts\n self.predictor = SamPredictor(self.sam, **sam_kwargs)\n\n def __call__(\n self,\n image,\n foreground=True,\n erosion_kernel=(3, 3),\n mask_multiplier=255,\n **kwargs,\n ):\n\"\"\"Generate masks for the input tile. This function originates from the segment-anything-eo repository.\n See https://bit.ly/41pwiHw\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n \"\"\"\n h, w, _ = image.shape\n\n masks = self.mask_generator.generate(image)\n\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=np.uint8)\n else:\n resulting_mask = np.ones((h, w), dtype=np.uint8)\n resulting_borders = np.zeros((h, w), dtype=np.uint8)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(np.uint8)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(np.uint8)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(np.uint8)\n resulting_borders = (resulting_borders > 0).astype(np.uint8)\n resulting_mask_with_borders = resulting_mask - resulting_borders\n return resulting_mask_with_borders * mask_multiplier\n\n def generate(\n self,\n source,\n output=None,\n foreground=True,\n batch=False,\n erosion_kernel=None,\n mask_multiplier=255,\n unique=True,\n **kwargs,\n ):\n\"\"\"Generate masks for the input image.\n\n Args:\n source (str | np.ndarray): The path to the input image or the input image as a numpy array.\n output (str, optional): The path to the output image. Defaults to None.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n The parameter is ignored if unique is True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.\n\n \"\"\"\n\n if isinstance(source, str):\n if source.startswith(\"http\"):\n source = download_file(source)\n\n if not os.path.exists(source):\n raise ValueError(f\"Input path {source} does not exist.\")\n\n if batch: # Subdivide the image into tiles and segment each tile\n self.batch = True\n self.source = source\n self.masks = output\n return tiff_to_tiff(\n source,\n output,\n self,\n foreground=foreground,\n erosion_kernel=erosion_kernel,\n mask_multiplier=mask_multiplier,\n **kwargs,\n )\n\n image = cv2.imread(source)\n elif isinstance(source, np.ndarray):\n image = source\n source = None\n else:\n raise ValueError(\"Input source must be either a path or a numpy array.\")\n\n self.source = source # Store the input image path\n self.image = image # Store the input image as a numpy array\n mask_generator = self.mask_generator # The automatic mask generator\n masks = mask_generator.generate(image) # Segment the input image\n self.masks = masks # Store the masks as a list of dictionaries\n self.batch = False\n\n if output is not None:\n # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n self.save_masks(\n output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs\n )\n\n def save_masks(\n self,\n output=None,\n foreground=True,\n unique=True,\n erosion_kernel=None,\n mask_multiplier=255,\n **kwargs,\n ):\n\"\"\"Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n\n Args:\n output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n\n \"\"\"\n\n if self.masks is None:\n raise ValueError(\"No masks found. Please run generate() first.\")\n\n h, w, _ = self.image.shape\n masks = self.masks\n\n # Set output image data type based on the number of objects\n if len(masks) < 255:\n dtype = np.uint8\n elif len(masks) < 65535:\n dtype = np.uint16\n else:\n dtype = np.uint32\n\n # Generate a mask of objects with unique values\n if unique:\n # Sort the masks by area in ascending order\n sorted_masks = sorted(masks, key=(lambda x: x[\"area\"]), reverse=False)\n\n # Create an output image with the same size as the input image\n objects = np.zeros(\n (\n sorted_masks[0][\"segmentation\"].shape[0],\n sorted_masks[0][\"segmentation\"].shape[1],\n )\n )\n # Assign a unique value to each object\n for index, ann in enumerate(sorted_masks):\n m = ann[\"segmentation\"]\n objects[m] = index + 1\n\n # Generate a binary mask\n else:\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=dtype)\n else:\n resulting_mask = np.ones((h, w), dtype=dtype)\n resulting_borders = np.zeros((h, w), dtype=dtype)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(dtype)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(dtype)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(dtype)\n resulting_borders = (resulting_borders > 0).astype(dtype)\n objects = resulting_mask - resulting_borders\n objects = objects * mask_multiplier\n\n objects = objects.astype(dtype)\n self.objects = objects\n\n if output is not None: # Save the output image\n array_to_image(self.objects, output, self.source, **kwargs)\n\n def show_masks(\n self, figsize=(12, 10), cmap=\"binary_r\", axis=\"off\", foreground=True, **kwargs\n ):\n\"\"\"Show the binary mask or the mask of objects with unique values.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n cmap (str, optional): The colormap. Defaults to \"binary_r\".\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.\n **kwargs: Other arguments for save_masks().\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n if self.batch:\n self.objects = cv2.imread(self.masks)\n else:\n if self.objects is None:\n self.save_masks(foreground=foreground, **kwargs)\n\n plt.figure(figsize=figsize)\n plt.imshow(self.objects, cmap=cmap)\n plt.axis(axis)\n plt.show()\n\n def show_anns(\n self,\n figsize=(12, 10),\n axis=\"off\",\n alpha=0.35,\n output=None,\n blend=True,\n **kwargs,\n ):\n\"\"\"Show the annotations (objects with random color) on the input image.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.\n output (str, optional): The path to the output image. Defaults to None.\n blend (bool, optional): Whether to show the input image. Defaults to True.\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n anns = self.masks\n\n if self.image is None:\n print(\"Please run generate() first.\")\n return\n\n if anns is None or len(anns) == 0:\n return\n\n plt.figure(figsize=figsize)\n plt.imshow(self.image)\n\n sorted_anns = sorted(anns, key=(lambda x: x[\"area\"]), reverse=True)\n\n ax = plt.gca()\n ax.set_autoscale_on(False)\n\n img = np.ones(\n (\n sorted_anns[0][\"segmentation\"].shape[0],\n sorted_anns[0][\"segmentation\"].shape[1],\n 4,\n )\n )\n img[:, :, 3] = 0\n for ann in sorted_anns:\n m = ann[\"segmentation\"]\n color_mask = np.concatenate([np.random.random(3), [alpha]])\n img[m] = color_mask\n ax.imshow(img)\n\n if \"dpi\" not in kwargs:\n kwargs[\"dpi\"] = 100\n\n if \"bbox_inches\" not in kwargs:\n kwargs[\"bbox_inches\"] = \"tight\"\n\n plt.axis(axis)\n\n self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)\n\n if output is not None:\n if blend:\n array = blend_images(\n self.annotations, self.image, alpha=alpha, show=False\n )\n else:\n array = self.annotations\n array_to_image(array, output, self.source)\n\n def set_image(self, image, image_format=\"RGB\"):\n\"\"\"Set the input image as a numpy array.\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n image_format (str, optional): The image format, can be RGB or BGR. Defaults to \"RGB\".\n \"\"\"\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n\n image = cv2.imread(image)\n image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n self.image = image\n elif isinstance(image, np.ndarray):\n pass\n else:\n raise ValueError(\"Input image must be either a path or a numpy array.\")\n\n self.predictor.set_image(image, image_format=image_format)\n\n def save_prediction(\n self,\n output,\n index=None,\n mask_multiplier=255,\n dtype=np.float32,\n vector=None,\n simplify_tolerance=None,\n **kwargs,\n ):\n\"\"\"Save the predicted mask to the output path.\n\n Args:\n output (str): The path to the output image.\n index (int, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n vector (str, optional): The path to the output vector file. Defaults to None.\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n\n \"\"\"\n if self.scores is None:\n raise ValueError(\"No predictions found. Please run predict() first.\")\n\n if index is None:\n index = self.scores.argmax(axis=0)\n\n array = self.masks[index] * mask_multiplier\n self.prediction = array\n array_to_image(array, output, self.source, dtype=dtype, **kwargs)\n\n if vector is not None:\n raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)\n\n def predict(\n self,\n point_coords=None,\n point_labels=None,\n boxes=None,\n point_crs=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n output=None,\n index=None,\n mask_multiplier=255,\n dtype=\"float32\",\n return_results=False,\n **kwargs,\n ):\n\"\"\"Predict masks for the given input prompts, using the currently set image.\n\n Args:\n point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the\n model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON\n dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.\n point_labels (list | int | np.ndarray, optional): A length N array of labels for the\n point prompts. 1 indicates a foreground point and 0 indicates a background point.\n point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.\n boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the\n model, in XYXY format.\n mask_input (np.ndarray, optional): A low resolution mask input to the model, typically\n coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.\n multimask_output (bool, optional): If true, the model will return three masks.\n For ambiguous input prompts (such as a single click), this will often\n produce better masks than a single prediction. If only a single\n mask is needed, the model's predicted quality score can be used\n to select the best mask. For non-ambiguous prompts, such as multiple\n input prompts, multimask_output=False can give better results.\n return_logits (bool, optional): If true, returns un-thresholded masks logits\n instead of a binary mask.\n output (str, optional): The path to the output image. Defaults to None.\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.\n\n \"\"\"\n\n if isinstance(boxes, str):\n gdf = gpd.read_file(boxes)\n if gdf.crs is not None:\n gdf = gdf.to_crs(\"epsg:4326\")\n boxes = gdf.geometry.bounds.values.tolist()\n elif isinstance(boxes, dict):\n import json\n\n geojson = json.dumps(boxes)\n gdf = gpd.read_file(geojson, driver=\"GeoJSON\")\n boxes = gdf.geometry.bounds.values.tolist()\n\n if isinstance(point_coords, str):\n point_coords = vector_to_geojson(point_coords)\n\n if isinstance(point_coords, dict):\n point_coords = geojson_to_coords(point_coords)\n\n if hasattr(self, \"point_coords\"):\n point_coords = self.point_coords\n\n if hasattr(self, \"point_labels\"):\n point_labels = self.point_labels\n\n if (point_crs is not None) and (point_coords is not None):\n point_coords = coords_to_xy(self.source, point_coords, point_crs)\n\n if isinstance(point_coords, list):\n point_coords = np.array(point_coords)\n\n if point_coords is not None:\n if point_labels is None:\n point_labels = [1] * len(point_coords)\n elif isinstance(point_labels, int):\n point_labels = [point_labels] * len(point_coords)\n\n if isinstance(point_labels, list):\n if len(point_labels) != len(point_coords):\n if len(point_labels) == 1:\n point_labels = point_labels * len(point_coords)\n else:\n raise ValueError(\n \"The length of point_labels must be equal to the length of point_coords.\"\n )\n point_labels = np.array(point_labels)\n\n predictor = self.predictor\n\n input_boxes = None\n if isinstance(boxes, list) and (point_crs is not None):\n coords = bbox_to_xy(self.source, boxes, point_crs)\n input_boxes = np.array(coords)\n if isinstance(coords[0], int):\n input_boxes = input_boxes[None, :]\n else:\n input_boxes = torch.tensor(input_boxes, device=self.device)\n input_boxes = predictor.transform.apply_boxes_torch(\n input_boxes, self.image.shape[:2]\n )\n elif isinstance(boxes, list) and (point_crs is None):\n input_boxes = np.array(boxes)\n if isinstance(boxes[0], int):\n input_boxes = input_boxes[None, :]\n\n self.boxes = input_boxes\n\n if boxes is None or (not isinstance(boxes[0], list)):\n masks, scores, logits = predictor.predict(\n point_coords,\n point_labels,\n input_boxes,\n mask_input,\n multimask_output,\n return_logits,\n )\n else:\n masks, scores, logits = predictor.predict_torch(\n point_coords=point_coords,\n point_labels=point_coords,\n boxes=input_boxes,\n multimask_output=True,\n )\n\n self.masks = masks\n self.scores = scores\n self.logits = logits\n\n if output is not None:\n if boxes is None or (not isinstance(boxes[0], list)):\n self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)\n else:\n self.tensor_to_numpy(\n index, output, mask_multiplier, dtype, save_args=kwargs\n )\n\n if return_results:\n return masks, scores, logits\n\n def tensor_to_numpy(\n self, index=None, output=None, mask_multiplier=255, dtype=\"uint8\", save_args={}\n ):\n\"\"\"Convert the predicted masks from tensors to numpy arrays.\n\n Args:\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n output (str, optional): The path to the output image. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.\n save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.\n\n Returns:\n np.ndarray: The predicted mask as a numpy array.\n \"\"\"\n\n boxes = self.boxes\n masks = self.masks\n\n image_pil = self.image\n image_np = np.array(image_pil)\n\n if index is None:\n index = 1\n\n masks = masks[:, index, :, :]\n masks = masks.squeeze(1)\n\n if boxes is None or (len(boxes) == 0): # No \"object\" instances found\n print(\"No objects found in the image.\")\n return\n else:\n # Create an empty image to store the mask overlays\n mask_overlay = np.zeros_like(\n image_np[..., 0], dtype=dtype\n ) # Adjusted for single channel\n\n for i, (box, mask) in enumerate(zip(boxes, masks)):\n # Convert tensor to numpy array if necessary and ensure it contains integers\n if isinstance(mask, torch.Tensor):\n mask = (\n mask.cpu().numpy().astype(dtype)\n ) # If mask is on GPU, use .cpu() before .numpy()\n mask_overlay += ((mask > 0) * (i + 1)).astype(\n dtype\n ) # Assign a unique value for each mask\n\n # Normalize mask_overlay to be in [0, 255]\n mask_overlay = (\n mask_overlay > 0\n ) * mask_multiplier # Binary mask in [0, 255]\n\n if output is not None:\n array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)\n else:\n return mask_overlay\n\n def show_map(self, basemap=\"SATELLITE\", repeat_mode=True, out_dir=None, **kwargs):\n\"\"\"Show the interactive map.\n\n Args:\n basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.\n repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.\n out_dir (str, optional): The path to the output directory. Defaults to None.\n\n Returns:\n leafmap.Map: The map object.\n \"\"\"\n return sam_map_gui(\n self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs\n )\n\n def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):\n\"\"\"Show a canvas to collect foreground and background points.\n\n Args:\n image (str | np.ndarray): The input image.\n fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).\n bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).\n radius (int, optional): The radius of the points. Defaults to 5.\n\n Returns:\n tuple: A tuple of two lists of foreground and background points.\n \"\"\"\n\n if self.image is None:\n raise ValueError(\"Please run set_image() first.\")\n\n image = self.image\n fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)\n self.fg_points = fg_points\n self.bg_points = bg_points\n point_coords = fg_points + bg_points\n point_labels = [1] * len(fg_points) + [0] * len(bg_points)\n self.point_coords = point_coords\n self.point_labels = point_labels\n\n def clear_cuda_cache(self):\n\"\"\"Clear the CUDA cache.\"\"\"\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n\n def image_to_image(self, image, **kwargs):\n return image_to_image(image, self, **kwargs)\n\n def download_tms_as_tiff(self, source, pt1, pt2, zoom, dist):\n image = draw_tile(source, pt1[0], pt1[1], pt2[0], pt2[1], zoom, dist)\n return image\n\n def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):\n\"\"\"Save the result to a vector file.\n\n Args:\n image (str): The path to the image file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)\n\n def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n\n def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the gpkg file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_gpkg(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n\n def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a shapefile.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the shapefile.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_shp(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n\n def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a GeoJSON file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the GeoJSON file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_geojson(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.__call__","title":"__call__(self, image, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255, **kwargs)
special
","text":"Generate masks for the input tile. This function originates from the segment-anything-eo repository. See https://bit.ly/41pwiHw
Parameters:
Name Type Description Defaultimage
np.ndarray
The input image as a numpy array.
requiredforeground
bool
Whether to generate the foreground mask. Defaults to True.
True
erosion_kernel
tuple
The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
(3, 3)
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
255
Source code in samgeo/samgeo.py
def __call__(\n self,\n image,\n foreground=True,\n erosion_kernel=(3, 3),\n mask_multiplier=255,\n **kwargs,\n):\n\"\"\"Generate masks for the input tile. This function originates from the segment-anything-eo repository.\n See https://bit.ly/41pwiHw\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n \"\"\"\n h, w, _ = image.shape\n\n masks = self.mask_generator.generate(image)\n\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=np.uint8)\n else:\n resulting_mask = np.ones((h, w), dtype=np.uint8)\n resulting_borders = np.zeros((h, w), dtype=np.uint8)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(np.uint8)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(np.uint8)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(np.uint8)\n resulting_borders = (resulting_borders > 0).astype(np.uint8)\n resulting_mask_with_borders = resulting_mask - resulting_borders\n return resulting_mask_with_borders * mask_multiplier\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.__init__","title":"__init__(self, model_type='vit_h', automatic=True, device=None, checkpoint_dir=None, sam_kwargs=None, **kwargs)
special
","text":"Initialize the class.
Parameters:
Name Type Description Defaultmodel_type
str
The model type. It can be one of the following: vit_h, vit_l, vit_b. Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
'vit_h'
automatic
bool
Whether to use the automatic mask generator or input prompts. Defaults to True. The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
True
device
str
The device to use. It can be one of the following: cpu, cuda. Defaults to None, which will use cuda if available.
None
checkpoint_dir
str
The path to the model checkpoint. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth. Defaults to None. See https://bit.ly/3VrpxUh for more details.
None
sam_kwargs
dict
Optional arguments for fine-tuning the SAM model. Defaults to None. The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
points_per_side: Optional[int] = 32, points_per_batch: int = 64, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, box_nms_thresh: float = 0.7, crop_n_layers: int = 0, crop_nms_thresh: float = 0.7, crop_overlap_ratio: float = 512 / 1500, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = \"binary_mask\",
None
Source code in samgeo/samgeo.py
def __init__(\n self,\n model_type=\"vit_h\",\n automatic=True,\n device=None,\n checkpoint_dir=None,\n sam_kwargs=None,\n **kwargs,\n):\n\"\"\"Initialize the class.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.\n The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.\n device (str, optional): The device to use. It can be one of the following: cpu, cuda.\n Defaults to None, which will use cuda if available.\n checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:\n sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.\n Defaults to None. See https://bit.ly/3VrpxUh for more details.\n sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.\n The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.\n\n points_per_side: Optional[int] = 32,\n points_per_batch: int = 64,\n pred_iou_thresh: float = 0.88,\n stability_score_thresh: float = 0.95,\n stability_score_offset: float = 1.0,\n box_nms_thresh: float = 0.7,\n crop_n_layers: int = 0,\n crop_nms_thresh: float = 0.7,\n crop_overlap_ratio: float = 512 / 1500,\n crop_n_points_downscale_factor: int = 1,\n point_grids: Optional[List[np.ndarray]] = None,\n min_mask_region_area: int = 0,\n output_mode: str = \"binary_mask\",\n\n \"\"\"\n hq = False # Not using HQ-SAM\n\n if \"checkpoint\" in kwargs:\n checkpoint = kwargs[\"checkpoint\"]\n if not os.path.exists(checkpoint):\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n kwargs.pop(\"checkpoint\")\n else:\n checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)\n\n # Use cuda if available\n if device is None:\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n if device == \"cuda\":\n torch.cuda.empty_cache()\n\n self.checkpoint = checkpoint\n self.model_type = model_type\n self.device = device\n self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model\n self.source = None # Store the input image path\n self.image = None # Store the input image as a numpy array\n # Store the masks as a list of dictionaries. Each mask is a dictionary\n # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box\n self.masks = None\n self.objects = None # Store the mask objects as a numpy array\n # Store the annotations (objects with random color) as a numpy array.\n self.annotations = None\n\n # Store the predicted masks, iou_predictions, and low_res_masks\n self.prediction = None\n self.scores = None\n self.logits = None\n\n # Build the SAM model\n self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)\n self.sam.to(device=self.device)\n # Use optional arguments for fine-tuning the SAM model\n sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}\n\n if automatic:\n # Segment the entire image using the automatic mask generator\n self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)\n else:\n # Segment selected objects using input prompts\n self.predictor = SamPredictor(self.sam, **sam_kwargs)\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.clear_cuda_cache","title":"clear_cuda_cache(self)
","text":"Clear the CUDA cache.
Source code insamgeo/samgeo.py
def clear_cuda_cache(self):\n\"\"\"Clear the CUDA cache.\"\"\"\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.generate","title":"generate(self, source, output=None, foreground=True, batch=False, erosion_kernel=None, mask_multiplier=255, unique=True, **kwargs)
","text":"Generate masks for the input image.
Parameters:
Name Type Description Defaultsource
str | np.ndarray
The path to the input image or the input image as a numpy array.
requiredoutput
str
The path to the output image. Defaults to None.
None
foreground
bool
Whether to generate the foreground mask. Defaults to True.
True
batch
bool
Whether to generate masks for a batch of image tiles. Defaults to False.
False
erosion_kernel
tuple
The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. The parameter is ignored if unique is True.
255
unique
bool
Whether to assign a unique value to each object. Defaults to True. The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
True
Source code in samgeo/samgeo.py
def generate(\n self,\n source,\n output=None,\n foreground=True,\n batch=False,\n erosion_kernel=None,\n mask_multiplier=255,\n unique=True,\n **kwargs,\n):\n\"\"\"Generate masks for the input image.\n\n Args:\n source (str | np.ndarray): The path to the input image or the input image as a numpy array.\n output (str, optional): The path to the output image. Defaults to None.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n The parameter is ignored if unique is True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.\n\n \"\"\"\n\n if isinstance(source, str):\n if source.startswith(\"http\"):\n source = download_file(source)\n\n if not os.path.exists(source):\n raise ValueError(f\"Input path {source} does not exist.\")\n\n if batch: # Subdivide the image into tiles and segment each tile\n self.batch = True\n self.source = source\n self.masks = output\n return tiff_to_tiff(\n source,\n output,\n self,\n foreground=foreground,\n erosion_kernel=erosion_kernel,\n mask_multiplier=mask_multiplier,\n **kwargs,\n )\n\n image = cv2.imread(source)\n elif isinstance(source, np.ndarray):\n image = source\n source = None\n else:\n raise ValueError(\"Input source must be either a path or a numpy array.\")\n\n self.source = source # Store the input image path\n self.image = image # Store the input image as a numpy array\n mask_generator = self.mask_generator # The automatic mask generator\n masks = mask_generator.generate(image) # Segment the input image\n self.masks = masks # Store the masks as a list of dictionaries\n self.batch = False\n\n if output is not None:\n # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n self.save_masks(\n output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs\n )\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.predict","title":"predict(self, point_coords=None, point_labels=None, boxes=None, point_crs=None, mask_input=None, multimask_output=True, return_logits=False, output=None, index=None, mask_multiplier=255, dtype='float32', return_results=False, **kwargs)
","text":"Predict masks for the given input prompts, using the currently set image.
Parameters:
Name Type Description Defaultpoint_coords
str | dict | list | np.ndarray
A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
None
point_labels
list | int | np.ndarray
A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
None
point_crs
str
The coordinate reference system (CRS) of the point prompts.
None
boxes
list | np.ndarray
A length 4 array given a box prompt to the model, in XYXY format.
None
mask_input
np.ndarray
A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool, optional): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.
None
return_logits
bool
If true, returns un-thresholded masks logits instead of a binary mask.
False
output
str
The path to the output image. Defaults to None.
None
index
index
The index of the mask to save. Defaults to None, which will save the mask with the highest score.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1].
255
dtype
np.dtype
The data type of the output image. Defaults to np.float32.
'float32'
return_results
bool
Whether to return the predicted masks, scores, and logits. Defaults to False.
False
Source code in samgeo/samgeo.py
def predict(\n self,\n point_coords=None,\n point_labels=None,\n boxes=None,\n point_crs=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n output=None,\n index=None,\n mask_multiplier=255,\n dtype=\"float32\",\n return_results=False,\n **kwargs,\n):\n\"\"\"Predict masks for the given input prompts, using the currently set image.\n\n Args:\n point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the\n model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON\n dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.\n point_labels (list | int | np.ndarray, optional): A length N array of labels for the\n point prompts. 1 indicates a foreground point and 0 indicates a background point.\n point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.\n boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the\n model, in XYXY format.\n mask_input (np.ndarray, optional): A low resolution mask input to the model, typically\n coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.\n multimask_output (bool, optional): If true, the model will return three masks.\n For ambiguous input prompts (such as a single click), this will often\n produce better masks than a single prediction. If only a single\n mask is needed, the model's predicted quality score can be used\n to select the best mask. For non-ambiguous prompts, such as multiple\n input prompts, multimask_output=False can give better results.\n return_logits (bool, optional): If true, returns un-thresholded masks logits\n instead of a binary mask.\n output (str, optional): The path to the output image. Defaults to None.\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.\n\n \"\"\"\n\n if isinstance(boxes, str):\n gdf = gpd.read_file(boxes)\n if gdf.crs is not None:\n gdf = gdf.to_crs(\"epsg:4326\")\n boxes = gdf.geometry.bounds.values.tolist()\n elif isinstance(boxes, dict):\n import json\n\n geojson = json.dumps(boxes)\n gdf = gpd.read_file(geojson, driver=\"GeoJSON\")\n boxes = gdf.geometry.bounds.values.tolist()\n\n if isinstance(point_coords, str):\n point_coords = vector_to_geojson(point_coords)\n\n if isinstance(point_coords, dict):\n point_coords = geojson_to_coords(point_coords)\n\n if hasattr(self, \"point_coords\"):\n point_coords = self.point_coords\n\n if hasattr(self, \"point_labels\"):\n point_labels = self.point_labels\n\n if (point_crs is not None) and (point_coords is not None):\n point_coords = coords_to_xy(self.source, point_coords, point_crs)\n\n if isinstance(point_coords, list):\n point_coords = np.array(point_coords)\n\n if point_coords is not None:\n if point_labels is None:\n point_labels = [1] * len(point_coords)\n elif isinstance(point_labels, int):\n point_labels = [point_labels] * len(point_coords)\n\n if isinstance(point_labels, list):\n if len(point_labels) != len(point_coords):\n if len(point_labels) == 1:\n point_labels = point_labels * len(point_coords)\n else:\n raise ValueError(\n \"The length of point_labels must be equal to the length of point_coords.\"\n )\n point_labels = np.array(point_labels)\n\n predictor = self.predictor\n\n input_boxes = None\n if isinstance(boxes, list) and (point_crs is not None):\n coords = bbox_to_xy(self.source, boxes, point_crs)\n input_boxes = np.array(coords)\n if isinstance(coords[0], int):\n input_boxes = input_boxes[None, :]\n else:\n input_boxes = torch.tensor(input_boxes, device=self.device)\n input_boxes = predictor.transform.apply_boxes_torch(\n input_boxes, self.image.shape[:2]\n )\n elif isinstance(boxes, list) and (point_crs is None):\n input_boxes = np.array(boxes)\n if isinstance(boxes[0], int):\n input_boxes = input_boxes[None, :]\n\n self.boxes = input_boxes\n\n if boxes is None or (not isinstance(boxes[0], list)):\n masks, scores, logits = predictor.predict(\n point_coords,\n point_labels,\n input_boxes,\n mask_input,\n multimask_output,\n return_logits,\n )\n else:\n masks, scores, logits = predictor.predict_torch(\n point_coords=point_coords,\n point_labels=point_coords,\n boxes=input_boxes,\n multimask_output=True,\n )\n\n self.masks = masks\n self.scores = scores\n self.logits = logits\n\n if output is not None:\n if boxes is None or (not isinstance(boxes[0], list)):\n self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)\n else:\n self.tensor_to_numpy(\n index, output, mask_multiplier, dtype, save_args=kwargs\n )\n\n if return_results:\n return masks, scores, logits\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.raster_to_vector","title":"raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs)
","text":"Save the result to a vector file.
Parameters:
Name Type Description Defaultimage
str
The path to the image file.
requiredoutput
str
The path to the vector file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/samgeo.py
def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):\n\"\"\"Save the result to a vector file.\n\n Args:\n image (str): The path to the image file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.save_masks","title":"save_masks(self, output=None, foreground=True, unique=True, erosion_kernel=None, mask_multiplier=255, **kwargs)
","text":"Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
Parameters:
Name Type Description Defaultoutput
str
The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
None
foreground
bool
Whether to generate the foreground mask. Defaults to True.
True
unique
bool
Whether to assign a unique value to each object. Defaults to True.
True
erosion_kernel
tuple
The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
255
Source code in samgeo/samgeo.py
def save_masks(\n self,\n output=None,\n foreground=True,\n unique=True,\n erosion_kernel=None,\n mask_multiplier=255,\n **kwargs,\n):\n\"\"\"Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.\n\n Args:\n output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.\n foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.\n unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.\n erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.\n Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.\n\n \"\"\"\n\n if self.masks is None:\n raise ValueError(\"No masks found. Please run generate() first.\")\n\n h, w, _ = self.image.shape\n masks = self.masks\n\n # Set output image data type based on the number of objects\n if len(masks) < 255:\n dtype = np.uint8\n elif len(masks) < 65535:\n dtype = np.uint16\n else:\n dtype = np.uint32\n\n # Generate a mask of objects with unique values\n if unique:\n # Sort the masks by area in ascending order\n sorted_masks = sorted(masks, key=(lambda x: x[\"area\"]), reverse=False)\n\n # Create an output image with the same size as the input image\n objects = np.zeros(\n (\n sorted_masks[0][\"segmentation\"].shape[0],\n sorted_masks[0][\"segmentation\"].shape[1],\n )\n )\n # Assign a unique value to each object\n for index, ann in enumerate(sorted_masks):\n m = ann[\"segmentation\"]\n objects[m] = index + 1\n\n # Generate a binary mask\n else:\n if foreground: # Extract foreground objects only\n resulting_mask = np.zeros((h, w), dtype=dtype)\n else:\n resulting_mask = np.ones((h, w), dtype=dtype)\n resulting_borders = np.zeros((h, w), dtype=dtype)\n\n for m in masks:\n mask = (m[\"segmentation\"] > 0).astype(dtype)\n resulting_mask += mask\n\n # Apply erosion to the mask\n if erosion_kernel is not None:\n mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)\n mask_erode = (mask_erode > 0).astype(dtype)\n edge_mask = mask - mask_erode\n resulting_borders += edge_mask\n\n resulting_mask = (resulting_mask > 0).astype(dtype)\n resulting_borders = (resulting_borders > 0).astype(dtype)\n objects = resulting_mask - resulting_borders\n objects = objects * mask_multiplier\n\n objects = objects.astype(dtype)\n self.objects = objects\n\n if output is not None: # Save the output image\n array_to_image(self.objects, output, self.source, **kwargs)\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.save_prediction","title":"save_prediction(self, output, index=None, mask_multiplier=255, dtype=<class 'numpy.float32'>, vector=None, simplify_tolerance=None, **kwargs)
","text":"Save the predicted mask to the output path.
Parameters:
Name Type Description Defaultoutput
str
The path to the output image.
requiredindex
int
The index of the mask to save. Defaults to None, which will save the mask with the highest score.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1].
255
vector
str
The path to the output vector file. Defaults to None.
None
dtype
np.dtype
The data type of the output image. Defaults to np.float32.
<class 'numpy.float32'>
simplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/samgeo.py
def save_prediction(\n self,\n output,\n index=None,\n mask_multiplier=255,\n dtype=np.float32,\n vector=None,\n simplify_tolerance=None,\n **kwargs,\n):\n\"\"\"Save the predicted mask to the output path.\n\n Args:\n output (str): The path to the output image.\n index (int, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n vector (str, optional): The path to the output vector file. Defaults to None.\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n\n \"\"\"\n if self.scores is None:\n raise ValueError(\"No predictions found. Please run predict() first.\")\n\n if index is None:\n index = self.scores.argmax(axis=0)\n\n array = self.masks[index] * mask_multiplier\n self.prediction = array\n array_to_image(array, output, self.source, dtype=dtype, **kwargs)\n\n if vector is not None:\n raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.set_image","title":"set_image(self, image, image_format='RGB')
","text":"Set the input image as a numpy array.
Parameters:
Name Type Description Defaultimage
np.ndarray
The input image as a numpy array.
requiredimage_format
str
The image format, can be RGB or BGR. Defaults to \"RGB\".
'RGB'
Source code in samgeo/samgeo.py
def set_image(self, image, image_format=\"RGB\"):\n\"\"\"Set the input image as a numpy array.\n\n Args:\n image (np.ndarray): The input image as a numpy array.\n image_format (str, optional): The image format, can be RGB or BGR. Defaults to \"RGB\".\n \"\"\"\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n\n image = cv2.imread(image)\n image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n self.image = image\n elif isinstance(image, np.ndarray):\n pass\n else:\n raise ValueError(\"Input image must be either a path or a numpy array.\")\n\n self.predictor.set_image(image, image_format=image_format)\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.show_anns","title":"show_anns(self, figsize=(12, 10), axis='off', alpha=0.35, output=None, blend=True, **kwargs)
","text":"Show the annotations (objects with random color) on the input image.
Parameters:
Name Type Description Defaultfigsize
tuple
The figure size. Defaults to (12, 10).
(12, 10)
axis
str
Whether to show the axis. Defaults to \"off\".
'off'
alpha
float
The alpha value for the annotations. Defaults to 0.35.
0.35
output
str
The path to the output image. Defaults to None.
None
blend
bool
Whether to show the input image. Defaults to True.
True
Source code in samgeo/samgeo.py
def show_anns(\n self,\n figsize=(12, 10),\n axis=\"off\",\n alpha=0.35,\n output=None,\n blend=True,\n **kwargs,\n):\n\"\"\"Show the annotations (objects with random color) on the input image.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.\n output (str, optional): The path to the output image. Defaults to None.\n blend (bool, optional): Whether to show the input image. Defaults to True.\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n anns = self.masks\n\n if self.image is None:\n print(\"Please run generate() first.\")\n return\n\n if anns is None or len(anns) == 0:\n return\n\n plt.figure(figsize=figsize)\n plt.imshow(self.image)\n\n sorted_anns = sorted(anns, key=(lambda x: x[\"area\"]), reverse=True)\n\n ax = plt.gca()\n ax.set_autoscale_on(False)\n\n img = np.ones(\n (\n sorted_anns[0][\"segmentation\"].shape[0],\n sorted_anns[0][\"segmentation\"].shape[1],\n 4,\n )\n )\n img[:, :, 3] = 0\n for ann in sorted_anns:\n m = ann[\"segmentation\"]\n color_mask = np.concatenate([np.random.random(3), [alpha]])\n img[m] = color_mask\n ax.imshow(img)\n\n if \"dpi\" not in kwargs:\n kwargs[\"dpi\"] = 100\n\n if \"bbox_inches\" not in kwargs:\n kwargs[\"bbox_inches\"] = \"tight\"\n\n plt.axis(axis)\n\n self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)\n\n if output is not None:\n if blend:\n array = blend_images(\n self.annotations, self.image, alpha=alpha, show=False\n )\n else:\n array = self.annotations\n array_to_image(array, output, self.source)\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.show_canvas","title":"show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)
","text":"Show a canvas to collect foreground and background points.
Parameters:
Name Type Description Defaultimage
str | np.ndarray
The input image.
requiredfg_color
tuple
The color for the foreground points. Defaults to (0, 255, 0).
(0, 255, 0)
bg_color
tuple
The color for the background points. Defaults to (0, 0, 255).
(0, 0, 255)
radius
int
The radius of the points. Defaults to 5.
5
Returns:
Type Descriptiontuple
A tuple of two lists of foreground and background points.
Source code insamgeo/samgeo.py
def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):\n\"\"\"Show a canvas to collect foreground and background points.\n\n Args:\n image (str | np.ndarray): The input image.\n fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).\n bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).\n radius (int, optional): The radius of the points. Defaults to 5.\n\n Returns:\n tuple: A tuple of two lists of foreground and background points.\n \"\"\"\n\n if self.image is None:\n raise ValueError(\"Please run set_image() first.\")\n\n image = self.image\n fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)\n self.fg_points = fg_points\n self.bg_points = bg_points\n point_coords = fg_points + bg_points\n point_labels = [1] * len(fg_points) + [0] * len(bg_points)\n self.point_coords = point_coords\n self.point_labels = point_labels\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.show_map","title":"show_map(self, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)
","text":"Show the interactive map.
Parameters:
Name Type Description Defaultbasemap
str
The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
'SATELLITE'
repeat_mode
bool
Whether to use the repeat mode for draw control. Defaults to True.
True
out_dir
str
The path to the output directory. Defaults to None.
None
Returns:
Type Descriptionleafmap.Map
The map object.
Source code insamgeo/samgeo.py
def show_map(self, basemap=\"SATELLITE\", repeat_mode=True, out_dir=None, **kwargs):\n\"\"\"Show the interactive map.\n\n Args:\n basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.\n repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.\n out_dir (str, optional): The path to the output directory. Defaults to None.\n\n Returns:\n leafmap.Map: The map object.\n \"\"\"\n return sam_map_gui(\n self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs\n )\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.show_masks","title":"show_masks(self, figsize=(12, 10), cmap='binary_r', axis='off', foreground=True, **kwargs)
","text":"Show the binary mask or the mask of objects with unique values.
Parameters:
Name Type Description Defaultfigsize
tuple
The figure size. Defaults to (12, 10).
(12, 10)
cmap
str
The colormap. Defaults to \"binary_r\".
'binary_r'
axis
str
Whether to show the axis. Defaults to \"off\".
'off'
foreground
bool
Whether to show the foreground mask only. Defaults to True.
True
**kwargs
Other arguments for save_masks().
{}
Source code in samgeo/samgeo.py
def show_masks(\n self, figsize=(12, 10), cmap=\"binary_r\", axis=\"off\", foreground=True, **kwargs\n):\n\"\"\"Show the binary mask or the mask of objects with unique values.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n cmap (str, optional): The colormap. Defaults to \"binary_r\".\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.\n **kwargs: Other arguments for save_masks().\n \"\"\"\n\n import matplotlib.pyplot as plt\n\n if self.batch:\n self.objects = cv2.imread(self.masks)\n else:\n if self.objects is None:\n self.save_masks(foreground=foreground, **kwargs)\n\n plt.figure(figsize=figsize)\n plt.imshow(self.objects, cmap=cmap)\n plt.axis(axis)\n plt.show()\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.tensor_to_numpy","title":"tensor_to_numpy(self, index=None, output=None, mask_multiplier=255, dtype='uint8', save_args={})
","text":"Convert the predicted masks from tensors to numpy arrays.
Parameters:
Name Type Description Defaultindex
index
The index of the mask to save. Defaults to None, which will save the mask with the highest score.
None
output
str
The path to the output image. Defaults to None.
None
mask_multiplier
int
The mask multiplier for the output mask, which is usually a binary mask [0, 1].
255
dtype
np.dtype
The data type of the output image. Defaults to np.uint8.
'uint8'
save_args
dict
Optional arguments for saving the output image. Defaults to {}.
{}
Returns:
Type Descriptionnp.ndarray
The predicted mask as a numpy array.
Source code insamgeo/samgeo.py
def tensor_to_numpy(\n self, index=None, output=None, mask_multiplier=255, dtype=\"uint8\", save_args={}\n):\n\"\"\"Convert the predicted masks from tensors to numpy arrays.\n\n Args:\n index (index, optional): The index of the mask to save. Defaults to None,\n which will save the mask with the highest score.\n output (str, optional): The path to the output image. Defaults to None.\n mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].\n dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.\n save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.\n\n Returns:\n np.ndarray: The predicted mask as a numpy array.\n \"\"\"\n\n boxes = self.boxes\n masks = self.masks\n\n image_pil = self.image\n image_np = np.array(image_pil)\n\n if index is None:\n index = 1\n\n masks = masks[:, index, :, :]\n masks = masks.squeeze(1)\n\n if boxes is None or (len(boxes) == 0): # No \"object\" instances found\n print(\"No objects found in the image.\")\n return\n else:\n # Create an empty image to store the mask overlays\n mask_overlay = np.zeros_like(\n image_np[..., 0], dtype=dtype\n ) # Adjusted for single channel\n\n for i, (box, mask) in enumerate(zip(boxes, masks)):\n # Convert tensor to numpy array if necessary and ensure it contains integers\n if isinstance(mask, torch.Tensor):\n mask = (\n mask.cpu().numpy().astype(dtype)\n ) # If mask is on GPU, use .cpu() before .numpy()\n mask_overlay += ((mask > 0) * (i + 1)).astype(\n dtype\n ) # Assign a unique value for each mask\n\n # Normalize mask_overlay to be in [0, 255]\n mask_overlay = (\n mask_overlay > 0\n ) * mask_multiplier # Binary mask in [0, 255]\n\n if output is not None:\n array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)\n else:\n return mask_overlay\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.tiff_to_geojson","title":"tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a GeoJSON file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the GeoJSON file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/samgeo.py
def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a GeoJSON file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the GeoJSON file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_geojson(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.tiff_to_gpkg","title":"tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a gpkg file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the gpkg file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/samgeo.py
def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the gpkg file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_gpkg(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.tiff_to_shp","title":"tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a shapefile.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the shapefile.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/samgeo.py
def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a shapefile.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the shapefile.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_shp(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeo.tiff_to_vector","title":"tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs)
","text":"Convert a tiff file to a gpkg file.
Parameters:
Name Type Description Defaulttiff_path
str
The path to the tiff file.
requiredoutput
str
The path to the vector file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/samgeo.py
def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):\n\"\"\"Convert a tiff file to a gpkg file.\n\n Args:\n tiff_path (str): The path to the tiff file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(\n tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs\n )\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeoPredictor","title":" SamGeoPredictor (SamPredictor)
","text":"Source code in samgeo/samgeo.py
class SamGeoPredictor(SamPredictor):\n def __init__(\n self,\n sam_model,\n ):\n from segment_anything.utils.transforms import ResizeLongestSide\n\n self.model = sam_model\n self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)\n\n def set_image(self, image):\n super(SamGeoPredictor, self).set_image(image)\n\n def predict(\n self,\n src_fp=None,\n geo_box=None,\n point_coords=None,\n point_labels=None,\n box=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n ):\n if geo_box and src_fp:\n self.crs = \"EPSG:4326\"\n dst_crs = get_crs(src_fp)\n sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)\n ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)\n xs = np.array([sw[0], ne[0]])\n ys = np.array([sw[1], ne[1]])\n box = get_pixel_coords(src_fp, xs, ys)\n self.geo_box = geo_box\n self.width = box[2] - box[0]\n self.height = box[3] - box[1]\n self.geo_transform = set_transform(geo_box, self.width, self.height)\n\n masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(\n point_coords, point_labels, box, mask_input, multimask_output, return_logits\n )\n\n return masks, iou_predictions, low_res_masks\n\n def masks_to_geotiff(self, src_fp, dst_fp, masks):\n profile = get_profile(src_fp)\n write_raster(\n dst_fp,\n masks,\n profile,\n self.width,\n self.height,\n self.geo_transform,\n self.crs,\n )\n\n def geotiff_to_geojson(self, src_fp, dst_fp, bidx=1):\n gdf = get_features(src_fp, bidx)\n write_features(gdf, dst_fp)\n return gdf\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeoPredictor.predict","title":"predict(self, src_fp=None, geo_box=None, point_coords=None, point_labels=None, box=None, mask_input=None, multimask_output=True, return_logits=False)
","text":"Predict masks for the given input prompts, using the currently set image.
Parameters:
Name Type Description Defaultpoint_coords
np.ndarray or None
A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels.
None
point_labels
np.ndarray or None
A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
None
box
np.ndarray or None
A length 4 array given a box prompt to the model, in XYXY format.
None
mask_input
np.ndarray
A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
None
multimask_output
bool
If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.
True
return_logits
bool
If true, returns un-thresholded masks logits instead of a binary mask.
False
Returns:
Type Description(np.ndarray)
The output masks in CxHxW format, where C is the number of masks, and (H, W) is the original image size. (np.ndarray): An array of length C containing the model's predictions for the quality of each mask. (np.ndarray): An array of shape CxHxW, where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input.
Source code insamgeo/samgeo.py
def predict(\n self,\n src_fp=None,\n geo_box=None,\n point_coords=None,\n point_labels=None,\n box=None,\n mask_input=None,\n multimask_output=True,\n return_logits=False,\n):\n if geo_box and src_fp:\n self.crs = \"EPSG:4326\"\n dst_crs = get_crs(src_fp)\n sw = transform_coords(geo_box[0], geo_box[1], self.crs, dst_crs)\n ne = transform_coords(geo_box[2], geo_box[3], self.crs, dst_crs)\n xs = np.array([sw[0], ne[0]])\n ys = np.array([sw[1], ne[1]])\n box = get_pixel_coords(src_fp, xs, ys)\n self.geo_box = geo_box\n self.width = box[2] - box[0]\n self.height = box[3] - box[1]\n self.geo_transform = set_transform(geo_box, self.width, self.height)\n\n masks, iou_predictions, low_res_masks = super(SamGeoPredictor, self).predict(\n point_coords, point_labels, box, mask_input, multimask_output, return_logits\n )\n\n return masks, iou_predictions, low_res_masks\n
"},{"location":"samgeo/#samgeo.samgeo.SamGeoPredictor.set_image","title":"set_image(self, image)
","text":"Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method.
Parameters:
Name Type Description Defaultimage
np.ndarray
The image for calculating masks. Expects an image in HWC uint8 format, with pixel values in [0, 255].
requiredimage_format
str
The color format of the image, in ['RGB', 'BGR'].
required Source code insamgeo/samgeo.py
def set_image(self, image):\n super(SamGeoPredictor, self).set_image(image)\n
"},{"location":"text_sam/","title":"text_sam module","text":"The LangSAM model for segmenting objects from satellite images using text prompts. The source code is adapted from the https://github.com/luca-medeiros/lang-segment-anything repository. Credits to Luca Medeiros for the original implementation.
"},{"location":"text_sam/#samgeo.text_sam.LangSAM","title":" LangSAM
","text":"A Language-based Segment-Anything Model (LangSAM) class which combines GroundingDINO and SAM.
Source code insamgeo/text_sam.py
class LangSAM:\n\"\"\"\n A Language-based Segment-Anything Model (LangSAM) class which combines GroundingDINO and SAM.\n \"\"\"\n\n def __init__(self, model_type=\"vit_h\"):\n\"\"\"Initialize the LangSAM instance.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n \"\"\"\n\n self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n self.build_groundingdino()\n self.build_sam(model_type)\n\n self.source = None\n self.image = None\n self.masks = None\n self.boxes = None\n self.phrases = None\n self.logits = None\n self.prediction = None\n\n def build_sam(self, model_type):\n\"\"\"Build the SAM model.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n \"\"\"\n checkpoint_url = SAM_MODELS[model_type]\n sam = sam_model_registry[model_type]()\n state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)\n sam.load_state_dict(state_dict, strict=True)\n sam.to(device=self.device)\n self.sam = SamPredictor(sam)\n\n def build_groundingdino(self):\n\"\"\"Build the GroundingDINO model.\"\"\"\n ckpt_repo_id = \"ShilongLiu/GroundingDINO\"\n ckpt_filename = \"groundingdino_swinb_cogcoor.pth\"\n ckpt_config_filename = \"GroundingDINO_SwinB.cfg.py\"\n self.groundingdino = load_model_hf(\n ckpt_repo_id, ckpt_filename, ckpt_config_filename, self.device\n )\n\n def predict_dino(self, image, text_prompt, box_threshold, text_threshold):\n\"\"\"\n Run the GroundingDINO model prediction.\n\n Args:\n image (Image): Input PIL Image.\n text_prompt (str): Text prompt for the model.\n box_threshold (float): Box threshold for the prediction.\n text_threshold (float): Text threshold for the prediction.\n\n Returns:\n tuple: Tuple containing boxes, logits, and phrases.\n \"\"\"\n\n image_trans = transform_image(image)\n boxes, logits, phrases = predict(\n model=self.groundingdino,\n image=image_trans,\n caption=text_prompt,\n box_threshold=box_threshold,\n text_threshold=text_threshold,\n device=self.device,\n )\n W, H = image.size\n boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])\n\n return boxes, logits, phrases\n\n def predict_sam(self, image, boxes):\n\"\"\"\n Run the SAM model prediction.\n\n Args:\n image (Image): Input PIL Image.\n boxes (torch.Tensor): Tensor of bounding boxes.\n\n Returns:\n Masks tensor.\n \"\"\"\n image_array = np.asarray(image)\n self.sam.set_image(image_array)\n transformed_boxes = self.sam.transform.apply_boxes_torch(\n boxes, image_array.shape[:2]\n )\n masks, _, _ = self.sam.predict_torch(\n point_coords=None,\n point_labels=None,\n boxes=transformed_boxes.to(self.sam.device),\n multimask_output=False,\n )\n return masks.cpu()\n\n def set_image(self, image):\n\"\"\"Set the input image.\n\n Args:\n image (str): The path to the image file or a HTTP URL.\n \"\"\"\n\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n else:\n self.source = None\n\n def predict(\n self,\n image,\n text_prompt,\n box_threshold,\n text_threshold,\n output=None,\n mask_multiplier=255,\n dtype=np.uint8,\n save_args={},\n return_results=False,\n return_coords=False,\n **kwargs,\n ):\n\"\"\"\n Run both GroundingDINO and SAM model prediction.\n\n Parameters:\n image (Image): Input PIL Image.\n text_prompt (str): Text prompt for the model.\n box_threshold (float): Box threshold for the prediction.\n text_threshold (float): Text threshold for the prediction.\n output (str, optional): Output path for the prediction. Defaults to None.\n mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.\n dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.\n save_args (dict, optional): Save arguments for the prediction. Defaults to {}.\n return_results (bool, optional): Whether to return the results. Defaults to False.\n\n Returns:\n tuple: Tuple containing masks, boxes, phrases, and logits.\n \"\"\"\n\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n\n # Load the georeferenced image\n with rasterio.open(image) as src:\n image_np = src.read().transpose(\n (1, 2, 0)\n ) # Convert rasterio image to numpy array\n self.transform = src.transform # Save georeferencing information\n self.crs = src.crs # Save the Coordinate Reference System\n image_pil = Image.fromarray(\n image_np[:, :, :3]\n ) # Convert numpy array to PIL image, excluding the alpha channel\n else:\n image_pil = image\n image_np = np.array(image_pil)\n\n self.image = image_pil\n\n boxes, logits, phrases = self.predict_dino(\n image_pil, text_prompt, box_threshold, text_threshold\n )\n masks = torch.tensor([])\n if len(boxes) > 0:\n masks = self.predict_sam(image_pil, boxes)\n masks = masks.squeeze(1)\n\n if boxes.nelement() == 0: # No \"object\" instances found\n print(\"No objects found in the image.\")\n return\n else:\n # Create an empty image to store the mask overlays\n mask_overlay = np.zeros_like(\n image_np[..., 0], dtype=dtype\n ) # Adjusted for single channel\n\n for i, (box, mask) in enumerate(zip(boxes, masks)):\n # Convert tensor to numpy array if necessary and ensure it contains integers\n if isinstance(mask, torch.Tensor):\n mask = (\n mask.cpu().numpy().astype(dtype)\n ) # If mask is on GPU, use .cpu() before .numpy()\n mask_overlay += ((mask > 0) * (i + 1)).astype(\n dtype\n ) # Assign a unique value for each mask\n\n # Normalize mask_overlay to be in [0, 255]\n mask_overlay = (\n mask_overlay > 0\n ) * mask_multiplier # Binary mask in [0, 255]\n\n if output is not None:\n array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)\n\n self.masks = masks\n self.boxes = boxes\n self.phrases = phrases\n self.logits = logits\n self.prediction = mask_overlay\n\n if return_results:\n return masks, boxes, phrases, logits\n\n if return_coords:\n boxlist = []\n for box in self.boxes:\n box = box.cpu().numpy()\n boxlist.append((box[0], box[1]))\n return boxlist\n\n def predict_batch(\n self,\n images,\n out_dir,\n text_prompt,\n box_threshold,\n text_threshold,\n mask_multiplier=255,\n dtype=np.uint8,\n save_args={},\n merge=True,\n verbose=True,\n **kwargs,\n ):\n\"\"\"\n Run both GroundingDINO and SAM model prediction for a batch of images.\n\n Parameters:\n images (list): List of input PIL Images.\n out_dir (str): Output directory for the prediction.\n text_prompt (str): Text prompt for the model.\n box_threshold (float): Box threshold for the prediction.\n text_threshold (float): Text threshold for the prediction.\n mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.\n dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.\n save_args (dict, optional): Save arguments for the prediction. Defaults to {}.\n merge (bool, optional): Whether to merge the predictions into a single GeoTIFF file. Defaults to True.\n \"\"\"\n\n import glob\n\n if not os.path.exists(out_dir):\n os.makedirs(out_dir)\n\n if isinstance(images, str):\n images = list(glob.glob(os.path.join(images, \"*.tif\")))\n images.sort()\n\n if not isinstance(images, list):\n raise ValueError(\"images must be a list or a directory to GeoTIFF files.\")\n\n for i, image in enumerate(images):\n basename = os.path.splitext(os.path.basename(image))[0]\n if verbose:\n print(\n f\"Processing image {str(i+1).zfill(len(str(len(images))))} of {len(images)}: {image}...\"\n )\n output = os.path.join(out_dir, f\"{basename}_mask.tif\")\n self.predict(\n image,\n text_prompt,\n box_threshold,\n text_threshold,\n output=output,\n mask_multiplier=mask_multiplier,\n dtype=dtype,\n save_args=save_args,\n **kwargs,\n )\n\n if merge:\n output = os.path.join(out_dir, \"merged.tif\")\n merge_rasters(out_dir, output)\n if verbose:\n print(f\"Saved the merged prediction to {output}.\")\n\n def save_boxes(self, output=None, dst_crs=\"EPSG:4326\", **kwargs):\n\"\"\"Save the bounding boxes to a vector file.\n\n Args:\n output (str): The path to the output vector file.\n dst_crs (str, optional): The destination CRS. Defaults to \"EPSG:4326\".\n **kwargs: Additional arguments for boxes_to_vector().\n \"\"\"\n\n if self.boxes is None:\n print(\"Please run predict() first.\")\n return\n else:\n boxes = self.boxes.tolist()\n coords = rowcol_to_xy(self.source, boxes=boxes, dst_crs=dst_crs, **kwargs)\n if output is None:\n return boxes_to_vector(coords, self.crs, dst_crs, output)\n else:\n boxes_to_vector(coords, self.crs, dst_crs, output)\n\n def show_anns(\n self,\n figsize=(12, 10),\n axis=\"off\",\n cmap=\"viridis\",\n alpha=0.4,\n add_boxes=True,\n box_color=\"r\",\n box_linewidth=1,\n title=None,\n output=None,\n blend=True,\n **kwargs,\n ):\n\"\"\"Show the annotations (objects with random color) on the input image.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n cmap (str, optional): The colormap for the annotations. Defaults to \"viridis\".\n alpha (float, optional): The alpha value for the annotations. Defaults to 0.4.\n add_boxes (bool, optional): Whether to show the bounding boxes. Defaults to True.\n box_color (str, optional): The color for the bounding boxes. Defaults to \"r\".\n box_linewidth (int, optional): The line width for the bounding boxes. Defaults to 1.\n title (str, optional): The title for the image. Defaults to None.\n output (str, optional): The path to the output image. Defaults to None.\n blend (bool, optional): Whether to show the input image. Defaults to True.\n kwargs (dict, optional): Additional arguments for matplotlib.pyplot.savefig().\n \"\"\"\n\n import warnings\n import matplotlib.pyplot as plt\n import matplotlib.patches as patches\n\n warnings.filterwarnings(\"ignore\")\n\n anns = self.prediction\n\n if anns is None:\n print(\"Please run predict() first.\")\n return\n elif len(anns) == 0:\n print(\"No objects found in the image.\")\n return\n\n plt.figure(figsize=figsize)\n plt.imshow(self.image)\n\n if add_boxes:\n for box in self.boxes:\n # Draw bounding box\n box = box.cpu().numpy() # Convert the tensor to a numpy array\n rect = patches.Rectangle(\n (box[0], box[1]),\n box[2] - box[0],\n box[3] - box[1],\n linewidth=box_linewidth,\n edgecolor=box_color,\n facecolor=\"none\",\n )\n plt.gca().add_patch(rect)\n\n if \"dpi\" not in kwargs:\n kwargs[\"dpi\"] = 100\n\n if \"bbox_inches\" not in kwargs:\n kwargs[\"bbox_inches\"] = \"tight\"\n\n plt.imshow(anns, cmap=cmap, alpha=alpha)\n\n if title is not None:\n plt.title(title)\n plt.axis(axis)\n\n if output is not None:\n if blend:\n plt.savefig(output, **kwargs)\n else:\n array_to_image(self.prediction, output, self.source)\n\n def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):\n\"\"\"Save the result to a vector file.\n\n Args:\n image (str): The path to the image file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)\n\n def show_map(self, basemap=\"SATELLITE\", out_dir=None, **kwargs):\n\"\"\"Show the interactive map.\n\n Args:\n basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.\n out_dir (str, optional): The path to the output directory. Defaults to None.\n\n Returns:\n leafmap.Map: The map object.\n \"\"\"\n return text_sam_gui(self, basemap=basemap, out_dir=out_dir, **kwargs)\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.__init__","title":"__init__(self, model_type='vit_h')
special
","text":"Initialize the LangSAM instance.
Parameters:
Name Type Description Defaultmodel_type
str
The model type. It can be one of the following: vit_h, vit_l, vit_b. Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
'vit_h'
Source code in samgeo/text_sam.py
def __init__(self, model_type=\"vit_h\"):\n\"\"\"Initialize the LangSAM instance.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n \"\"\"\n\n self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n self.build_groundingdino()\n self.build_sam(model_type)\n\n self.source = None\n self.image = None\n self.masks = None\n self.boxes = None\n self.phrases = None\n self.logits = None\n self.prediction = None\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.build_groundingdino","title":"build_groundingdino(self)
","text":"Build the GroundingDINO model.
Source code insamgeo/text_sam.py
def build_groundingdino(self):\n\"\"\"Build the GroundingDINO model.\"\"\"\n ckpt_repo_id = \"ShilongLiu/GroundingDINO\"\n ckpt_filename = \"groundingdino_swinb_cogcoor.pth\"\n ckpt_config_filename = \"GroundingDINO_SwinB.cfg.py\"\n self.groundingdino = load_model_hf(\n ckpt_repo_id, ckpt_filename, ckpt_config_filename, self.device\n )\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.build_sam","title":"build_sam(self, model_type)
","text":"Build the SAM model.
Parameters:
Name Type Description Defaultmodel_type
str
The model type. It can be one of the following: vit_h, vit_l, vit_b. Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
required Source code insamgeo/text_sam.py
def build_sam(self, model_type):\n\"\"\"Build the SAM model.\n\n Args:\n model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.\n Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.\n \"\"\"\n checkpoint_url = SAM_MODELS[model_type]\n sam = sam_model_registry[model_type]()\n state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)\n sam.load_state_dict(state_dict, strict=True)\n sam.to(device=self.device)\n self.sam = SamPredictor(sam)\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.predict","title":"predict(self, image, text_prompt, box_threshold, text_threshold, output=None, mask_multiplier=255, dtype=<class 'numpy.uint8'>, save_args={}, return_results=False, return_coords=False, **kwargs)
","text":"Run both GroundingDINO and SAM model prediction.
Parameters:
Name Type Description Defaultimage
Image
Input PIL Image.
requiredtext_prompt
str
Text prompt for the model.
requiredbox_threshold
float
Box threshold for the prediction.
requiredtext_threshold
float
Text threshold for the prediction.
requiredoutput
str
Output path for the prediction. Defaults to None.
None
mask_multiplier
int
Mask multiplier for the prediction. Defaults to 255.
255
dtype
np.dtype
Data type for the prediction. Defaults to np.uint8.
<class 'numpy.uint8'>
save_args
dict
Save arguments for the prediction. Defaults to {}.
{}
return_results
bool
Whether to return the results. Defaults to False.
False
Returns:
Type Descriptiontuple
Tuple containing masks, boxes, phrases, and logits.
Source code insamgeo/text_sam.py
def predict(\n self,\n image,\n text_prompt,\n box_threshold,\n text_threshold,\n output=None,\n mask_multiplier=255,\n dtype=np.uint8,\n save_args={},\n return_results=False,\n return_coords=False,\n **kwargs,\n):\n\"\"\"\n Run both GroundingDINO and SAM model prediction.\n\n Parameters:\n image (Image): Input PIL Image.\n text_prompt (str): Text prompt for the model.\n box_threshold (float): Box threshold for the prediction.\n text_threshold (float): Text threshold for the prediction.\n output (str, optional): Output path for the prediction. Defaults to None.\n mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.\n dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.\n save_args (dict, optional): Save arguments for the prediction. Defaults to {}.\n return_results (bool, optional): Whether to return the results. Defaults to False.\n\n Returns:\n tuple: Tuple containing masks, boxes, phrases, and logits.\n \"\"\"\n\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n\n # Load the georeferenced image\n with rasterio.open(image) as src:\n image_np = src.read().transpose(\n (1, 2, 0)\n ) # Convert rasterio image to numpy array\n self.transform = src.transform # Save georeferencing information\n self.crs = src.crs # Save the Coordinate Reference System\n image_pil = Image.fromarray(\n image_np[:, :, :3]\n ) # Convert numpy array to PIL image, excluding the alpha channel\n else:\n image_pil = image\n image_np = np.array(image_pil)\n\n self.image = image_pil\n\n boxes, logits, phrases = self.predict_dino(\n image_pil, text_prompt, box_threshold, text_threshold\n )\n masks = torch.tensor([])\n if len(boxes) > 0:\n masks = self.predict_sam(image_pil, boxes)\n masks = masks.squeeze(1)\n\n if boxes.nelement() == 0: # No \"object\" instances found\n print(\"No objects found in the image.\")\n return\n else:\n # Create an empty image to store the mask overlays\n mask_overlay = np.zeros_like(\n image_np[..., 0], dtype=dtype\n ) # Adjusted for single channel\n\n for i, (box, mask) in enumerate(zip(boxes, masks)):\n # Convert tensor to numpy array if necessary and ensure it contains integers\n if isinstance(mask, torch.Tensor):\n mask = (\n mask.cpu().numpy().astype(dtype)\n ) # If mask is on GPU, use .cpu() before .numpy()\n mask_overlay += ((mask > 0) * (i + 1)).astype(\n dtype\n ) # Assign a unique value for each mask\n\n # Normalize mask_overlay to be in [0, 255]\n mask_overlay = (\n mask_overlay > 0\n ) * mask_multiplier # Binary mask in [0, 255]\n\n if output is not None:\n array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)\n\n self.masks = masks\n self.boxes = boxes\n self.phrases = phrases\n self.logits = logits\n self.prediction = mask_overlay\n\n if return_results:\n return masks, boxes, phrases, logits\n\n if return_coords:\n boxlist = []\n for box in self.boxes:\n box = box.cpu().numpy()\n boxlist.append((box[0], box[1]))\n return boxlist\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.predict_batch","title":"predict_batch(self, images, out_dir, text_prompt, box_threshold, text_threshold, mask_multiplier=255, dtype=<class 'numpy.uint8'>, save_args={}, merge=True, verbose=True, **kwargs)
","text":"Run both GroundingDINO and SAM model prediction for a batch of images.
Parameters:
Name Type Description Defaultimages
list
List of input PIL Images.
requiredout_dir
str
Output directory for the prediction.
requiredtext_prompt
str
Text prompt for the model.
requiredbox_threshold
float
Box threshold for the prediction.
requiredtext_threshold
float
Text threshold for the prediction.
requiredmask_multiplier
int
Mask multiplier for the prediction. Defaults to 255.
255
dtype
np.dtype
Data type for the prediction. Defaults to np.uint8.
<class 'numpy.uint8'>
save_args
dict
Save arguments for the prediction. Defaults to {}.
{}
merge
bool
Whether to merge the predictions into a single GeoTIFF file. Defaults to True.
True
Source code in samgeo/text_sam.py
def predict_batch(\n self,\n images,\n out_dir,\n text_prompt,\n box_threshold,\n text_threshold,\n mask_multiplier=255,\n dtype=np.uint8,\n save_args={},\n merge=True,\n verbose=True,\n **kwargs,\n):\n\"\"\"\n Run both GroundingDINO and SAM model prediction for a batch of images.\n\n Parameters:\n images (list): List of input PIL Images.\n out_dir (str): Output directory for the prediction.\n text_prompt (str): Text prompt for the model.\n box_threshold (float): Box threshold for the prediction.\n text_threshold (float): Text threshold for the prediction.\n mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.\n dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.\n save_args (dict, optional): Save arguments for the prediction. Defaults to {}.\n merge (bool, optional): Whether to merge the predictions into a single GeoTIFF file. Defaults to True.\n \"\"\"\n\n import glob\n\n if not os.path.exists(out_dir):\n os.makedirs(out_dir)\n\n if isinstance(images, str):\n images = list(glob.glob(os.path.join(images, \"*.tif\")))\n images.sort()\n\n if not isinstance(images, list):\n raise ValueError(\"images must be a list or a directory to GeoTIFF files.\")\n\n for i, image in enumerate(images):\n basename = os.path.splitext(os.path.basename(image))[0]\n if verbose:\n print(\n f\"Processing image {str(i+1).zfill(len(str(len(images))))} of {len(images)}: {image}...\"\n )\n output = os.path.join(out_dir, f\"{basename}_mask.tif\")\n self.predict(\n image,\n text_prompt,\n box_threshold,\n text_threshold,\n output=output,\n mask_multiplier=mask_multiplier,\n dtype=dtype,\n save_args=save_args,\n **kwargs,\n )\n\n if merge:\n output = os.path.join(out_dir, \"merged.tif\")\n merge_rasters(out_dir, output)\n if verbose:\n print(f\"Saved the merged prediction to {output}.\")\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.predict_dino","title":"predict_dino(self, image, text_prompt, box_threshold, text_threshold)
","text":"Run the GroundingDINO model prediction.
Parameters:
Name Type Description Defaultimage
Image
Input PIL Image.
requiredtext_prompt
str
Text prompt for the model.
requiredbox_threshold
float
Box threshold for the prediction.
requiredtext_threshold
float
Text threshold for the prediction.
requiredReturns:
Type Descriptiontuple
Tuple containing boxes, logits, and phrases.
Source code insamgeo/text_sam.py
def predict_dino(self, image, text_prompt, box_threshold, text_threshold):\n\"\"\"\n Run the GroundingDINO model prediction.\n\n Args:\n image (Image): Input PIL Image.\n text_prompt (str): Text prompt for the model.\n box_threshold (float): Box threshold for the prediction.\n text_threshold (float): Text threshold for the prediction.\n\n Returns:\n tuple: Tuple containing boxes, logits, and phrases.\n \"\"\"\n\n image_trans = transform_image(image)\n boxes, logits, phrases = predict(\n model=self.groundingdino,\n image=image_trans,\n caption=text_prompt,\n box_threshold=box_threshold,\n text_threshold=text_threshold,\n device=self.device,\n )\n W, H = image.size\n boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])\n\n return boxes, logits, phrases\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.predict_sam","title":"predict_sam(self, image, boxes)
","text":"Run the SAM model prediction.
Parameters:
Name Type Description Defaultimage
Image
Input PIL Image.
requiredboxes
torch.Tensor
Tensor of bounding boxes.
requiredReturns:
Type DescriptionMasks tensor.
Source code insamgeo/text_sam.py
def predict_sam(self, image, boxes):\n\"\"\"\n Run the SAM model prediction.\n\n Args:\n image (Image): Input PIL Image.\n boxes (torch.Tensor): Tensor of bounding boxes.\n\n Returns:\n Masks tensor.\n \"\"\"\n image_array = np.asarray(image)\n self.sam.set_image(image_array)\n transformed_boxes = self.sam.transform.apply_boxes_torch(\n boxes, image_array.shape[:2]\n )\n masks, _, _ = self.sam.predict_torch(\n point_coords=None,\n point_labels=None,\n boxes=transformed_boxes.to(self.sam.device),\n multimask_output=False,\n )\n return masks.cpu()\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.raster_to_vector","title":"raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs)
","text":"Save the result to a vector file.
Parameters:
Name Type Description Defaultimage
str
The path to the image file.
requiredoutput
str
The path to the vector file.
requiredsimplify_tolerance
float
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.
None
Source code in samgeo/text_sam.py
def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):\n\"\"\"Save the result to a vector file.\n\n Args:\n image (str): The path to the image file.\n output (str): The path to the vector file.\n simplify_tolerance (float, optional): The maximum allowed geometry displacement.\n The higher this value, the smaller the number of vertices in the resulting geometry.\n \"\"\"\n\n raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.save_boxes","title":"save_boxes(self, output=None, dst_crs='EPSG:4326', **kwargs)
","text":"Save the bounding boxes to a vector file.
Parameters:
Name Type Description Defaultoutput
str
The path to the output vector file.
None
dst_crs
str
The destination CRS. Defaults to \"EPSG:4326\".
'EPSG:4326'
**kwargs
Additional arguments for boxes_to_vector().
{}
Source code in samgeo/text_sam.py
def save_boxes(self, output=None, dst_crs=\"EPSG:4326\", **kwargs):\n\"\"\"Save the bounding boxes to a vector file.\n\n Args:\n output (str): The path to the output vector file.\n dst_crs (str, optional): The destination CRS. Defaults to \"EPSG:4326\".\n **kwargs: Additional arguments for boxes_to_vector().\n \"\"\"\n\n if self.boxes is None:\n print(\"Please run predict() first.\")\n return\n else:\n boxes = self.boxes.tolist()\n coords = rowcol_to_xy(self.source, boxes=boxes, dst_crs=dst_crs, **kwargs)\n if output is None:\n return boxes_to_vector(coords, self.crs, dst_crs, output)\n else:\n boxes_to_vector(coords, self.crs, dst_crs, output)\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.set_image","title":"set_image(self, image)
","text":"Set the input image.
Parameters:
Name Type Description Defaultimage
str
The path to the image file or a HTTP URL.
required Source code insamgeo/text_sam.py
def set_image(self, image):\n\"\"\"Set the input image.\n\n Args:\n image (str): The path to the image file or a HTTP URL.\n \"\"\"\n\n if isinstance(image, str):\n if image.startswith(\"http\"):\n image = download_file(image)\n\n if not os.path.exists(image):\n raise ValueError(f\"Input path {image} does not exist.\")\n\n self.source = image\n else:\n self.source = None\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.show_anns","title":"show_anns(self, figsize=(12, 10), axis='off', cmap='viridis', alpha=0.4, add_boxes=True, box_color='r', box_linewidth=1, title=None, output=None, blend=True, **kwargs)
","text":"Show the annotations (objects with random color) on the input image.
Parameters:
Name Type Description Defaultfigsize
tuple
The figure size. Defaults to (12, 10).
(12, 10)
axis
str
Whether to show the axis. Defaults to \"off\".
'off'
cmap
str
The colormap for the annotations. Defaults to \"viridis\".
'viridis'
alpha
float
The alpha value for the annotations. Defaults to 0.4.
0.4
add_boxes
bool
Whether to show the bounding boxes. Defaults to True.
True
box_color
str
The color for the bounding boxes. Defaults to \"r\".
'r'
box_linewidth
int
The line width for the bounding boxes. Defaults to 1.
1
title
str
The title for the image. Defaults to None.
None
output
str
The path to the output image. Defaults to None.
None
blend
bool
Whether to show the input image. Defaults to True.
True
kwargs
dict
Additional arguments for matplotlib.pyplot.savefig().
{}
Source code in samgeo/text_sam.py
def show_anns(\n self,\n figsize=(12, 10),\n axis=\"off\",\n cmap=\"viridis\",\n alpha=0.4,\n add_boxes=True,\n box_color=\"r\",\n box_linewidth=1,\n title=None,\n output=None,\n blend=True,\n **kwargs,\n):\n\"\"\"Show the annotations (objects with random color) on the input image.\n\n Args:\n figsize (tuple, optional): The figure size. Defaults to (12, 10).\n axis (str, optional): Whether to show the axis. Defaults to \"off\".\n cmap (str, optional): The colormap for the annotations. Defaults to \"viridis\".\n alpha (float, optional): The alpha value for the annotations. Defaults to 0.4.\n add_boxes (bool, optional): Whether to show the bounding boxes. Defaults to True.\n box_color (str, optional): The color for the bounding boxes. Defaults to \"r\".\n box_linewidth (int, optional): The line width for the bounding boxes. Defaults to 1.\n title (str, optional): The title for the image. Defaults to None.\n output (str, optional): The path to the output image. Defaults to None.\n blend (bool, optional): Whether to show the input image. Defaults to True.\n kwargs (dict, optional): Additional arguments for matplotlib.pyplot.savefig().\n \"\"\"\n\n import warnings\n import matplotlib.pyplot as plt\n import matplotlib.patches as patches\n\n warnings.filterwarnings(\"ignore\")\n\n anns = self.prediction\n\n if anns is None:\n print(\"Please run predict() first.\")\n return\n elif len(anns) == 0:\n print(\"No objects found in the image.\")\n return\n\n plt.figure(figsize=figsize)\n plt.imshow(self.image)\n\n if add_boxes:\n for box in self.boxes:\n # Draw bounding box\n box = box.cpu().numpy() # Convert the tensor to a numpy array\n rect = patches.Rectangle(\n (box[0], box[1]),\n box[2] - box[0],\n box[3] - box[1],\n linewidth=box_linewidth,\n edgecolor=box_color,\n facecolor=\"none\",\n )\n plt.gca().add_patch(rect)\n\n if \"dpi\" not in kwargs:\n kwargs[\"dpi\"] = 100\n\n if \"bbox_inches\" not in kwargs:\n kwargs[\"bbox_inches\"] = \"tight\"\n\n plt.imshow(anns, cmap=cmap, alpha=alpha)\n\n if title is not None:\n plt.title(title)\n plt.axis(axis)\n\n if output is not None:\n if blend:\n plt.savefig(output, **kwargs)\n else:\n array_to_image(self.prediction, output, self.source)\n
"},{"location":"text_sam/#samgeo.text_sam.LangSAM.show_map","title":"show_map(self, basemap='SATELLITE', out_dir=None, **kwargs)
","text":"Show the interactive map.
Parameters:
Name Type Description Defaultbasemap
str
The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
'SATELLITE'
out_dir
str
The path to the output directory. Defaults to None.
None
Returns:
Type Descriptionleafmap.Map
The map object.
Source code insamgeo/text_sam.py
def show_map(self, basemap=\"SATELLITE\", out_dir=None, **kwargs):\n\"\"\"Show the interactive map.\n\n Args:\n basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.\n out_dir (str, optional): The path to the output directory. Defaults to None.\n\n Returns:\n leafmap.Map: The map object.\n \"\"\"\n return text_sam_gui(self, basemap=basemap, out_dir=out_dir, **kwargs)\n
"},{"location":"text_sam/#samgeo.text_sam.load_model_hf","title":"load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu')
","text":"Loads a model from HuggingFace Model Hub.
Parameters:
Name Type Description Defaultrepo_id
str
Repository ID on HuggingFace Model Hub.
requiredfilename
str
Name of the model file in the repository.
requiredckpt_config_filename
str
Name of the config file for the model in the repository.
requireddevice
str
Device to load the model onto. Default is 'cpu'.
'cpu'
Returns:
Type Descriptiontorch.nn.Module
The loaded model.
Source code insamgeo/text_sam.py
def load_model_hf(\n repo_id: str, filename: str, ckpt_config_filename: str, device: str = \"cpu\"\n) -> torch.nn.Module:\n\"\"\"\n Loads a model from HuggingFace Model Hub.\n\n Args:\n repo_id (str): Repository ID on HuggingFace Model Hub.\n filename (str): Name of the model file in the repository.\n ckpt_config_filename (str): Name of the config file for the model in the repository.\n device (str): Device to load the model onto. Default is 'cpu'.\n\n Returns:\n torch.nn.Module: The loaded model.\n \"\"\"\n\n cache_config_file = hf_hub_download(\n repo_id=repo_id,\n filename=ckpt_config_filename,\n force_filename=ckpt_config_filename,\n )\n args = SLConfig.fromfile(cache_config_file)\n model = build_model(args)\n model.to(device)\n cache_file = hf_hub_download(\n repo_id=repo_id, filename=filename, force_filename=filename\n )\n checkpoint = torch.load(cache_file, map_location=\"cpu\")\n model.load_state_dict(clean_state_dict(checkpoint[\"model\"]), strict=False)\n model.eval()\n return model\n
"},{"location":"text_sam/#samgeo.text_sam.transform_image","title":"transform_image(image)
","text":"Transforms an image using standard transformations for image-based models.
Parameters:
Name Type Description Defaultimage
Image
The PIL Image to be transformed.
requiredReturns:
Type Descriptiontorch.Tensor
The transformed image as a tensor.
Source code insamgeo/text_sam.py
def transform_image(image: Image) -> torch.Tensor:\n\"\"\"\n Transforms an image using standard transformations for image-based models.\n\n Args:\n image (Image): The PIL Image to be transformed.\n\n Returns:\n torch.Tensor: The transformed image as a tensor.\n \"\"\"\n transform = T.Compose(\n [\n T.RandomResize([800], max_size=1333),\n T.ToTensor(),\n T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n ]\n )\n image_transformed, _ = transform(image, None)\n return image_transformed\n
"},{"location":"usage/","title":"Usage","text":"To use segment-geospatial in a project:
import samgeo\n
Here is a simple example of using segment-geospatial to generate a segmentation mask from a satellite image:
import os\nimport torch\nfrom samgeo import SamGeo, tms_to_geotiff\n\nbbox = [-95.3704, 29.6762, -95.368, 29.6775]\nimage = 'satellite.tif'\ntms_to_geotiff(output=image, bbox=bbox, zoom=20, source='Satellite')\n\nout_dir = os.path.join(os.path.expanduser('~'), 'Downloads')\ncheckpoint = os.path.join(out_dir, 'sam_vit_h_4b8939.pth')\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\nsam = SamGeo(\n checkpoint=checkpoint,\n model_type='vit_h',\n device=device,\n erosion_kernel=(3, 3),\n mask_multiplier=255,\n sam_kwargs=None,\n)\n\nmask = 'segment.tif'\nsam.generate(image, mask)\n\nvector = 'segment.gpkg'\nsam.tiff_to_gpkg(mask, vector, simplify_tolerance=None)\n
"},{"location":"examples/arcgis/","title":"Arcgis","text":"In\u00a0[\u00a0]: Copied! import os\nimport leafmap\nfrom samgeo import SamGeo\n\n%matplotlib inline\nimport os import leafmap from samgeo import SamGeo %matplotlib inline In\u00a0[\u00a0]: Copied!
workspace = os.path.dirname(arcpy.env.workspace)\nos.chdir(workspace)\narcpy.env.overwriteOutput = True\nworkspace = os.path.dirname(arcpy.env.workspace) os.chdir(workspace) arcpy.env.overwriteOutput = True In\u00a0[\u00a0]: Copied!
leafmap.download_file(\n url=\"https://github.com/opengeos/data/blob/main/naip/buildings.tif\",\n quiet=True,\n overwrite=True,\n)\nleafmap.download_file( url=\"https://github.com/opengeos/data/blob/main/naip/buildings.tif\", quiet=True, overwrite=True, ) In\u00a0[\u00a0]: Copied!
leafmap.download_file(\n url=\"https://github.com/opengeos/data/blob/main/naip/agriculture.tif\",\n quiet=True,\n overwrite=True,\n)\nleafmap.download_file( url=\"https://github.com/opengeos/data/blob/main/naip/agriculture.tif\", quiet=True, overwrite=True, ) In\u00a0[\u00a0]: Copied!
leafmap.download_file(\n url=\"https://github.com/opengeos/data/blob/main/naip/water.tif\",\n quiet=True,\n overwrite=True,\n)\nleafmap.download_file( url=\"https://github.com/opengeos/data/blob/main/naip/water.tif\", quiet=True, overwrite=True, ) In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\",\n sam_kwargs=None,\n)\nsam = SamGeo( model_type=\"vit_h\", sam_kwargs=None, ) In\u00a0[\u00a0]: Copied!
image = \"agriculture.tif\"\nimage = \"agriculture.tif\"
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Segment the image and save the results to a GeoTIFF file. Set unique=True
to assign a unique ID to each object.
sam.generate(image, output=\"ag_masks.tif\", foreground=True, unique=True)\nsam.generate(image, output=\"ag_masks.tif\", foreground=True, unique=True)
If you run into GPU memory errors, uncomment the following code block and run it to empty cuda cache then rerun the code block above.
In\u00a0[\u00a0]: Copied!# sam.clear_cuda_cache()\n# sam.clear_cuda_cache()
Show the segmentation result as a grayscale image.
In\u00a0[\u00a0]: Copied!sam.show_masks(cmap=\"binary_r\")\nsam.show_masks(cmap=\"binary_r\")
Show the object annotations (objects with random color) on the map.
In\u00a0[\u00a0]: Copied!sam.show_anns(axis=\"off\", alpha=1, output=\"ag_annotations.tif\")\nsam.show_anns(axis=\"off\", alpha=1, output=\"ag_annotations.tif\")
Add layers to ArcGIS Pro.
In\u00a0[\u00a0]: Copied!m = leafmap.arc_active_map()\nm = leafmap.arc_active_map() In\u00a0[\u00a0]: Copied!
m.addDataFromPath(os.path.join(workspace, \"agriculture.tif\"))\nm.addDataFromPath(os.path.join(workspace, \"agriculture.tif\")) In\u00a0[\u00a0]: Copied!
m.addDataFromPath(os.path.join(workspace, \"ag_annotations.tif\"))\nm.addDataFromPath(os.path.join(workspace, \"ag_annotations.tif\"))
Convert the object annotations to vector format, such as GeoPackage, Shapefile, or GeoJSON.
In\u00a0[\u00a0]: Copied!in_raster = os.path.join(workspace, \"ag_masks.tif\")\nout_shp = os.path.join(workspace, \"ag_masks.shp\")\nin_raster = os.path.join(workspace, \"ag_masks.tif\") out_shp = os.path.join(workspace, \"ag_masks.shp\") In\u00a0[\u00a0]: Copied!
arcpy.conversion.RasterToPolygon(in_raster, out_shp)\narcpy.conversion.RasterToPolygon(in_raster, out_shp) In\u00a0[\u00a0]: Copied!
image = \"water.tif\"\nimage = \"water.tif\" In\u00a0[\u00a0]: Copied!
sam.generate(image, output=\"water_masks.tif\", foreground=True, unique=True)\nsam.generate(image, output=\"water_masks.tif\", foreground=True, unique=True) In\u00a0[\u00a0]: Copied!
# sam.clear_cuda_cache()\n# sam.clear_cuda_cache() In\u00a0[\u00a0]: Copied!
sam.show_masks(cmap=\"binary_r\")\nsam.show_masks(cmap=\"binary_r\") In\u00a0[\u00a0]: Copied!
sam.show_anns(axis=\"off\", alpha=1, output=\"water_annotations.tif\")\nsam.show_anns(axis=\"off\", alpha=1, output=\"water_annotations.tif\") In\u00a0[\u00a0]: Copied!
m.addDataFromPath(os.path.join(workspace, \"water.tif\"))\nm.addDataFromPath(os.path.join(workspace, \"water.tif\")) In\u00a0[\u00a0]: Copied!
m.addDataFromPath(os.path.join(workspace, \"water_annotations.tif\"))\nm.addDataFromPath(os.path.join(workspace, \"water_annotations.tif\")) In\u00a0[\u00a0]: Copied!
in_raster = os.path.join(workspace, \"water_masks.tif\")\nout_shp = os.path.join(workspace, \"water_masks.shp\")\nin_raster = os.path.join(workspace, \"water_masks.tif\") out_shp = os.path.join(workspace, \"water_masks.shp\") In\u00a0[\u00a0]: Copied!
arcpy.conversion.RasterToPolygon(in_raster, out_shp)\narcpy.conversion.RasterToPolygon(in_raster, out_shp) In\u00a0[\u00a0]: Copied!
sam_kwargs = {\n \"points_per_side\": 32,\n \"pred_iou_thresh\": 0.86,\n \"stability_score_thresh\": 0.92,\n \"crop_n_layers\": 1,\n \"crop_n_points_downscale_factor\": 2,\n \"min_mask_region_area\": 100,\n}\nsam_kwargs = { \"points_per_side\": 32, \"pred_iou_thresh\": 0.86, \"stability_score_thresh\": 0.92, \"crop_n_layers\": 1, \"crop_n_points_downscale_factor\": 2, \"min_mask_region_area\": 100, } In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\",\n sam_kwargs=sam_kwargs,\n)\nsam = SamGeo( model_type=\"vit_h\", sam_kwargs=sam_kwargs, ) In\u00a0[\u00a0]: Copied!
sam.generate('agriculture.tif', output=\"ag_masks2.tif\", foreground=True)\nsam.generate('agriculture.tif', output=\"ag_masks2.tif\", foreground=True) In\u00a0[\u00a0]: Copied!
sam.show_masks(cmap=\"binary_r\")\nsam.show_masks(cmap=\"binary_r\") In\u00a0[\u00a0]: Copied!
sam.show_anns(axis=\"off\", alpha=0.5, output=\"ag_annotations2.tif\")\nsam.show_anns(axis=\"off\", alpha=0.5, output=\"ag_annotations2.tif\")"},{"location":"examples/arcgis/#using-the-segment-geospatial-python-package-with-arcgis-pro","title":"Using the Segment-Geospatial Python Package with ArcGIS Pro\u00b6","text":"
The notebook shows step-by-step instructions for using the Segment Anything Model (SAM) with ArcGIS Pro. Check out the YouTube tutorial here and the Resources for Unlocking the Power of Deep Learning Applications Using ArcGIS. Credit goes to Esri.
"},{"location":"examples/arcgis/#installation","title":"Installation\u00b6","text":"Open Windows Registry Editor (regedit.exe
) and navigate to Computer\\HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Control\\FileSystem
. Change the value of LongPathsEnabled
to 1
. See this screenshot. This is a known issue with the deep learning libraries for ArcGIS Pro 3.1. A future release might fix this issue.
Navigate to the Start Menu -> All apps -> ArcGIS folder, then open the Python Command Prompt.
Create a new conda environment and install mamba and Python 3.9.x from the Esri Anaconda channel. Mamba is a drop-in replacement for conda that is mach faster for installing Python packages and their dependencies.
conda create conda-forge::mamba esri::python --name samgeo
Activate the new conda environment.
conda activate samgeo
Install arcpy, deep-learning-essentials, segment-geospatial, and other dependencies (~4GB download).
mamba install arcpy deep-learning-essentials leafmap localtileserver segment-geospatial -c esri -c conda-forge
Activate the new environment in ArcGIS Pro.
proswap samgeo
Close the Python Command Prompt and open ArcGIS Pro.
Download this notebook and run it in ArcGIS Pro.
In this example, we will use the high-resolution aerial imagery from the USDA National Agricultural Imagery Program (NAIP). You can download NAIP imagery using the USDA Data Gateway or the USDA NCRS Box Drive. I have downloaded some NAIP imagery and clipped them to a smaller area, which are available here.
"},{"location":"examples/arcgis/#initialize-sam-class","title":"Initialize SAM class\u00b6","text":"Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
"},{"location":"examples/arcgis/#automatic-mask-generation","title":"Automatic mask generation\u00b6","text":"Specify the file path to the image we downloaded earlier.
"},{"location":"examples/arcgis/#segment-waterbodies","title":"Segment waterbodies\u00b6","text":""},{"location":"examples/arcgis/#automatic-mask-generation-options","title":"Automatic mask generation options\u00b6","text":"There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
"},{"location":"examples/automatic_mask_generator/","title":"Automatic mask generator","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial\n# %pip install segment-geospatial In\u00a0[\u00a0]: Copied!
import os\nimport leafmap\nfrom samgeo import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff\nimport os import leafmap from samgeo import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[37.8713, -122.2580], zoom=17, height=\"800px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[37.8713, -122.2580], zoom=17, height=\"800px\") m.add_basemap(\"SATELLITE\") m
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
In\u00a0[\u00a0]: Copied!if m.user_roi_bounds() is not None:\n bbox = m.user_roi_bounds()\nelse:\n bbox = [-122.2659, 37.8682, -122.2521, 37.8741]\nif m.user_roi_bounds() is not None: bbox = m.user_roi_bounds() else: bbox = [-122.2659, 37.8682, -122.2521, 37.8741] In\u00a0[\u00a0]: Copied!
image = \"satellite.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=17, source=\"Satellite\", overwrite=True)\nimage = \"satellite.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=17, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\",\n sam_kwargs=None,\n)\nsam = SamGeo( model_type=\"vit_h\", sam_kwargs=None, ) In\u00a0[\u00a0]: Copied!
sam.generate(image, output=\"masks.tif\", foreground=True, unique=True)\nsam.generate(image, output=\"masks.tif\", foreground=True, unique=True) In\u00a0[\u00a0]: Copied!
sam.show_masks(cmap=\"binary_r\")\nsam.show_masks(cmap=\"binary_r\")
Show the object annotations (objects with random color) on the map.
In\u00a0[\u00a0]: Copied!sam.show_anns(axis=\"off\", alpha=1, output=\"annotations.tif\")\nsam.show_anns(axis=\"off\", alpha=1, output=\"annotations.tif\")
Compare images with a slider.
In\u00a0[\u00a0]: Copied!leafmap.image_comparison(\n \"satellite.tif\",\n \"annotations.tif\",\n label1=\"Satellite Image\",\n label2=\"Image Segmentation\",\n)\nleafmap.image_comparison( \"satellite.tif\", \"annotations.tif\", label1=\"Satellite Image\", label2=\"Image Segmentation\", )
Add image to the map.
In\u00a0[\u00a0]: Copied!m.add_raster(\"annotations.tif\", alpha=0.5, layer_name=\"Masks\")\nm\nm.add_raster(\"annotations.tif\", alpha=0.5, layer_name=\"Masks\") m
Convert the object annotations to vector format, such as GeoPackage, Shapefile, or GeoJSON.
In\u00a0[\u00a0]: Copied!sam.tiff_to_vector(\"masks.tif\", \"masks.gpkg\")\nsam.tiff_to_vector(\"masks.tif\", \"masks.gpkg\") In\u00a0[\u00a0]: Copied!
sam_kwargs = {\n \"points_per_side\": 32,\n \"pred_iou_thresh\": 0.86,\n \"stability_score_thresh\": 0.92,\n \"crop_n_layers\": 1,\n \"crop_n_points_downscale_factor\": 2,\n \"min_mask_region_area\": 100,\n}\nsam_kwargs = { \"points_per_side\": 32, \"pred_iou_thresh\": 0.86, \"stability_score_thresh\": 0.92, \"crop_n_layers\": 1, \"crop_n_points_downscale_factor\": 2, \"min_mask_region_area\": 100, } In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\",\n sam_kwargs=sam_kwargs,\n)\nsam = SamGeo( model_type=\"vit_h\", sam_kwargs=sam_kwargs, ) In\u00a0[\u00a0]: Copied!
sam.generate(image, output=\"masks2.tif\", foreground=True)\nsam.generate(image, output=\"masks2.tif\", foreground=True) In\u00a0[\u00a0]: Copied!
sam.show_masks(cmap=\"binary_r\")\nsam.show_masks(cmap=\"binary_r\") In\u00a0[\u00a0]: Copied!
sam.show_anns(axis=\"off\", opacity=1, output=\"annotations2.tif\")\nsam.show_anns(axis=\"off\", opacity=1, output=\"annotations2.tif\")
Compare images with a slider.
In\u00a0[\u00a0]: Copied!leafmap.image_comparison(\n image,\n \"annotations.tif\",\n label1=\"Image\",\n label2=\"Image Segmentation\",\n)\nleafmap.image_comparison( image, \"annotations.tif\", label1=\"Image\", label2=\"Image Segmentation\", )
Overlay the annotations on the image and use the slider to change the opacity interactively.
In\u00a0[\u00a0]: Copied!overlay_images(image, \"annotations2.tif\", backend=\"TkAgg\")\noverlay_images(image, \"annotations2.tif\", backend=\"TkAgg\") "},{"location":"examples/automatic_mask_generator/#automatically-generating-object-masks-with-sam","title":"Automatically generating object masks with SAM\u00b6","text":"
This notebook shows how to segment objects from an image using the Segment Anything Model (SAM) with a few lines of code.
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
The notebook is adapted from segment-anything/notebooks/automatic_mask_generator_example.ipynb, but I have made it much easier to save the segmentation results and visualize them.
"},{"location":"examples/automatic_mask_generator/#install-dependencies","title":"Install dependencies\u00b6","text":"Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/automatic_mask_generator/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/automatic_mask_generator/#download-a-sample-image","title":"Download a sample image\u00b6","text":""},{"location":"examples/automatic_mask_generator/#initialize-sam-class","title":"Initialize SAM class\u00b6","text":"Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
"},{"location":"examples/automatic_mask_generator/#automatic-mask-generation","title":"Automatic mask generation\u00b6","text":"Segment the image and save the results to a GeoTIFF file. Set unique=True
to assign a unique ID to each object.
There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
"},{"location":"examples/automatic_mask_generator_hq/","title":"Automatic mask generator hq","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial\n# %pip install segment-geospatial In\u00a0[\u00a0]: Copied!
import os\nimport leafmap\nfrom samgeo.hq_sam import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff\nimport os import leafmap from samgeo.hq_sam import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[37.8713, -122.2580], zoom=17, height=\"800px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[37.8713, -122.2580], zoom=17, height=\"800px\") m.add_basemap(\"SATELLITE\") m
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
In\u00a0[\u00a0]: Copied!if m.user_roi_bounds() is not None:\n bbox = m.user_roi_bounds()\nelse:\n bbox = [-122.2659, 37.8682, -122.2521, 37.8741]\nif m.user_roi_bounds() is not None: bbox = m.user_roi_bounds() else: bbox = [-122.2659, 37.8682, -122.2521, 37.8741] In\u00a0[\u00a0]: Copied!
image = \"satellite.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=17, source=\"Satellite\", overwrite=True)\nimage = \"satellite.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=17, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\", # can be vit_h, vit_b, vit_l, vit_tiny\n sam_kwargs=None,\n)\nsam = SamGeo( model_type=\"vit_h\", # can be vit_h, vit_b, vit_l, vit_tiny sam_kwargs=None, ) In\u00a0[\u00a0]: Copied!
sam.generate(image, output=\"masks.tif\", foreground=True, unique=True)\nsam.generate(image, output=\"masks.tif\", foreground=True, unique=True) In\u00a0[\u00a0]: Copied!
sam.show_masks(cmap=\"binary_r\")\nsam.show_masks(cmap=\"binary_r\")
Show the object annotations (objects with random color) on the map.
In\u00a0[\u00a0]: Copied!sam.show_anns(axis=\"off\", alpha=1, output=\"annotations.tif\")\nsam.show_anns(axis=\"off\", alpha=1, output=\"annotations.tif\")
Compare images with a slider.
In\u00a0[\u00a0]: Copied!leafmap.image_comparison(\n \"satellite.tif\",\n \"annotations.tif\",\n label1=\"Satellite Image\",\n label2=\"Image Segmentation\",\n)\nleafmap.image_comparison( \"satellite.tif\", \"annotations.tif\", label1=\"Satellite Image\", label2=\"Image Segmentation\", )
Add image to the map.
In\u00a0[\u00a0]: Copied!m.add_raster(\"annotations.tif\", alpha=0.5, layer_name=\"Masks\")\nm\nm.add_raster(\"annotations.tif\", alpha=0.5, layer_name=\"Masks\") m
Convert the object annotations to vector format, such as GeoPackage, Shapefile, or GeoJSON.
In\u00a0[\u00a0]: Copied!sam.tiff_to_vector(\"masks.tif\", \"masks.gpkg\")\nsam.tiff_to_vector(\"masks.tif\", \"masks.gpkg\") In\u00a0[\u00a0]: Copied!
sam_kwargs = {\n \"points_per_side\": 32,\n \"pred_iou_thresh\": 0.86,\n \"stability_score_thresh\": 0.92,\n \"crop_n_layers\": 1,\n \"crop_n_points_downscale_factor\": 2,\n \"min_mask_region_area\": 100,\n}\nsam_kwargs = { \"points_per_side\": 32, \"pred_iou_thresh\": 0.86, \"stability_score_thresh\": 0.92, \"crop_n_layers\": 1, \"crop_n_points_downscale_factor\": 2, \"min_mask_region_area\": 100, } In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\",\n sam_kwargs=sam_kwargs,\n)\nsam = SamGeo( model_type=\"vit_h\", sam_kwargs=sam_kwargs, ) In\u00a0[17]: Copied!
sam.generate(image, output=\"masks2.tif\", foreground=True)\nsam.generate(image, output=\"masks2.tif\", foreground=True) In\u00a0[\u00a0]: Copied!
sam.show_masks(cmap=\"binary_r\")\nsam.show_masks(cmap=\"binary_r\") In\u00a0[\u00a0]: Copied!
sam.show_anns(axis=\"off\", opacity=1, output=\"annotations2.tif\")\nsam.show_anns(axis=\"off\", opacity=1, output=\"annotations2.tif\")
Compare images with a slider.
In\u00a0[\u00a0]: Copied!leafmap.image_comparison(\n image,\n \"annotations.tif\",\n label1=\"Image\",\n label2=\"Image Segmentation\",\n)\nleafmap.image_comparison( image, \"annotations.tif\", label1=\"Image\", label2=\"Image Segmentation\", )
Overlay the annotations on the image and use the slider to change the opacity interactively.
In\u00a0[\u00a0]: Copied!overlay_images(image, \"annotations2.tif\", backend=\"TkAgg\")\noverlay_images(image, \"annotations2.tif\", backend=\"TkAgg\") "},{"location":"examples/automatic_mask_generator_hq/#automatically-generating-object-masks-with-hq-sam","title":"Automatically generating object masks with HQ-SAM\u00b6","text":"
This notebook shows how to segment objects from an image using the High-Quality Segment Anything Model (HQ-SAM) with a few lines of code.
"},{"location":"examples/automatic_mask_generator_hq/#install-dependencies","title":"Install dependencies\u00b6","text":"Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/automatic_mask_generator_hq/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/automatic_mask_generator_hq/#download-a-sample-image","title":"Download a sample image\u00b6","text":""},{"location":"examples/automatic_mask_generator_hq/#initialize-sam-class","title":"Initialize SAM class\u00b6","text":"Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
"},{"location":"examples/automatic_mask_generator_hq/#automatic-mask-generation","title":"Automatic mask generation\u00b6","text":"Segment the image and save the results to a GeoTIFF file. Set unique=True
to assign a unique ID to each object.
There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
"},{"location":"examples/box_prompts/","title":"Box prompts","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial\n# %pip install segment-geospatial In\u00a0[\u00a0]: Copied!
import leafmap\nfrom samgeo import tms_to_geotiff\nfrom samgeo import SamGeo\nimport leafmap from samgeo import tms_to_geotiff from samgeo import SamGeo In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[-22.17615, -51.253043], zoom=18, height=\"800px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[-22.17615, -51.253043], zoom=18, height=\"800px\") m.add_basemap(\"SATELLITE\") m In\u00a0[\u00a0]: Copied!
bbox = m.user_roi_bounds()\nif bbox is None:\n bbox = [-51.2565, -22.1777, -51.2512, -22.175]\nbbox = m.user_roi_bounds() if bbox is None: bbox = [-51.2565, -22.1777, -51.2512, -22.175] In\u00a0[\u00a0]: Copied!
image = \"Image.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)\nimage = \"Image.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\",\n automatic=False,\n sam_kwargs=None,\n)\nsam = SamGeo( model_type=\"vit_h\", automatic=False, sam_kwargs=None, )
Specify the image to segment.
In\u00a0[\u00a0]: Copied!sam.set_image(image)\nsam.set_image(image)
Display the map. Use the drawing tools to draw some rectangles around the features you want to extract, such as trees, buildings.
In\u00a0[\u00a0]: Copied!m\nm In\u00a0[\u00a0]: Copied!
if m.user_rois is not None:\n boxes = m.user_rois\nelse:\n boxes = [\n [-51.2546, -22.1771, -51.2541, -22.1767],\n [-51.2538, -22.1764, -51.2535, -22.1761],\n ]\nif m.user_rois is not None: boxes = m.user_rois else: boxes = [ [-51.2546, -22.1771, -51.2541, -22.1767], [-51.2538, -22.1764, -51.2535, -22.1761], ] In\u00a0[\u00a0]: Copied!
sam.predict(boxes=boxes, point_crs=\"EPSG:4326\", output=\"mask.tif\", dtype=\"uint8\")\nsam.predict(boxes=boxes, point_crs=\"EPSG:4326\", output=\"mask.tif\", dtype=\"uint8\") In\u00a0[\u00a0]: Copied!
m.add_raster('mask.tif', cmap='viridis', nodata=0, layer_name='Mask')\nm\nm.add_raster('mask.tif', cmap='viridis', nodata=0, layer_name='Mask') m In\u00a0[\u00a0]: Copied!
url = 'https://opengeos.github.io/data/sam/tree_boxes.geojson'\ngeojson = \"tree_boxes.geojson\"\nleafmap.download_file(url, geojson)\nurl = 'https://opengeos.github.io/data/sam/tree_boxes.geojson' geojson = \"tree_boxes.geojson\" leafmap.download_file(url, geojson)
Display the vector data on the map.
In\u00a0[\u00a0]: Copied!m = leafmap.Map()\nm.add_raster(\"Image.tif\", layer_name=\"image\")\nstyle = {\n \"color\": \"#ffff00\",\n \"weight\": 2,\n \"fillColor\": \"#7c4185\",\n \"fillOpacity\": 0,\n}\nm.add_vector(geojson, style=style, zoom_to_layer=True, layer_name=\"Bounding boxes\")\nm\nm = leafmap.Map() m.add_raster(\"Image.tif\", layer_name=\"image\") style = { \"color\": \"#ffff00\", \"weight\": 2, \"fillColor\": \"#7c4185\", \"fillOpacity\": 0, } m.add_vector(geojson, style=style, zoom_to_layer=True, layer_name=\"Bounding boxes\") m In\u00a0[\u00a0]: Copied!
sam.predict(boxes=geojson, point_crs=\"EPSG:4326\", output=\"mask2.tif\", dtype=\"uint8\")\nsam.predict(boxes=geojson, point_crs=\"EPSG:4326\", output=\"mask2.tif\", dtype=\"uint8\")
Display the segmented masks on the map.
In\u00a0[\u00a0]: Copied!m.add_raster(\"mask2.tif\", cmap=\"Greens\", nodata=0, opacity=0.5, layer_name=\"Tree masks\")\nm\nm.add_raster(\"mask2.tif\", cmap=\"Greens\", nodata=0, opacity=0.5, layer_name=\"Tree masks\") m "},{"location":"examples/box_prompts/#segmenting-remote-sensing-imagery-with-box-prompts","title":"Segmenting remote sensing imagery with box prompts\u00b6","text":"
This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM).
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/box_prompts/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/box_prompts/#download-a-sample-image","title":"Download a sample image\u00b6","text":"Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
"},{"location":"examples/box_prompts/#initialize-sam-class","title":"Initialize SAM class\u00b6","text":"The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
Set automatic=False
to disable the SamAutomaticMaskGenerator
and enable the SamPredictor
.
If no rectangles are drawn, the default bounding boxes will be used as follows:
"},{"location":"examples/box_prompts/#segment-the-image","title":"Segment the image\u00b6","text":"Use the predict()
method to segment the image with specified bounding boxes. The boxes
parameter accepts a list of bounding box coordinates in the format of [[left, bottom, right, top], [left, bottom, right, top], ...], a GeoJSON dictionary, or a file path to a GeoJSON file.
Add the segmented image to the map.
"},{"location":"examples/box_prompts/#use-an-existing-vector-file-as-box-prompts","title":"Use an existing vector file as box prompts\u00b6","text":"Alternatively, you can specify a file path to a vector file. Let's download a sample vector file from GitHub.
"},{"location":"examples/box_prompts/#segment-image-with-box-prompts","title":"Segment image with box prompts\u00b6","text":"Segment the image using the specified file path to the vector mask.
"},{"location":"examples/input_prompts/","title":"Input prompts","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial\n# %pip install segment-geospatial In\u00a0[1]: Copied!
import os\nimport leafmap\nfrom samgeo import SamGeo, tms_to_geotiff\nimport os import leafmap from samgeo import SamGeo, tms_to_geotiff In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height=\"800px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height=\"800px\") m.add_basemap(\"SATELLITE\") m In\u00a0[\u00a0]: Copied!
if m.user_roi is not None:\n bbox = m.user_roi_bounds()\nelse:\n bbox = [-122.1497, 37.6311, -122.1203, 37.6458]\nif m.user_roi is not None: bbox = m.user_roi_bounds() else: bbox = [-122.1497, 37.6311, -122.1203, 37.6458] In\u00a0[\u00a0]: Copied!
image = \"satellite.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=16, source=\"Satellite\", overwrite=True)\nimage = \"satellite.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=16, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m
Set automatic=False
to disable the SamAutomaticMaskGenerator
and enable the SamPredictor
.
sam = SamGeo(\n model_type=\"vit_h\",\n automatic=False,\n sam_kwargs=None,\n)\nsam = SamGeo( model_type=\"vit_h\", automatic=False, sam_kwargs=None, )
Specify the image to segment.
In\u00a0[\u00a0]: Copied!sam.set_image(image)\nsam.set_image(image) In\u00a0[\u00a0]: Copied!
point_coords = [[-122.1419, 37.6383]]\nsam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask1.tif\")\nm.add_raster(\"mask1.tif\", layer_name=\"Mask1\", nodata=0, cmap=\"Blues\", opacity=1)\nm\npoint_coords = [[-122.1419, 37.6383]] sam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask1.tif\") m.add_raster(\"mask1.tif\", layer_name=\"Mask1\", nodata=0, cmap=\"Blues\", opacity=1) m
Try multiple points input:
In\u00a0[\u00a0]: Copied!point_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]]\nsam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask2.tif\")\nm.add_raster(\"mask2.tif\", layer_name=\"Mask2\", nodata=0, cmap=\"Greens\", opacity=1)\nm\npoint_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]] sam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask2.tif\") m.add_raster(\"mask2.tif\", layer_name=\"Mask2\", nodata=0, cmap=\"Greens\", opacity=1) m In\u00a0[\u00a0]: Copied!
m = sam.show_map()\nm\nm = sam.show_map() m "},{"location":"examples/input_prompts/#generating-object-masks-from-input-prompts-with-sam","title":"Generating object masks from input prompts with SAM\u00b6","text":"
This notebook shows how to generate object masks from input prompts with the Segment Anything Model (SAM).
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
The notebook is adapted from segment-anything/notebooks/predictor_example.ipynb, but I have made it much easier to save the segmentation results and visualize them.
"},{"location":"examples/input_prompts/#install-dependencies","title":"Install dependencies\u00b6","text":"Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/input_prompts/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/input_prompts/#download-a-sample-image","title":"Download a sample image\u00b6","text":"Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
"},{"location":"examples/input_prompts/#initialize-sam-class","title":"Initialize SAM class\u00b6","text":"Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
"},{"location":"examples/input_prompts/#image-segmentation-with-input-points","title":"Image segmentation with input points\u00b6","text":"A single point can be used to segment an object. The point can be specified as a tuple of (x, y), such as (col, row) or (lon, lat). The points can also be specified as a file path to a vector dataset. For non (col, row) input points, specify the point_crs
parameter, which will automatically transform the points to the image column and row coordinates.
Try a single point input:
"},{"location":"examples/input_prompts/#interactive-segmentation","title":"Interactive segmentation\u00b6","text":"Display the interactive map and use the marker tool to draw points on the map. Then click on the Segment
button to segment the objects. The results will be added to the map automatically. Click on the Reset
button to clear the points and the results.
# %pip install segment-geospatial\n# %pip install segment-geospatial In\u00a0[\u00a0]: Copied!
import os\nimport leafmap\nfrom samgeo.hq_sam import SamGeo, tms_to_geotiff\nimport os import leafmap from samgeo.hq_sam import SamGeo, tms_to_geotiff In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height=\"800px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height=\"800px\") m.add_basemap(\"SATELLITE\") m In\u00a0[\u00a0]: Copied!
if m.user_roi is not None:\n bbox = m.user_roi_bounds()\nelse:\n bbox = [-122.1497, 37.6311, -122.1203, 37.6458]\nif m.user_roi is not None: bbox = m.user_roi_bounds() else: bbox = [-122.1497, 37.6311, -122.1203, 37.6458] In\u00a0[\u00a0]: Copied!
image = \"satellite.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=16, source=\"Satellite\", overwrite=True)\nimage = \"satellite.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=16, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m
Set automatic=False
to disable the SamAutomaticMaskGenerator
and enable the SamPredictor
.
sam = SamGeo(\n model_type=\"vit_h\", # can be vit_h, vit_b, vit_l, vit_tiny\n automatic=False,\n sam_kwargs=None,\n)\nsam = SamGeo( model_type=\"vit_h\", # can be vit_h, vit_b, vit_l, vit_tiny automatic=False, sam_kwargs=None, )
Specify the image to segment.
In\u00a0[\u00a0]: Copied!sam.set_image(image)\nsam.set_image(image) In\u00a0[\u00a0]: Copied!
point_coords = [[-122.1419, 37.6383]]\nsam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask1.tif\")\nm.add_raster(\"mask1.tif\", layer_name=\"Mask1\", nodata=0, cmap=\"Blues\", opacity=1)\nm\npoint_coords = [[-122.1419, 37.6383]] sam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask1.tif\") m.add_raster(\"mask1.tif\", layer_name=\"Mask1\", nodata=0, cmap=\"Blues\", opacity=1) m
Try multiple points input:
In\u00a0[\u00a0]: Copied!point_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]]\nsam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask2.tif\")\nm.add_raster(\"mask2.tif\", layer_name=\"Mask2\", nodata=0, cmap=\"Greens\", opacity=1)\nm\npoint_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]] sam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output=\"mask2.tif\") m.add_raster(\"mask2.tif\", layer_name=\"Mask2\", nodata=0, cmap=\"Greens\", opacity=1) m In\u00a0[\u00a0]: Copied!
m = sam.show_map()\nm\nm = sam.show_map() m "},{"location":"examples/input_prompts_hq/#generating-object-masks-from-input-prompts-with-hq-sam","title":"Generating object masks from input prompts with HQ-SAM\u00b6","text":"
This notebook shows how to generate object masks from input prompts with the High-Quality Segment Anything Model (HQ-SAM).
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/input_prompts_hq/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/input_prompts_hq/#download-a-sample-image","title":"Download a sample image\u00b6","text":"Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
"},{"location":"examples/input_prompts_hq/#initialize-sam-class","title":"Initialize SAM class\u00b6","text":"Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.
"},{"location":"examples/input_prompts_hq/#image-segmentation-with-input-points","title":"Image segmentation with input points\u00b6","text":"A single point can be used to segment an object. The point can be specified as a tuple of (x, y), such as (col, row) or (lon, lat). The points can also be specified as a file path to a vector dataset. For non (col, row) input points, specify the point_crs
parameter, which will automatically transform the points to the image column and row coordinates.
Try a single point input:
"},{"location":"examples/input_prompts_hq/#interactive-segmentation","title":"Interactive segmentation\u00b6","text":"Display the interactive map and use the marker tool to draw points on the map. Then click on the Segment
button to segment the objects. The results will be added to the map automatically. Click on the Reset
button to clear the points and the results.
# %pip install segment-geospatial\n# %pip install segment-geospatial In\u00a0[\u00a0]: Copied!
import os\nimport leafmap\nfrom samgeo import SamGeoPredictor, tms_to_geotiff, get_basemaps\nfrom segment_anything import sam_model_registry\nimport os import leafmap from samgeo import SamGeoPredictor, tms_to_geotiff, get_basemaps from segment_anything import sam_model_registry In\u00a0[\u00a0]: Copied!
zoom = 16\nm = leafmap.Map(center=[45, -123], zoom=zoom)\nm.add_basemap(\"SATELLITE\")\nm\nzoom = 16 m = leafmap.Map(center=[45, -123], zoom=zoom) m.add_basemap(\"SATELLITE\") m
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
In\u00a0[\u00a0]: Copied!if m.user_roi_bounds() is not None:\n bbox = m.user_roi_bounds()\nelse:\n bbox = [-123.0127, 44.9957, -122.9874, 45.0045]\nif m.user_roi_bounds() is not None: bbox = m.user_roi_bounds() else: bbox = [-123.0127, 44.9957, -122.9874, 45.0045] In\u00a0[\u00a0]: Copied!
image = \"satellite.tif\"\n# image = '/path/to/your/own/image.tif'\nimage = \"satellite.tif\" # image = '/path/to/your/own/image.tif'
Besides the satellite
basemap, you can use any of the following basemaps returned by the get_basemaps()
function:
# get_basemaps().keys()\n# get_basemaps().keys()
Specify the basemap as the source.
In\u00a0[\u00a0]: Copied!tms_to_geotiff(\n output=image, bbox=bbox, zoom=zoom + 1, source=\"Satellite\", overwrite=True\n)\ntms_to_geotiff( output=image, bbox=bbox, zoom=zoom + 1, source=\"Satellite\", overwrite=True ) In\u00a0[\u00a0]: Copied!
m.add_raster(image, layer_name=\"Image\")\nm\nm.add_raster(image, layer_name=\"Image\") m
Use the draw tools to draw a rectangle from which to subset segmentations on the map
In\u00a0[\u00a0]: Copied!if m.user_roi_bounds() is not None:\n clip_box = m.user_roi_bounds()\nelse:\n clip_box = [-123.0064, 44.9988, -123.0005, 45.0025]\nif m.user_roi_bounds() is not None: clip_box = m.user_roi_bounds() else: clip_box = [-123.0064, 44.9988, -123.0005, 45.0025] In\u00a0[\u00a0]: Copied!
clip_box\nclip_box In\u00a0[\u00a0]: Copied!
out_dir = os.path.join(os.path.expanduser(\"~\"), \"Downloads\")\ncheckpoint = os.path.join(out_dir, \"sam_vit_h_4b8939.pth\")\nout_dir = os.path.join(os.path.expanduser(\"~\"), \"Downloads\") checkpoint = os.path.join(out_dir, \"sam_vit_h_4b8939.pth\") In\u00a0[\u00a0]: Copied!
import cv2\n\nimg_arr = cv2.imread(image)\n\nmodel_type = \"vit_h\"\n\nsam = sam_model_registry[model_type](checkpoint=checkpoint)\n\npredictor = SamGeoPredictor(sam)\n\npredictor.set_image(img_arr)\n\nmasks, _, _ = predictor.predict(src_fp=image, geo_box=clip_box)\nimport cv2 img_arr = cv2.imread(image) model_type = \"vit_h\" sam = sam_model_registry[model_type](checkpoint=checkpoint) predictor = SamGeoPredictor(sam) predictor.set_image(img_arr) masks, _, _ = predictor.predict(src_fp=image, geo_box=clip_box) In\u00a0[\u00a0]: Copied!
masks_img = \"preds.tif\"\npredictor.masks_to_geotiff(image, masks_img, masks.astype(\"uint8\"))\nmasks_img = \"preds.tif\" predictor.masks_to_geotiff(image, masks_img, masks.astype(\"uint8\")) In\u00a0[\u00a0]: Copied!
vector = \"feats.geojson\"\ngdf = predictor.geotiff_to_geojson(masks_img, vector, bidx=1)\ngdf.plot()\nvector = \"feats.geojson\" gdf = predictor.geotiff_to_geojson(masks_img, vector, bidx=1) gdf.plot() In\u00a0[\u00a0]: Copied!
style = {\n \"color\": \"#3388ff\",\n \"weight\": 2,\n \"fillColor\": \"#7c4185\",\n \"fillOpacity\": 0.5,\n}\nm.add_vector(vector, layer_name=\"Vector\", style=style)\nm\nstyle = { \"color\": \"#3388ff\", \"weight\": 2, \"fillColor\": \"#7c4185\", \"fillOpacity\": 0.5, } m.add_vector(vector, layer_name=\"Vector\", style=style) m"},{"location":"examples/satellite-predictor/#segment-anything-model-for-geospatial-data","title":"Segment Anything Model for Geospatial Data\u00b6","text":"
This notebook shows how to use segment satellite imagery using the Segment Anything Model (SAM) with a few lines of code.
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/satellite-predictor/#import-libraries","title":"Import libraries\u00b6","text":""},{"location":"examples/satellite-predictor/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/satellite-predictor/#download-map-tiles","title":"Download map tiles\u00b6","text":"Download maps tiles and mosaic them into a single GeoTIFF file
"},{"location":"examples/satellite-predictor/#initialize-samgeopredictor-class","title":"Initialize SamGeoPredictor class\u00b6","text":""},{"location":"examples/satellite-predictor/#visualize-the-results","title":"Visualize the results\u00b6","text":""},{"location":"examples/satellite/","title":"Satellite","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial\n# %pip install segment-geospatial In\u00a0[\u00a0]: Copied!
import os\nimport leafmap\nfrom samgeo import SamGeo, tms_to_geotiff, get_basemaps\nimport os import leafmap from samgeo import SamGeo, tms_to_geotiff, get_basemaps In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[29.676840, -95.369222], zoom=19)\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[29.676840, -95.369222], zoom=19) m.add_basemap(\"SATELLITE\") m
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
In\u00a0[\u00a0]: Copied!if m.user_roi_bounds() is not None:\n bbox = m.user_roi_bounds()\nelse:\n bbox = [-95.3704, 29.6762, -95.368, 29.6775]\nif m.user_roi_bounds() is not None: bbox = m.user_roi_bounds() else: bbox = [-95.3704, 29.6762, -95.368, 29.6775] In\u00a0[\u00a0]: Copied!
image = \"satellite.tif\"\nimage = \"satellite.tif\"
Besides the satellite
basemap, you can use any of the following basemaps returned by the get_basemaps()
function:
# get_basemaps().keys()\n# get_basemaps().keys()
Specify the basemap as the source.
In\u00a0[\u00a0]: Copied!tms_to_geotiff(output=image, bbox=bbox, zoom=20, source=\"Satellite\", overwrite=True)\ntms_to_geotiff(output=image, bbox=bbox, zoom=20, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False # turn off the basemap\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False # turn off the basemap m.add_raster(image, layer_name=\"Image\") m In\u00a0[\u00a0]: Copied!
sam = SamGeo(\n model_type=\"vit_h\",\n checkpoint=\"sam_vit_h_4b8939.pth\",\n sam_kwargs=None,\n)\nsam = SamGeo( model_type=\"vit_h\", checkpoint=\"sam_vit_h_4b8939.pth\", sam_kwargs=None, ) In\u00a0[\u00a0]: Copied!
mask = \"segment.tif\"\nsam.generate(\n image, mask, batch=True, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255\n)\nmask = \"segment.tif\" sam.generate( image, mask, batch=True, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255 ) In\u00a0[\u00a0]: Copied!
vector = \"segment.gpkg\"\nsam.tiff_to_gpkg(mask, vector, simplify_tolerance=None)\nvector = \"segment.gpkg\" sam.tiff_to_gpkg(mask, vector, simplify_tolerance=None)
You can also save the segmentation results as any vector data format supported by GeoPandas.
In\u00a0[\u00a0]: Copied!shapefile = \"segment.shp\"\nsam.tiff_to_vector(mask, shapefile)\nshapefile = \"segment.shp\" sam.tiff_to_vector(mask, shapefile) In\u00a0[\u00a0]: Copied!
style = {\n \"color\": \"#3388ff\",\n \"weight\": 2,\n \"fillColor\": \"#7c4185\",\n \"fillOpacity\": 0.5,\n}\nm.add_vector(vector, layer_name=\"Vector\", style=style)\nm\nstyle = { \"color\": \"#3388ff\", \"weight\": 2, \"fillColor\": \"#7c4185\", \"fillOpacity\": 0.5, } m.add_vector(vector, layer_name=\"Vector\", style=style) m "},{"location":"examples/satellite/#segment-anything-model-for-geospatial-data","title":"Segment Anything Model for Geospatial Data\u00b6","text":"
This notebook shows how to use segment satellite imagery using the Segment Anything Model (SAM) with a few lines of code.
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/satellite/#import-libraries","title":"Import libraries\u00b6","text":""},{"location":"examples/satellite/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/satellite/#download-map-tiles","title":"Download map tiles\u00b6","text":"Download maps tiles and mosaic them into a single GeoTIFF file
"},{"location":"examples/satellite/#initialize-sam-class","title":"Initialize SAM class\u00b6","text":""},{"location":"examples/satellite/#segment-the-image","title":"Segment the image\u00b6","text":"Set batch=True
to segment the image in batches. This is useful for large images that cannot fit in memory.
Save the segmentation results as a GeoPackage file.
"},{"location":"examples/satellite/#visualize-the-results","title":"Visualize the results\u00b6","text":""},{"location":"examples/swimming_pools/","title":"Swimming pools","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial groundingdino-py leafmap localtileserver\n# %pip install segment-geospatial groundingdino-py leafmap localtileserver In\u00a0[\u00a0]: Copied!
import leafmap\nfrom samgeo import tms_to_geotiff\nfrom samgeo.text_sam import LangSAM\nimport leafmap from samgeo import tms_to_geotiff from samgeo.text_sam import LangSAM In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[34.040984, -118.491668], zoom=19, height=\"600px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[34.040984, -118.491668], zoom=19, height=\"600px\") m.add_basemap(\"SATELLITE\") m In\u00a0[\u00a0]: Copied!
bbox = m.user_roi_bounds()\nif bbox is None:\n bbox = [-118.4932, 34.0404, -118.4903, 34.0417]\nbbox = m.user_roi_bounds() if bbox is None: bbox = [-118.4932, 34.0404, -118.4903, 34.0417] In\u00a0[\u00a0]: Copied!
image = \"Image.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)\nimage = \"Image.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m In\u00a0[\u00a0]: Copied!
sam = LangSAM()\nsam = LangSAM() In\u00a0[\u00a0]: Copied!
text_prompt = \"swimming pool\"\ntext_prompt = \"swimming pool\" In\u00a0[\u00a0]: Copied!
sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24)\nsam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24) In\u00a0[\u00a0]: Copied!
sam.show_anns(\n cmap='Blues',\n box_color='red',\n title='Automatic Segmentation of Swimming Pools',\n blend=True,\n)\nsam.show_anns( cmap='Blues', box_color='red', title='Automatic Segmentation of Swimming Pools', blend=True, )
Show the result without bounding boxes on the map.
In\u00a0[\u00a0]: Copied!sam.show_anns(\n cmap='Blues',\n add_boxes=False,\n alpha=0.5,\n title='Automatic Segmentation of Swimming Pools',\n)\nsam.show_anns( cmap='Blues', add_boxes=False, alpha=0.5, title='Automatic Segmentation of Swimming Pools', )
Show the result as a grayscale image.
In\u00a0[\u00a0]: Copied!sam.show_anns(\n cmap='Greys_r',\n add_boxes=False,\n alpha=1,\n title='Automatic Segmentation of Swimming Pools',\n blend=False,\n output='pools.tif',\n)\nsam.show_anns( cmap='Greys_r', add_boxes=False, alpha=1, title='Automatic Segmentation of Swimming Pools', blend=False, output='pools.tif', )
Convert the result to a vector format.
In\u00a0[\u00a0]: Copied!sam.raster_to_vector(\"pools.tif\", \"pools.shp\")\nsam.raster_to_vector(\"pools.tif\", \"pools.shp\")
Show the results on the interactive map.
In\u00a0[\u00a0]: Copied!m.add_raster(\"pools.tif\", layer_name=\"Pools\", palette=\"Blues\", opacity=0.5, nodata=0)\nstyle = {\n \"color\": \"#3388ff\",\n \"weight\": 2,\n \"fillColor\": \"#7c4185\",\n \"fillOpacity\": 0.5,\n}\nm.add_vector(\"pools.shp\", layer_name=\"Vector\", style=style)\nm\nm.add_raster(\"pools.tif\", layer_name=\"Pools\", palette=\"Blues\", opacity=0.5, nodata=0) style = { \"color\": \"#3388ff\", \"weight\": 2, \"fillColor\": \"#7c4185\", \"fillOpacity\": 0.5, } m.add_vector(\"pools.shp\", layer_name=\"Vector\", style=style) m In\u00a0[\u00a0]: Copied!
sam.show_map()\nsam.show_map() "},{"location":"examples/swimming_pools/#mapping-swimming-pools-with-text-prompts","title":"Mapping swimming pools with text prompts\u00b6","text":"
This notebook shows how to map swimming pools with text prompts and the Segment Anything Model (SAM).
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/swimming_pools/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/swimming_pools/#download-a-sample-image","title":"Download a sample image\u00b6","text":"Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
"},{"location":"examples/swimming_pools/#initialize-langsam-class","title":"Initialize LangSAM class\u00b6","text":"The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
"},{"location":"examples/swimming_pools/#specify-text-prompts","title":"Specify text prompts\u00b6","text":""},{"location":"examples/swimming_pools/#segment-the-image","title":"Segment the image\u00b6","text":"Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.
box_threshold
: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.
text_threshold
: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.
Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.
"},{"location":"examples/swimming_pools/#visualize-the-results","title":"Visualize the results\u00b6","text":"Show the result with bounding boxes on the map.
"},{"location":"examples/swimming_pools/#interactive-segmentation","title":"Interactive segmentation\u00b6","text":""},{"location":"examples/text_prompts/","title":"Text prompts","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial groundingdino-py leafmap localtileserver\n# %pip install segment-geospatial groundingdino-py leafmap localtileserver In\u00a0[\u00a0]: Copied!
import leafmap\nfrom samgeo import tms_to_geotiff\nfrom samgeo.text_sam import LangSAM\nimport leafmap from samgeo import tms_to_geotiff from samgeo.text_sam import LangSAM In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[-22.17615, -51.253043], zoom=18, height=\"800px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[-22.17615, -51.253043], zoom=18, height=\"800px\") m.add_basemap(\"SATELLITE\") m In\u00a0[\u00a0]: Copied!
bbox = m.user_roi_bounds()\nif bbox is None:\n bbox = [-51.2565, -22.1777, -51.2512, -22.175]\nbbox = m.user_roi_bounds() if bbox is None: bbox = [-51.2565, -22.1777, -51.2512, -22.175] In\u00a0[\u00a0]: Copied!
image = \"Image.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)\nimage = \"Image.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m In\u00a0[\u00a0]: Copied!
sam = LangSAM()\nsam = LangSAM() In\u00a0[\u00a0]: Copied!
text_prompt = \"tree\"\ntext_prompt = \"tree\" In\u00a0[\u00a0]: Copied!
sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24)\nsam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24) In\u00a0[\u00a0]: Copied!
sam.show_anns(\n cmap='Greens',\n box_color='red',\n title='Automatic Segmentation of Trees',\n blend=True,\n)\nsam.show_anns( cmap='Greens', box_color='red', title='Automatic Segmentation of Trees', blend=True, )
Show the result without bounding boxes on the map.
In\u00a0[\u00a0]: Copied!sam.show_anns(\n cmap='Greens',\n add_boxes=False,\n alpha=0.5,\n title='Automatic Segmentation of Trees',\n)\nsam.show_anns( cmap='Greens', add_boxes=False, alpha=0.5, title='Automatic Segmentation of Trees', )
Show the result as a grayscale image.
In\u00a0[\u00a0]: Copied!sam.show_anns(\n cmap='Greys_r',\n add_boxes=False,\n alpha=1,\n title='Automatic Segmentation of Trees',\n blend=False,\n output='trees.tif',\n)\nsam.show_anns( cmap='Greys_r', add_boxes=False, alpha=1, title='Automatic Segmentation of Trees', blend=False, output='trees.tif', )
Convert the result to a vector format.
In\u00a0[\u00a0]: Copied!sam.raster_to_vector(\"trees.tif\", \"trees.shp\")\nsam.raster_to_vector(\"trees.tif\", \"trees.shp\")
Show the results on the interactive map.
In\u00a0[\u00a0]: Copied!m.add_raster(\"trees.tif\", layer_name=\"Trees\", palette=\"Greens\", opacity=0.5, nodata=0)\nstyle = {\n \"color\": \"#3388ff\",\n \"weight\": 2,\n \"fillColor\": \"#7c4185\",\n \"fillOpacity\": 0.5,\n}\nm.add_vector(\"trees.shp\", layer_name=\"Vector\", style=style)\nm\nm.add_raster(\"trees.tif\", layer_name=\"Trees\", palette=\"Greens\", opacity=0.5, nodata=0) style = { \"color\": \"#3388ff\", \"weight\": 2, \"fillColor\": \"#7c4185\", \"fillOpacity\": 0.5, } m.add_vector(\"trees.shp\", layer_name=\"Vector\", style=style) m In\u00a0[\u00a0]: Copied!
sam.show_map()\nsam.show_map() "},{"location":"examples/text_prompts/#segmenting-remote-sensing-imagery-with-text-prompts-and-the-segment-anything-model-sam","title":"Segmenting remote sensing imagery with text prompts and the Segment Anything Model (SAM)\u00b6","text":"
This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM).
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/text_prompts/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/text_prompts/#download-a-sample-image","title":"Download a sample image\u00b6","text":"Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
"},{"location":"examples/text_prompts/#initialize-langsam-class","title":"Initialize LangSAM class\u00b6","text":"The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
"},{"location":"examples/text_prompts/#specify-text-prompts","title":"Specify text prompts\u00b6","text":""},{"location":"examples/text_prompts/#segment-the-image","title":"Segment the image\u00b6","text":"Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.
box_threshold
: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.
text_threshold
: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.
Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.
"},{"location":"examples/text_prompts/#visualize-the-results","title":"Visualize the results\u00b6","text":"Show the result with bounding boxes on the map.
"},{"location":"examples/text_prompts/#interactive-segmentation","title":"Interactive segmentation\u00b6","text":""},{"location":"examples/text_prompts_batch/","title":"Text prompts batch","text":"In\u00a0[\u00a0]: Copied!# %pip install segment-geospatial groundingdino-py leafmap localtileserver\n# %pip install segment-geospatial groundingdino-py leafmap localtileserver In\u00a0[\u00a0]: Copied!
import leafmap\nfrom samgeo import tms_to_geotiff, split_raster\nfrom samgeo.text_sam import LangSAM\nimport leafmap from samgeo import tms_to_geotiff, split_raster from samgeo.text_sam import LangSAM In\u00a0[\u00a0]: Copied!
m = leafmap.Map(center=[-22.1278, -51.4430], zoom=17, height=\"800px\")\nm.add_basemap(\"SATELLITE\")\nm\nm = leafmap.Map(center=[-22.1278, -51.4430], zoom=17, height=\"800px\") m.add_basemap(\"SATELLITE\") m In\u00a0[\u00a0]: Copied!
bbox = m.user_roi_bounds()\nif bbox is None:\n bbox = [-51.4494, -22.1307, -51.4371, -22.1244]\nbbox = m.user_roi_bounds() if bbox is None: bbox = [-51.4494, -22.1307, -51.4371, -22.1244] In\u00a0[\u00a0]: Copied!
image = \"Image.tif\"\ntms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)\nimage = \"Image.tif\" tms_to_geotiff(output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True)
You can also use your own image. Uncomment and run the following cell to use your own image.
In\u00a0[\u00a0]: Copied!# image = '/path/to/your/own/image.tif'\n# image = '/path/to/your/own/image.tif'
Display the downloaded image on the map.
In\u00a0[\u00a0]: Copied!m.layers[-1].visible = False\nm.add_raster(image, layer_name=\"Image\")\nm\nm.layers[-1].visible = False m.add_raster(image, layer_name=\"Image\") m In\u00a0[\u00a0]: Copied!
split_raster(image, out_dir=\"tiles\", tile_size=(1000, 1000), overlap=0)\nsplit_raster(image, out_dir=\"tiles\", tile_size=(1000, 1000), overlap=0) In\u00a0[\u00a0]: Copied!
sam = LangSAM()\nsam = LangSAM() In\u00a0[\u00a0]: Copied!
text_prompt = \"tree\"\ntext_prompt = \"tree\" In\u00a0[\u00a0]: Copied!
sam.predict_batch(\n images='tiles',\n out_dir='masks',\n text_prompt=text_prompt,\n box_threshold=0.24,\n text_threshold=0.24,\n mask_multiplier=255,\n dtype='uint8',\n merge=True,\n verbose=True,\n)\nsam.predict_batch( images='tiles', out_dir='masks', text_prompt=text_prompt, box_threshold=0.24, text_threshold=0.24, mask_multiplier=255, dtype='uint8', merge=True, verbose=True, ) In\u00a0[\u00a0]: Copied!
m.add_raster('masks/merged.tif', cmap='viridis', nodata=0, layer_name='Mask')\nm.add_layer_manager()\nm\nm.add_raster('masks/merged.tif', cmap='viridis', nodata=0, layer_name='Mask') m.add_layer_manager() m "},{"location":"examples/text_prompts_batch/#batch-segmentation-with-text-prompts","title":"Batch segmentation with text prompts\u00b6","text":"
This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM).
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Uncomment and run the following cell to install the required dependencies.
"},{"location":"examples/text_prompts_batch/#create-an-interactive-map","title":"Create an interactive map\u00b6","text":""},{"location":"examples/text_prompts_batch/#download-a-sample-image","title":"Download a sample image\u00b6","text":"Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
"},{"location":"examples/text_prompts_batch/#split-the-image-into-tiles","title":"Split the image into tiles\u00b6","text":""},{"location":"examples/text_prompts_batch/#initialize-langsam-class","title":"Initialize LangSAM class\u00b6","text":"The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.
"},{"location":"examples/text_prompts_batch/#specify-text-prompts","title":"Specify text prompts\u00b6","text":""},{"location":"examples/text_prompts_batch/#segment-images","title":"Segment images\u00b6","text":"Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.
box_threshold
: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.
text_threshold
: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.
Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.
"},{"location":"examples/text_prompts_batch/#visualize-the-results","title":"Visualize the results\u00b6","text":""}]} \ No newline at end of file diff --git a/sitemap.xml b/sitemap.xml new file mode 100644 index 00000000..ea9392c4 --- /dev/null +++ b/sitemap.xml @@ -0,0 +1,108 @@ + +The LangSAM model for segmenting objects from satellite images using text prompts. +The source code is adapted from the https://github.com/luca-medeiros/lang-segment-anything repository. +Credits to Luca Medeiros for the original implementation.
+ + + +
+LangSAM
+
+
+
+¶A Language-based Segment-Anything Model (LangSAM) class which combines GroundingDINO and SAM.
+ +samgeo/text_sam.py
class LangSAM:
+ """
+ A Language-based Segment-Anything Model (LangSAM) class which combines GroundingDINO and SAM.
+ """
+
+ def __init__(self, model_type="vit_h"):
+ """Initialize the LangSAM instance.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ """
+
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.build_groundingdino()
+ self.build_sam(model_type)
+
+ self.source = None
+ self.image = None
+ self.masks = None
+ self.boxes = None
+ self.phrases = None
+ self.logits = None
+ self.prediction = None
+
+ def build_sam(self, model_type):
+ """Build the SAM model.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ """
+ checkpoint_url = SAM_MODELS[model_type]
+ sam = sam_model_registry[model_type]()
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
+ sam.load_state_dict(state_dict, strict=True)
+ sam.to(device=self.device)
+ self.sam = SamPredictor(sam)
+
+ def build_groundingdino(self):
+ """Build the GroundingDINO model."""
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
+ ckpt_filename = "groundingdino_swinb_cogcoor.pth"
+ ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
+ self.groundingdino = load_model_hf(
+ ckpt_repo_id, ckpt_filename, ckpt_config_filename, self.device
+ )
+
+ def predict_dino(self, image, text_prompt, box_threshold, text_threshold):
+ """
+ Run the GroundingDINO model prediction.
+
+ Args:
+ image (Image): Input PIL Image.
+ text_prompt (str): Text prompt for the model.
+ box_threshold (float): Box threshold for the prediction.
+ text_threshold (float): Text threshold for the prediction.
+
+ Returns:
+ tuple: Tuple containing boxes, logits, and phrases.
+ """
+
+ image_trans = transform_image(image)
+ boxes, logits, phrases = predict(
+ model=self.groundingdino,
+ image=image_trans,
+ caption=text_prompt,
+ box_threshold=box_threshold,
+ text_threshold=text_threshold,
+ device=self.device,
+ )
+ W, H = image.size
+ boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
+
+ return boxes, logits, phrases
+
+ def predict_sam(self, image, boxes):
+ """
+ Run the SAM model prediction.
+
+ Args:
+ image (Image): Input PIL Image.
+ boxes (torch.Tensor): Tensor of bounding boxes.
+
+ Returns:
+ Masks tensor.
+ """
+ image_array = np.asarray(image)
+ self.sam.set_image(image_array)
+ transformed_boxes = self.sam.transform.apply_boxes_torch(
+ boxes, image_array.shape[:2]
+ )
+ masks, _, _ = self.sam.predict_torch(
+ point_coords=None,
+ point_labels=None,
+ boxes=transformed_boxes.to(self.sam.device),
+ multimask_output=False,
+ )
+ return masks.cpu()
+
+ def set_image(self, image):
+ """Set the input image.
+
+ Args:
+ image (str): The path to the image file or a HTTP URL.
+ """
+
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+ else:
+ self.source = None
+
+ def predict(
+ self,
+ image,
+ text_prompt,
+ box_threshold,
+ text_threshold,
+ output=None,
+ mask_multiplier=255,
+ dtype=np.uint8,
+ save_args={},
+ return_results=False,
+ return_coords=False,
+ **kwargs,
+ ):
+ """
+ Run both GroundingDINO and SAM model prediction.
+
+ Parameters:
+ image (Image): Input PIL Image.
+ text_prompt (str): Text prompt for the model.
+ box_threshold (float): Box threshold for the prediction.
+ text_threshold (float): Text threshold for the prediction.
+ output (str, optional): Output path for the prediction. Defaults to None.
+ mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.
+ dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
+ save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
+ return_results (bool, optional): Whether to return the results. Defaults to False.
+
+ Returns:
+ tuple: Tuple containing masks, boxes, phrases, and logits.
+ """
+
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+
+ # Load the georeferenced image
+ with rasterio.open(image) as src:
+ image_np = src.read().transpose(
+ (1, 2, 0)
+ ) # Convert rasterio image to numpy array
+ self.transform = src.transform # Save georeferencing information
+ self.crs = src.crs # Save the Coordinate Reference System
+ image_pil = Image.fromarray(
+ image_np[:, :, :3]
+ ) # Convert numpy array to PIL image, excluding the alpha channel
+ else:
+ image_pil = image
+ image_np = np.array(image_pil)
+
+ self.image = image_pil
+
+ boxes, logits, phrases = self.predict_dino(
+ image_pil, text_prompt, box_threshold, text_threshold
+ )
+ masks = torch.tensor([])
+ if len(boxes) > 0:
+ masks = self.predict_sam(image_pil, boxes)
+ masks = masks.squeeze(1)
+
+ if boxes.nelement() == 0: # No "object" instances found
+ print("No objects found in the image.")
+ return
+ else:
+ # Create an empty image to store the mask overlays
+ mask_overlay = np.zeros_like(
+ image_np[..., 0], dtype=dtype
+ ) # Adjusted for single channel
+
+ for i, (box, mask) in enumerate(zip(boxes, masks)):
+ # Convert tensor to numpy array if necessary and ensure it contains integers
+ if isinstance(mask, torch.Tensor):
+ mask = (
+ mask.cpu().numpy().astype(dtype)
+ ) # If mask is on GPU, use .cpu() before .numpy()
+ mask_overlay += ((mask > 0) * (i + 1)).astype(
+ dtype
+ ) # Assign a unique value for each mask
+
+ # Normalize mask_overlay to be in [0, 255]
+ mask_overlay = (
+ mask_overlay > 0
+ ) * mask_multiplier # Binary mask in [0, 255]
+
+ if output is not None:
+ array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
+
+ self.masks = masks
+ self.boxes = boxes
+ self.phrases = phrases
+ self.logits = logits
+ self.prediction = mask_overlay
+
+ if return_results:
+ return masks, boxes, phrases, logits
+
+ if return_coords:
+ boxlist = []
+ for box in self.boxes:
+ box = box.cpu().numpy()
+ boxlist.append((box[0], box[1]))
+ return boxlist
+
+ def predict_batch(
+ self,
+ images,
+ out_dir,
+ text_prompt,
+ box_threshold,
+ text_threshold,
+ mask_multiplier=255,
+ dtype=np.uint8,
+ save_args={},
+ merge=True,
+ verbose=True,
+ **kwargs,
+ ):
+ """
+ Run both GroundingDINO and SAM model prediction for a batch of images.
+
+ Parameters:
+ images (list): List of input PIL Images.
+ out_dir (str): Output directory for the prediction.
+ text_prompt (str): Text prompt for the model.
+ box_threshold (float): Box threshold for the prediction.
+ text_threshold (float): Text threshold for the prediction.
+ mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.
+ dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
+ save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
+ merge (bool, optional): Whether to merge the predictions into a single GeoTIFF file. Defaults to True.
+ """
+
+ import glob
+
+ if not os.path.exists(out_dir):
+ os.makedirs(out_dir)
+
+ if isinstance(images, str):
+ images = list(glob.glob(os.path.join(images, "*.tif")))
+ images.sort()
+
+ if not isinstance(images, list):
+ raise ValueError("images must be a list or a directory to GeoTIFF files.")
+
+ for i, image in enumerate(images):
+ basename = os.path.splitext(os.path.basename(image))[0]
+ if verbose:
+ print(
+ f"Processing image {str(i+1).zfill(len(str(len(images))))} of {len(images)}: {image}..."
+ )
+ output = os.path.join(out_dir, f"{basename}_mask.tif")
+ self.predict(
+ image,
+ text_prompt,
+ box_threshold,
+ text_threshold,
+ output=output,
+ mask_multiplier=mask_multiplier,
+ dtype=dtype,
+ save_args=save_args,
+ **kwargs,
+ )
+
+ if merge:
+ output = os.path.join(out_dir, "merged.tif")
+ merge_rasters(out_dir, output)
+ if verbose:
+ print(f"Saved the merged prediction to {output}.")
+
+ def save_boxes(self, output=None, dst_crs="EPSG:4326", **kwargs):
+ """Save the bounding boxes to a vector file.
+
+ Args:
+ output (str): The path to the output vector file.
+ dst_crs (str, optional): The destination CRS. Defaults to "EPSG:4326".
+ **kwargs: Additional arguments for boxes_to_vector().
+ """
+
+ if self.boxes is None:
+ print("Please run predict() first.")
+ return
+ else:
+ boxes = self.boxes.tolist()
+ coords = rowcol_to_xy(self.source, boxes=boxes, dst_crs=dst_crs, **kwargs)
+ if output is None:
+ return boxes_to_vector(coords, self.crs, dst_crs, output)
+ else:
+ boxes_to_vector(coords, self.crs, dst_crs, output)
+
+ def show_anns(
+ self,
+ figsize=(12, 10),
+ axis="off",
+ cmap="viridis",
+ alpha=0.4,
+ add_boxes=True,
+ box_color="r",
+ box_linewidth=1,
+ title=None,
+ output=None,
+ blend=True,
+ **kwargs,
+ ):
+ """Show the annotations (objects with random color) on the input image.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ cmap (str, optional): The colormap for the annotations. Defaults to "viridis".
+ alpha (float, optional): The alpha value for the annotations. Defaults to 0.4.
+ add_boxes (bool, optional): Whether to show the bounding boxes. Defaults to True.
+ box_color (str, optional): The color for the bounding boxes. Defaults to "r".
+ box_linewidth (int, optional): The line width for the bounding boxes. Defaults to 1.
+ title (str, optional): The title for the image. Defaults to None.
+ output (str, optional): The path to the output image. Defaults to None.
+ blend (bool, optional): Whether to show the input image. Defaults to True.
+ kwargs (dict, optional): Additional arguments for matplotlib.pyplot.savefig().
+ """
+
+ import warnings
+ import matplotlib.pyplot as plt
+ import matplotlib.patches as patches
+
+ warnings.filterwarnings("ignore")
+
+ anns = self.prediction
+
+ if anns is None:
+ print("Please run predict() first.")
+ return
+ elif len(anns) == 0:
+ print("No objects found in the image.")
+ return
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.image)
+
+ if add_boxes:
+ for box in self.boxes:
+ # Draw bounding box
+ box = box.cpu().numpy() # Convert the tensor to a numpy array
+ rect = patches.Rectangle(
+ (box[0], box[1]),
+ box[2] - box[0],
+ box[3] - box[1],
+ linewidth=box_linewidth,
+ edgecolor=box_color,
+ facecolor="none",
+ )
+ plt.gca().add_patch(rect)
+
+ if "dpi" not in kwargs:
+ kwargs["dpi"] = 100
+
+ if "bbox_inches" not in kwargs:
+ kwargs["bbox_inches"] = "tight"
+
+ plt.imshow(anns, cmap=cmap, alpha=alpha)
+
+ if title is not None:
+ plt.title(title)
+ plt.axis(axis)
+
+ if output is not None:
+ if blend:
+ plt.savefig(output, **kwargs)
+ else:
+ array_to_image(self.prediction, output, self.source)
+
+ def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
+ """Save the result to a vector file.
+
+ Args:
+ image (str): The path to the image file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
+ def show_map(self, basemap="SATELLITE", out_dir=None, **kwargs):
+ """Show the interactive map.
+
+ Args:
+ basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
+ out_dir (str, optional): The path to the output directory. Defaults to None.
+
+ Returns:
+ leafmap.Map: The map object.
+ """
+ return text_sam_gui(self, basemap=basemap, out_dir=out_dir, **kwargs)
+
__init__(self, model_type='vit_h')
+
+
+ special
+
+
+¶Initialize the LangSAM instance.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
model_type |
+ str |
+ The model type. It can be one of the following: vit_h, vit_l, vit_b. +Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details. |
+ 'vit_h' |
+
samgeo/text_sam.py
def __init__(self, model_type="vit_h"):
+ """Initialize the LangSAM instance.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ """
+
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.build_groundingdino()
+ self.build_sam(model_type)
+
+ self.source = None
+ self.image = None
+ self.masks = None
+ self.boxes = None
+ self.phrases = None
+ self.logits = None
+ self.prediction = None
+
build_groundingdino(self)
+
+
+¶Build the GroundingDINO model.
+ +samgeo/text_sam.py
def build_groundingdino(self):
+ """Build the GroundingDINO model."""
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
+ ckpt_filename = "groundingdino_swinb_cogcoor.pth"
+ ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
+ self.groundingdino = load_model_hf(
+ ckpt_repo_id, ckpt_filename, ckpt_config_filename, self.device
+ )
+
build_sam(self, model_type)
+
+
+¶Build the SAM model.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
model_type |
+ str |
+ The model type. It can be one of the following: vit_h, vit_l, vit_b. +Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details. |
+ required | +
samgeo/text_sam.py
def build_sam(self, model_type):
+ """Build the SAM model.
+
+ Args:
+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
+ """
+ checkpoint_url = SAM_MODELS[model_type]
+ sam = sam_model_registry[model_type]()
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
+ sam.load_state_dict(state_dict, strict=True)
+ sam.to(device=self.device)
+ self.sam = SamPredictor(sam)
+
predict(self, image, text_prompt, box_threshold, text_threshold, output=None, mask_multiplier=255, dtype=<class 'numpy.uint8'>, save_args={}, return_results=False, return_coords=False, **kwargs)
+
+
+¶Run both GroundingDINO and SAM model prediction.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ Image |
+ Input PIL Image. |
+ required | +
text_prompt |
+ str |
+ Text prompt for the model. |
+ required | +
box_threshold |
+ float |
+ Box threshold for the prediction. |
+ required | +
text_threshold |
+ float |
+ Text threshold for the prediction. |
+ required | +
output |
+ str |
+ Output path for the prediction. Defaults to None. |
+ None |
+
mask_multiplier |
+ int |
+ Mask multiplier for the prediction. Defaults to 255. |
+ 255 |
+
dtype |
+ np.dtype |
+ Data type for the prediction. Defaults to np.uint8. |
+ <class 'numpy.uint8'> |
+
save_args |
+ dict |
+ Save arguments for the prediction. Defaults to {}. |
+ {} |
+
return_results |
+ bool |
+ Whether to return the results. Defaults to False. |
+ False |
+
Returns:
+Type | +Description | +
---|---|
tuple |
+ Tuple containing masks, boxes, phrases, and logits. |
+
samgeo/text_sam.py
def predict(
+ self,
+ image,
+ text_prompt,
+ box_threshold,
+ text_threshold,
+ output=None,
+ mask_multiplier=255,
+ dtype=np.uint8,
+ save_args={},
+ return_results=False,
+ return_coords=False,
+ **kwargs,
+):
+ """
+ Run both GroundingDINO and SAM model prediction.
+
+ Parameters:
+ image (Image): Input PIL Image.
+ text_prompt (str): Text prompt for the model.
+ box_threshold (float): Box threshold for the prediction.
+ text_threshold (float): Text threshold for the prediction.
+ output (str, optional): Output path for the prediction. Defaults to None.
+ mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.
+ dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
+ save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
+ return_results (bool, optional): Whether to return the results. Defaults to False.
+
+ Returns:
+ tuple: Tuple containing masks, boxes, phrases, and logits.
+ """
+
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+
+ # Load the georeferenced image
+ with rasterio.open(image) as src:
+ image_np = src.read().transpose(
+ (1, 2, 0)
+ ) # Convert rasterio image to numpy array
+ self.transform = src.transform # Save georeferencing information
+ self.crs = src.crs # Save the Coordinate Reference System
+ image_pil = Image.fromarray(
+ image_np[:, :, :3]
+ ) # Convert numpy array to PIL image, excluding the alpha channel
+ else:
+ image_pil = image
+ image_np = np.array(image_pil)
+
+ self.image = image_pil
+
+ boxes, logits, phrases = self.predict_dino(
+ image_pil, text_prompt, box_threshold, text_threshold
+ )
+ masks = torch.tensor([])
+ if len(boxes) > 0:
+ masks = self.predict_sam(image_pil, boxes)
+ masks = masks.squeeze(1)
+
+ if boxes.nelement() == 0: # No "object" instances found
+ print("No objects found in the image.")
+ return
+ else:
+ # Create an empty image to store the mask overlays
+ mask_overlay = np.zeros_like(
+ image_np[..., 0], dtype=dtype
+ ) # Adjusted for single channel
+
+ for i, (box, mask) in enumerate(zip(boxes, masks)):
+ # Convert tensor to numpy array if necessary and ensure it contains integers
+ if isinstance(mask, torch.Tensor):
+ mask = (
+ mask.cpu().numpy().astype(dtype)
+ ) # If mask is on GPU, use .cpu() before .numpy()
+ mask_overlay += ((mask > 0) * (i + 1)).astype(
+ dtype
+ ) # Assign a unique value for each mask
+
+ # Normalize mask_overlay to be in [0, 255]
+ mask_overlay = (
+ mask_overlay > 0
+ ) * mask_multiplier # Binary mask in [0, 255]
+
+ if output is not None:
+ array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
+
+ self.masks = masks
+ self.boxes = boxes
+ self.phrases = phrases
+ self.logits = logits
+ self.prediction = mask_overlay
+
+ if return_results:
+ return masks, boxes, phrases, logits
+
+ if return_coords:
+ boxlist = []
+ for box in self.boxes:
+ box = box.cpu().numpy()
+ boxlist.append((box[0], box[1]))
+ return boxlist
+
predict_batch(self, images, out_dir, text_prompt, box_threshold, text_threshold, mask_multiplier=255, dtype=<class 'numpy.uint8'>, save_args={}, merge=True, verbose=True, **kwargs)
+
+
+¶Run both GroundingDINO and SAM model prediction for a batch of images.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
images |
+ list |
+ List of input PIL Images. |
+ required | +
out_dir |
+ str |
+ Output directory for the prediction. |
+ required | +
text_prompt |
+ str |
+ Text prompt for the model. |
+ required | +
box_threshold |
+ float |
+ Box threshold for the prediction. |
+ required | +
text_threshold |
+ float |
+ Text threshold for the prediction. |
+ required | +
mask_multiplier |
+ int |
+ Mask multiplier for the prediction. Defaults to 255. |
+ 255 |
+
dtype |
+ np.dtype |
+ Data type for the prediction. Defaults to np.uint8. |
+ <class 'numpy.uint8'> |
+
save_args |
+ dict |
+ Save arguments for the prediction. Defaults to {}. |
+ {} |
+
merge |
+ bool |
+ Whether to merge the predictions into a single GeoTIFF file. Defaults to True. |
+ True |
+
samgeo/text_sam.py
def predict_batch(
+ self,
+ images,
+ out_dir,
+ text_prompt,
+ box_threshold,
+ text_threshold,
+ mask_multiplier=255,
+ dtype=np.uint8,
+ save_args={},
+ merge=True,
+ verbose=True,
+ **kwargs,
+):
+ """
+ Run both GroundingDINO and SAM model prediction for a batch of images.
+
+ Parameters:
+ images (list): List of input PIL Images.
+ out_dir (str): Output directory for the prediction.
+ text_prompt (str): Text prompt for the model.
+ box_threshold (float): Box threshold for the prediction.
+ text_threshold (float): Text threshold for the prediction.
+ mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.
+ dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
+ save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
+ merge (bool, optional): Whether to merge the predictions into a single GeoTIFF file. Defaults to True.
+ """
+
+ import glob
+
+ if not os.path.exists(out_dir):
+ os.makedirs(out_dir)
+
+ if isinstance(images, str):
+ images = list(glob.glob(os.path.join(images, "*.tif")))
+ images.sort()
+
+ if not isinstance(images, list):
+ raise ValueError("images must be a list or a directory to GeoTIFF files.")
+
+ for i, image in enumerate(images):
+ basename = os.path.splitext(os.path.basename(image))[0]
+ if verbose:
+ print(
+ f"Processing image {str(i+1).zfill(len(str(len(images))))} of {len(images)}: {image}..."
+ )
+ output = os.path.join(out_dir, f"{basename}_mask.tif")
+ self.predict(
+ image,
+ text_prompt,
+ box_threshold,
+ text_threshold,
+ output=output,
+ mask_multiplier=mask_multiplier,
+ dtype=dtype,
+ save_args=save_args,
+ **kwargs,
+ )
+
+ if merge:
+ output = os.path.join(out_dir, "merged.tif")
+ merge_rasters(out_dir, output)
+ if verbose:
+ print(f"Saved the merged prediction to {output}.")
+
predict_dino(self, image, text_prompt, box_threshold, text_threshold)
+
+
+¶Run the GroundingDINO model prediction.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ Image |
+ Input PIL Image. |
+ required | +
text_prompt |
+ str |
+ Text prompt for the model. |
+ required | +
box_threshold |
+ float |
+ Box threshold for the prediction. |
+ required | +
text_threshold |
+ float |
+ Text threshold for the prediction. |
+ required | +
Returns:
+Type | +Description | +
---|---|
tuple |
+ Tuple containing boxes, logits, and phrases. |
+
samgeo/text_sam.py
def predict_dino(self, image, text_prompt, box_threshold, text_threshold):
+ """
+ Run the GroundingDINO model prediction.
+
+ Args:
+ image (Image): Input PIL Image.
+ text_prompt (str): Text prompt for the model.
+ box_threshold (float): Box threshold for the prediction.
+ text_threshold (float): Text threshold for the prediction.
+
+ Returns:
+ tuple: Tuple containing boxes, logits, and phrases.
+ """
+
+ image_trans = transform_image(image)
+ boxes, logits, phrases = predict(
+ model=self.groundingdino,
+ image=image_trans,
+ caption=text_prompt,
+ box_threshold=box_threshold,
+ text_threshold=text_threshold,
+ device=self.device,
+ )
+ W, H = image.size
+ boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
+
+ return boxes, logits, phrases
+
predict_sam(self, image, boxes)
+
+
+¶Run the SAM model prediction.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ Image |
+ Input PIL Image. |
+ required | +
boxes |
+ torch.Tensor |
+ Tensor of bounding boxes. |
+ required | +
Returns:
+Type | +Description | +
---|---|
+ | Masks tensor. |
+
samgeo/text_sam.py
def predict_sam(self, image, boxes):
+ """
+ Run the SAM model prediction.
+
+ Args:
+ image (Image): Input PIL Image.
+ boxes (torch.Tensor): Tensor of bounding boxes.
+
+ Returns:
+ Masks tensor.
+ """
+ image_array = np.asarray(image)
+ self.sam.set_image(image_array)
+ transformed_boxes = self.sam.transform.apply_boxes_torch(
+ boxes, image_array.shape[:2]
+ )
+ masks, _, _ = self.sam.predict_torch(
+ point_coords=None,
+ point_labels=None,
+ boxes=transformed_boxes.to(self.sam.device),
+ multimask_output=False,
+ )
+ return masks.cpu()
+
raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs)
+
+
+¶Save the result to a vector file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str |
+ The path to the image file. |
+ required | +
output |
+ str |
+ The path to the vector file. |
+ required | +
simplify_tolerance |
+ float |
+ The maximum allowed geometry displacement. +The higher this value, the smaller the number of vertices in the resulting geometry. |
+ None |
+
samgeo/text_sam.py
def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
+ """Save the result to a vector file.
+
+ Args:
+ image (str): The path to the image file.
+ output (str): The path to the vector file.
+ simplify_tolerance (float, optional): The maximum allowed geometry displacement.
+ The higher this value, the smaller the number of vertices in the resulting geometry.
+ """
+
+ raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)
+
save_boxes(self, output=None, dst_crs='EPSG:4326', **kwargs)
+
+
+¶Save the bounding boxes to a vector file.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
output |
+ str |
+ The path to the output vector file. |
+ None |
+
dst_crs |
+ str |
+ The destination CRS. Defaults to "EPSG:4326". |
+ 'EPSG:4326' |
+
**kwargs |
+ + | Additional arguments for boxes_to_vector(). |
+ {} |
+
samgeo/text_sam.py
def save_boxes(self, output=None, dst_crs="EPSG:4326", **kwargs):
+ """Save the bounding boxes to a vector file.
+
+ Args:
+ output (str): The path to the output vector file.
+ dst_crs (str, optional): The destination CRS. Defaults to "EPSG:4326".
+ **kwargs: Additional arguments for boxes_to_vector().
+ """
+
+ if self.boxes is None:
+ print("Please run predict() first.")
+ return
+ else:
+ boxes = self.boxes.tolist()
+ coords = rowcol_to_xy(self.source, boxes=boxes, dst_crs=dst_crs, **kwargs)
+ if output is None:
+ return boxes_to_vector(coords, self.crs, dst_crs, output)
+ else:
+ boxes_to_vector(coords, self.crs, dst_crs, output)
+
set_image(self, image)
+
+
+¶Set the input image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ str |
+ The path to the image file or a HTTP URL. |
+ required | +
samgeo/text_sam.py
def set_image(self, image):
+ """Set the input image.
+
+ Args:
+ image (str): The path to the image file or a HTTP URL.
+ """
+
+ if isinstance(image, str):
+ if image.startswith("http"):
+ image = download_file(image)
+
+ if not os.path.exists(image):
+ raise ValueError(f"Input path {image} does not exist.")
+
+ self.source = image
+ else:
+ self.source = None
+
show_anns(self, figsize=(12, 10), axis='off', cmap='viridis', alpha=0.4, add_boxes=True, box_color='r', box_linewidth=1, title=None, output=None, blend=True, **kwargs)
+
+
+¶Show the annotations (objects with random color) on the input image.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
figsize |
+ tuple |
+ The figure size. Defaults to (12, 10). |
+ (12, 10) |
+
axis |
+ str |
+ Whether to show the axis. Defaults to "off". |
+ 'off' |
+
cmap |
+ str |
+ The colormap for the annotations. Defaults to "viridis". |
+ 'viridis' |
+
alpha |
+ float |
+ The alpha value for the annotations. Defaults to 0.4. |
+ 0.4 |
+
add_boxes |
+ bool |
+ Whether to show the bounding boxes. Defaults to True. |
+ True |
+
box_color |
+ str |
+ The color for the bounding boxes. Defaults to "r". |
+ 'r' |
+
box_linewidth |
+ int |
+ The line width for the bounding boxes. Defaults to 1. |
+ 1 |
+
title |
+ str |
+ The title for the image. Defaults to None. |
+ None |
+
output |
+ str |
+ The path to the output image. Defaults to None. |
+ None |
+
blend |
+ bool |
+ Whether to show the input image. Defaults to True. |
+ True |
+
kwargs |
+ dict |
+ Additional arguments for matplotlib.pyplot.savefig(). |
+ {} |
+
samgeo/text_sam.py
def show_anns(
+ self,
+ figsize=(12, 10),
+ axis="off",
+ cmap="viridis",
+ alpha=0.4,
+ add_boxes=True,
+ box_color="r",
+ box_linewidth=1,
+ title=None,
+ output=None,
+ blend=True,
+ **kwargs,
+):
+ """Show the annotations (objects with random color) on the input image.
+
+ Args:
+ figsize (tuple, optional): The figure size. Defaults to (12, 10).
+ axis (str, optional): Whether to show the axis. Defaults to "off".
+ cmap (str, optional): The colormap for the annotations. Defaults to "viridis".
+ alpha (float, optional): The alpha value for the annotations. Defaults to 0.4.
+ add_boxes (bool, optional): Whether to show the bounding boxes. Defaults to True.
+ box_color (str, optional): The color for the bounding boxes. Defaults to "r".
+ box_linewidth (int, optional): The line width for the bounding boxes. Defaults to 1.
+ title (str, optional): The title for the image. Defaults to None.
+ output (str, optional): The path to the output image. Defaults to None.
+ blend (bool, optional): Whether to show the input image. Defaults to True.
+ kwargs (dict, optional): Additional arguments for matplotlib.pyplot.savefig().
+ """
+
+ import warnings
+ import matplotlib.pyplot as plt
+ import matplotlib.patches as patches
+
+ warnings.filterwarnings("ignore")
+
+ anns = self.prediction
+
+ if anns is None:
+ print("Please run predict() first.")
+ return
+ elif len(anns) == 0:
+ print("No objects found in the image.")
+ return
+
+ plt.figure(figsize=figsize)
+ plt.imshow(self.image)
+
+ if add_boxes:
+ for box in self.boxes:
+ # Draw bounding box
+ box = box.cpu().numpy() # Convert the tensor to a numpy array
+ rect = patches.Rectangle(
+ (box[0], box[1]),
+ box[2] - box[0],
+ box[3] - box[1],
+ linewidth=box_linewidth,
+ edgecolor=box_color,
+ facecolor="none",
+ )
+ plt.gca().add_patch(rect)
+
+ if "dpi" not in kwargs:
+ kwargs["dpi"] = 100
+
+ if "bbox_inches" not in kwargs:
+ kwargs["bbox_inches"] = "tight"
+
+ plt.imshow(anns, cmap=cmap, alpha=alpha)
+
+ if title is not None:
+ plt.title(title)
+ plt.axis(axis)
+
+ if output is not None:
+ if blend:
+ plt.savefig(output, **kwargs)
+ else:
+ array_to_image(self.prediction, output, self.source)
+
show_map(self, basemap='SATELLITE', out_dir=None, **kwargs)
+
+
+¶Show the interactive map.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
basemap |
+ str |
+ The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID. |
+ 'SATELLITE' |
+
out_dir |
+ str |
+ The path to the output directory. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
leafmap.Map |
+ The map object. |
+
samgeo/text_sam.py
def show_map(self, basemap="SATELLITE", out_dir=None, **kwargs):
+ """Show the interactive map.
+
+ Args:
+ basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
+ out_dir (str, optional): The path to the output directory. Defaults to None.
+
+ Returns:
+ leafmap.Map: The map object.
+ """
+ return text_sam_gui(self, basemap=basemap, out_dir=out_dir, **kwargs)
+
load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu')
+
+
+¶Loads a model from HuggingFace Model Hub.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
repo_id |
+ str |
+ Repository ID on HuggingFace Model Hub. |
+ required | +
filename |
+ str |
+ Name of the model file in the repository. |
+ required | +
ckpt_config_filename |
+ str |
+ Name of the config file for the model in the repository. |
+ required | +
device |
+ str |
+ Device to load the model onto. Default is 'cpu'. |
+ 'cpu' |
+
Returns:
+Type | +Description | +
---|---|
torch.nn.Module |
+ The loaded model. |
+
samgeo/text_sam.py
def load_model_hf(
+ repo_id: str, filename: str, ckpt_config_filename: str, device: str = "cpu"
+) -> torch.nn.Module:
+ """
+ Loads a model from HuggingFace Model Hub.
+
+ Args:
+ repo_id (str): Repository ID on HuggingFace Model Hub.
+ filename (str): Name of the model file in the repository.
+ ckpt_config_filename (str): Name of the config file for the model in the repository.
+ device (str): Device to load the model onto. Default is 'cpu'.
+
+ Returns:
+ torch.nn.Module: The loaded model.
+ """
+
+ cache_config_file = hf_hub_download(
+ repo_id=repo_id,
+ filename=ckpt_config_filename,
+ force_filename=ckpt_config_filename,
+ )
+ args = SLConfig.fromfile(cache_config_file)
+ model = build_model(args)
+ model.to(device)
+ cache_file = hf_hub_download(
+ repo_id=repo_id, filename=filename, force_filename=filename
+ )
+ checkpoint = torch.load(cache_file, map_location="cpu")
+ model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
+ model.eval()
+ return model
+
transform_image(image)
+
+
+¶Transforms an image using standard transformations for image-based models.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
image |
+ Image |
+ The PIL Image to be transformed. |
+ required | +
Returns:
+Type | +Description | +
---|---|
torch.Tensor |
+ The transformed image as a tensor. |
+
samgeo/text_sam.py
def transform_image(image: Image) -> torch.Tensor:
+ """
+ Transforms an image using standard transformations for image-based models.
+
+ Args:
+ image (Image): The PIL Image to be transformed.
+
+ Returns:
+ torch.Tensor: The transformed image as a tensor.
+ """
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image_transformed, _ = transform(image, None)
+ return image_transformed
+
To use segment-geospatial in a project:
+1 |
|
Here is a simple example of using segment-geospatial to generate a segmentation mask from a satellite image:
+1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 |
|