diff --git a/docs/data_datacube.md b/docs/data_datacube.md index ceec98e9..cd638700 100644 --- a/docs/data_datacube.md +++ b/docs/data_datacube.md @@ -4,13 +4,13 @@ The `datacube.py` script collects Sentinel-2, Sentinel-1, and DEM data over individual MGRS tiles. The source list of the MGRS tiles to be processed is provided in an input file with MGRS geometries. Each run of the script will collect data for one of the MGRS tiles in the source file. The tile to be processed is based on the row index number provided as input. The MGRS tile ID is expected to be in the `name` property of the input file. -For the target MGRS tile, the script loops through the years between 2017 and 2023 in random order. For each year, it will search for the least cloudy Sentinel-2 scene. Based on the date of the selected Sentinel-2 scene, it will search for the Sentinel-1 scenes that are the closest match to that date, with a maximum of +/- 3 days of difference. It will include multiple Sentinel-1 scenes until the full MGRS tile is covered. There are cases where no matching Sentinel-1 scenes can be found, in which case the script moves to the next year. The script stops when 3 matching datasets were collected for 3 different years. Finally, the script will also select the intersecting part of the Copernicus Digital Elevation Model (DEM). +For the target MGRS tile, the script loops through the years between 2017 and 2023 in random order. For each year, it will search for the least cloudy Sentinel-2 scene. Based on the date of the selected Sentinel-2 scene, it will search for the Sentinel-1 scenes that are the closest match to that date, with a maximum of +/- 3 days of difference. It will include multiple Sentinel-1 scenes until the full MGRS tile is covered. If no matching Sentinel-1 scenes can be found, the script moves to the next year. The script stops when 3 matching datasets have been collected for 3 different years. Finally, the script will also select the intersecting part of the Copernicus Digital Elevation Model (DEM). -The script will then download all of the Sentinel-2 scene, and match the data cube with the corresponding Sentinel-1 and DEM data. The scene level data is then split into smaller chips of a fixed size of 512x512 pixels. The Sentinel2, Sentinel-1 and DEM bands are then packed together in a single TIFF file for each chip. These are saved locally and synced to a S3 bucket at the end of the script. The bucket name can be specified as input. +The script will then download the Sentinel-2 scene and match the data cube with the corresponding Sentinel-1 and DEM data. The scene-level data is then split into smaller chips of a fixed size of 512x512 pixels. The Sentinel-2, Sentinel-1 and DEM bands are then packed together in a single TIFF file for each chip. These are saved locally and synced to a S3 bucket at the end of the script. The bucket name can be specified as input. For testing and debugging, the data size can be reduced by specifying a pixel window using the `subset` parameter. Data will then be requested only for the specified pixel window. This will reduce the data size considerably which speeds up the processing during testing. -The example run below will search for data for the geometry with row index 1 in a with a local MGRS sample file, for a 1000x1000 pixel window. +The example run below will search for data for the geometry with row index 1 in a local MGRS sample file for a 1000x1000 pixel window. ```bash python datacube.py --sample /home/user/Desktop/mgrs_sample.fgb --bucket "my-bucket" --subset "1000,1000,2000,2000" --index 1 @@ -38,14 +38,14 @@ docker push $ecr_repo_id.dkr.ecr.us-east-1.amazonaws.com/fetch-and-run ### Prepare AWS batch -To prepare batch, we need to create a compute environment, job queue, and job +To prepare a batch, we need to create a compute environment, job queue, and job definition. Example configurations for the compute environment and the job definition are provided in the `batch` directory. The `submit.py` script contains a loop for submitting jobs to the queue. An -alternative to this individual job submissions would be to use array jobs, but +alternative to these individual job submissions would be to use array jobs, but for now the individual submissions are simpler and failures are easier to track. ### Create ZIP file with the package to execute @@ -54,7 +54,7 @@ Package the model and the inference script into a zip file. The `datacube.py` script is the one that will be executed on the instances. Put the scripts in a zip file and upload the zip package into S3 so that -the batch fetch and run can use it. +the batch fetch-and-run can use it. ```bash zip -FSrj "batch-fetch-and-run.zip" ./scripts/pipeline* -x "scripts/pipeline*.pyc" diff --git a/docs/data_labels.md b/docs/data_labels.md index a68c2338..be097d7b 100644 --- a/docs/data_labels.md +++ b/docs/data_labels.md @@ -1,10 +1,10 @@ # Benchmark dataset labels A benchmark dataset is a collection of data used for evaluating the performance -of algorithms, models or systems in a specific field of study. These datasets -are crucial in providing a common ground for comparing different approaches, +of algorithms, models, or systems in a specific field of study. These datasets +are crucial for providing common ground for comparing different approaches, allowing researchers to assess the strengths and weaknesses of various methods. -For Clay, we evaluate our model on benchmark datasets with suitable downstream +For Clay, we evaluate our model on benchmark datasets that have suitable downstream tasks. For our initial benchmark dataset, we've implemented the @@ -14,21 +14,20 @@ evaluation of finetuning on a downstream task. The task itself is [segmentation](https://paperswithcode.com/task/semantic-segmentation) of water pixels associated with recorded flood events. -The original dataset consists of 2/3 of our Foundation model's datacube inputs +The original dataset consists of two out of three of our Foundation model's datacube inputs (Sentinel-1 and Sentinel-2) along with raster water mask labels for both -sensors. Each image is 512x512 pixels in terms of width and height. The -original Sentinel-2 images are L1C, which is Top-of-Atmosphere reflectance. We -are training Clay with surface reflectance, however, so we ultimately used the -geospatial bounds from the GeoTIFF and image timestamp (from the granule name) -to query +sensors. Each image is 512x512 pixels. The +original Sentinel-2 images are L1C, which is Top-of-Atmosphere reflectance. We train +Clay with surface reflectance, however, so we ultimately used the geospatial bounds +from the GeoTIFF and image timestamp (from the granule name) to query [Microsoft Planetary Computer's STAC API for L2A (Bottom-of-Atmosphere a.k.a. "surface reflectance") Sentinel-2](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a) scenes in the same time and space, with the same channels expected by Clay. We then followed the same `datacube` creation logic to generate datacubes with -Sentinel-1 VV and VH and the Copernicus digital elevation model (DEM). We also +Sentinel-1 VV and VH and the Copernicus Digital Elevation Model (DEM). We also ensured that the Sentinel-1 data was within a +/- 3 day interval of each reference Sentinel-2 scene (same method used by the benchmark dataset authors) and that the Sentinel-1 data was indeed already included in the bechmark -datasets list of granules. The datacubes generated have all three inputs +dataset's list of granules. The datacubes generated have all three inputs matching the exact specs of the Foundation model's training data, at 512x512 pixels. @@ -36,18 +35,18 @@ Here is an example of a datacube we generated for the dataset: ![datacube](https://github.com/Clay-foundation/model/assets/23487320/94dffcf5-4075-4c17-ac96-01c11bcb299b) -The images left to right show a true color representation of the Sentinel-2 -scene, the Sentinel-1 VH polarization and the digital elevation model. +The images, left to right, show a true-color representation of the Sentinel-2 +scene, the Sentinel-1 VH polarization, and the Digital Elevation Model. ![gt](https://github.com/Clay-foundation/model/assets/23487320/4ac92af7-6931-4249-a920-7d29453b9b31) Here we have something similar, but this time just the Sentinel-1 and Sentinel-2 scenes with the Sentinel-1 water mask (ground truth) overlaid. -Last note on this benchmark dataset that we've adapted for Clay, we made sure +Last note on this benchmark dataset that we've adapted for Clay: we made sure to preserve the metadata for timestamp and geospatial coordinates in the datacube such that we can embed information in the way that the Clay Foundation -model expects. We also preserve the flood event information too, for analysis +model expects. We also preserve the flood event information for analysis during finetuning. The script for generating these datacubes is at diff --git a/docs/index.md b/docs/index.md index e9b0f9c8..8b0a31b2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,15 +2,20 @@ ## An open source AI model for Earth -Clay is a foundational model of Earth. It uses a Vision Transformer architecture adapted -to understand geospatial and temporal relations on Earth Observation data. The model is -trained via Self-supervised learning (SSL) using a Masked Autoencoder (MAE) method. +Clay is a [foundation model](https://www.adalovelaceinstitute.org/resource/foundation-models-explainer/) of Earth. Foundation models trained on earth observation (EO) data can efficiently distill and synthesize vast amounts of environmental data, allowing them to generalize this knowledge to specific, downstream applications. This makes them versatile and powerful tools for nature and climate applications. + +Clay’s model takes satellite imagery, along with information about location and time, as an input, and outputs embeddings, which are mathematical representations of a given area at a certain time on Earth’s surface. It uses a Vision Transformer architecture adapted to understand geospatial and temporal relations on Earth Observation data. The model is trained via Self-supervised learning (SSL) using a Masked Autoencoder (MAE) method. The Clay model can be used in three main ways: -- Generate semantic embeddings for any location and time. -- Fine-tune the model for downstream tasks such as classification, regression, and generative tasks. -- Use the model as a backbone for other models. +- **Generate semantic embeddings for any location and time.** You can use embeddings for a variety of tasks, including to: + - _Find features:_ Locate objects or features, such as surface mines, aquaculture, or concentrated animal feeding operations. + +- **Fine-tune the model for downstream tasks such as classification, regression, and generative tasks.** Fine-tuning the model takes advantage of its pre-training to more efficiently classify types, predict values, or detect change than from-scratch methods. Embeddings can also be used to do the following, which require fine-tuning: + - _Classify types or predict values of interest:_ Identify the types or classes of a given feature, such as crop type or land cover, or predict values of variables of interest, such as above ground biomass or agricultural productivity. + - _Detect changes over time:_ Find areas that have experienced changes such as deforestation, wildfires, destruction from human conflict, flooding, or urban development. + - This can be done by training a downstream model to take embeddings as input and output predicted classes/values. This could also include fine-tuning model weights to update the embeddings themselves. +- **Use the model as a backbone for other models.** ## Where is what @@ -22,11 +27,11 @@ The Clay model can be used in three main ways: License: [OpenRAIL-M](https://github.com/Clay-foundation/model/blob/main/LICENSE-MODEL.md). - The Clay **documentation** [lives on this site](https://clay-foundation.github.io/model/index.html). License: [CC-BY](http://creativecommons.org/licenses/by/4.0/). -- We release the **embeddings** of the used trainning data on [Source Cooperative](https://beta.source.coop/repositories/clay/clay-model-v0-embeddings). +- We release the **embeddings** of the used training data on [Source Cooperative](https://beta.source.coop/repositories/clay/clay-model-v0-embeddings). License: [ODC-BY](https://opendatacommons.org/licenses/by/). CLAY is a fiscal sponsored project of the 501c3 non-profit -[Radiant Earth Foundation](https://www.radiant.earth). +[Radiant Earth](https://www.radiant.earth). --- ### Table of Contents diff --git a/docs/installation.md b/docs/installation.md index 9069fa05..556c31a4 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -55,7 +55,7 @@ usage details. To generate a unified [`conda-lock.yml`](https://github.com/conda/conda-lock) file based on the dependency specification in `environment.yml`, run: - conda-lock lock --mamba --file environment.yml --platform linux-64 --with-cuda=12.0 + conda-lock lock --mamba --file environment.yml --with-cuda=12.0 Use this only when creating a new `conda-lock.yml` file or refreshing an existing one. ``` diff --git a/docs/model_embeddings.md b/docs/model_embeddings.md index dd20edfa..f34c0232 100644 --- a/docs/model_embeddings.md +++ b/docs/model_embeddings.md @@ -1,13 +1,13 @@ # Generating vector embeddings -Once you have a pretrained model, it is now possible to pass some input images -into the encoder part of the Vision Transformer, and produce vector embeddings +Once you have a pretrained model, it is possible to pass some input images +into the encoder part of the Vision Transformer and produce vector embeddings which contain a semantic representation of the image. ## Producing embeddings from the pretrained model -Step by step instructions to create embeddings for a single MGRS tile location -(e.g. 27WXN). +Step-by-step instructions to create embeddings for a single MGRS tile location +(e.g. 27WXN): 1. Ensure that you can access the 13-band GeoTIFF data files. @@ -15,11 +15,11 @@ Step by step instructions to create embeddings for a single MGRS tile location aws s3 ls s3://clay-tiles-02/02/27WXN/ ``` - This should report a list of filepaths if you have the correct permissions, - otherwise, please set up authentication before continuing. + This should report a list of filepaths if you have the correct permissions. + Otherwise, please set up authentication before continuing. -2. Download the pretrained model weights, and put them in the `checkpoints/` - folder. +2. Download the pretrained model weights and put them in the `checkpoints/` + folder: ```bash aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/ @@ -37,7 +37,7 @@ Step by step instructions to create embeddings for a single MGRS tile location For example, an AWS g5.4xlarge instance would be a cost effective option. ``` -3. Run model inference to generate the embeddings. +3. Run model inference to generate the embeddings: ```bash python trainer.py predict --ckpt_path=checkpoints/clay-small-70MT-1100T-10E.ckpt \ @@ -51,7 +51,7 @@ Step by step instructions to create embeddings for a single MGRS tile location This should output a GeoParquet file containing the embeddings for MGRS tile 27WXN (recall that each 10000x10000 pixel MGRS tile contains hundreds of smaller 512x512 chips), saved to the `data/embeddings/` folder. See the next - sub-section for details about the embeddings file. + subsection for details about the embeddings file. The `embeddings_level` flag determines how the embeddings are calculated. The default is `mean`, resulting in one average embedding per MGRS tile of @@ -61,9 +61,9 @@ Step by step instructions to create embeddings for a single MGRS tile location dimensionality of the encoder output, including the band group dimension. The array size of those embeddings is 6 * 16 * 16 * 768. - The embeddings are flattened into one dimensional arrays because pandas + The embeddings are flattened into one-dimensional arrays because pandas does not allow for multidimensional arrays. This makes it necessary to - reshape the flattened arrays to access the patch level embeddings. + reshape the flattened arrays to access the patch-level embeddings. ```{note} For those interested in how the embeddings were computed, the predict step @@ -113,7 +113,7 @@ Example: `27WXN_20200101_20231231_v001.gpq` ### Table schema -Each row within the GeoParquet table is generated from a 512x512 pixel image, +Each row within the GeoParquet table is generated from a 512x512 pixel image and contains a record of the embeddings, spatiotemporal metadata, and a link to the GeoTIFF file used as the source image for the embedding. The table looks something like this: @@ -161,9 +161,9 @@ Further reading: - https://cloudnativegeo.org/blog/2023/10/the-geoparquet-ecosystem-at-1.0.0 ``` -## Converting to patch level embeddings +## Converting to patch-level embeddings -In the case where patch level embeddings are requested, the resulting array +In the case where patch-level embeddings are requested, the resulting array will have all patch embeddings ravelled in one row. Each row represents a 512x512 pixel image, and contains 16x16 patch embeddings. diff --git a/docs/run_region.md b/docs/run_region.md index 02874d7f..cb9d5826 100644 --- a/docs/run_region.md +++ b/docs/run_region.md @@ -63,26 +63,26 @@ mgrs_aoi.to_file("data/mgrs/mgrs_aoi.fgb") This will select the MGRS tiles that intersect with your AOI. The processing will then happen for each of the MGRS tiles. This will most likely provide -slightly more data than the AOI itself, as the whole tile data will downloaded +slightly more data than the AOI itself, as the whole tile data will be downloaded for each matched MGRS tile. -Each run of th datacube script will take an index as input, which is the index +Each run of the datacube script will take an index as input, which is the index of the MGRS tile within the input file. This is why we need to download the data in a loop. A list of date ranges can be specified. The script will look for the least -cloudy Sentinel-2 scene for each date range, and match Sentinel-1 dates near +cloudy Sentinel-2 scene for each date range and match Sentinel-1 dates near the identified Sentinel-2 dates. -The output folder can be specified as a local folder, or a bucket can be -specified to upload the data to S3. +The output folder can be specified as a local folder or a bucket can be +specified if you want to upload the data to S3. Note that for the script to run, a Microsoft Planetary Computer token needs -to be set up, consult the [Planetary Computer SDK](https://github.com/microsoft/planetary-computer-sdk-for-python) +to be set up. Consult the [Planetary Computer SDK](https://github.com/microsoft/planetary-computer-sdk-for-python) documentation on how to set up the token. By default, the datacube script will download all the data available for each -MGRS tile it processes. So the output might include imagery chips that are +MGRS tile it processes, so the output might include imagery chips that are outside of the AOI specified. To speed up processing in the example below, we use the subset argument to @@ -95,7 +95,7 @@ be downloaded for each MGRS tile. ```bash for i in {0..5}; do -python scripts/datacube.py \ +python scripts/pipeline/datacube.py \ --sample data/mgrs/mgrs_aoi.fgb \ --localpath data/chips \ --index $i \ @@ -110,7 +110,7 @@ done The checkpoints can be accessed directly from Hugging Face at https://huggingface.co/made-with-clay/Clay. -The following command will run the model to create the embeddings, +The following command will run the model to create the embeddings and automatically download and cache the model weights. ```bash diff --git a/docs/specification.md b/docs/specification.md index a3003060..7f638622 100644 --- a/docs/specification.md +++ b/docs/specification.md @@ -9,7 +9,9 @@ Model weights released on 2024/01/12. ### Summary -Clay v0 is a self-supervised modified vision transfer model trained on stacks of Sentinel-2, Sentinel-1 & DEM data. It is trained as a Masked Autoencoder (MAE) to reconstruct the original image from a masked image. +Clay v0 is a self-supervised modified vision transformer model trained on stacks of Sentinel-2, Sentinel-1 & DEM data. It is trained as a Masked Autoencoder (MAE) to reconstruct the original image from a masked image. + +With the pre-trained model, you can input stacks of geospatial data and output vector embeddings, which capture spatial, temporal, and spectral information about Earth and represent these relationships numerically in high-dimensional space. Each embedding is representative of a certain area of Earth at a certain point in time. Each data entry is a stack of 10 bands of Sentinel-2, 2 bands of Sentinel-1 & 1 band of DEM data. The model is trained with 3 timesteps of data for each location, with a total of 1203 MGRS tiles globally distributed, each of size 10km x 10km. The data was collected from the Microsoft Planetary Computer. @@ -143,7 +145,7 @@ inputs it will be necessary to subset these as shown in the partial input tutori * Training Time: * `25` epochs, each taking ~`15h` to train. * Carbon Emissions: - * *Report not yet available from provider, expected March'24* + * According to the "Customer Carbon Emission Tool", there were no Scope 1 or Scope 2 carbon emissions. Following the [documentation](https://docs.aws.amazon.com/awsaccountbilling/latest/aboutv2/ccft-estimation.html), we believe this is due to the usage of renewable energy sources. We are aware that Scope 3 emissions might be significant for data centers and that these are not included in the estimate. * Training stages: * While developing the model we run small tests locally and on the cloud. We estimate that all testing and development compute is less than the compute used for 1 epoch of training. * QA of the model is also done locally and on the cloud, and we estimate that it is less than the compute used for 1 epoch of training. diff --git a/src/model_clay.py b/src/model_clay.py new file mode 100644 index 00000000..74590a2d --- /dev/null +++ b/src/model_clay.py @@ -0,0 +1,1022 @@ +import os +import re +from typing import Literal + +import geopandas as gpd +import lightning as L +import numpy as np +import pandas as pd +import pyarrow as pa +import shapely +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from torch import nn +from vit_pytorch.vit import Transformer + +from src.utils import posemb_sincos_1d, posemb_sincos_2d + +torch.set_float32_matmul_precision(precision="medium") + + +# %% +class Patchify(nn.Module): + """ + Patchify the input cube & create embeddings per patch + """ + + def __init__(self, in_chans, embed_dim, patch_size): + """ + Define layers of patch stem. + + Parameters + ---------- + in_chans : int + Number of input channels + embed_dim : int + Embedding dimension + patch_size : int + Patch size + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, xb): + b, c, h, w = xb.shape + xb = self.proj(xb) + xb = rearrange(xb, "b d p1 p2 -> b (p1 p2) d") + return self.norm(xb) + + +class Encoder(nn.Module): + def __init__( # noqa: PLR0913 + self, + mask_ratio, + image_size, + patch_size, + shuffle, + dim, + depth, + heads, + dim_head, + mlp_ratio, + band_groups, + dropout, + emb_dropout, + ): + super().__init__() + assert ( + image_size % patch_size == 0 + ), "Image dimensions must be divisible by the patch size." + self.mask_ratio = mask_ratio + self.image_size = image_size + self.patch_size = patch_size + self.shuffle = shuffle + self.dim = dim + self.band_groups = band_groups + self.num_spatial_patches = (image_size // patch_size) ** 2 + self.num_group_patches = len(band_groups) + self.num_patches = self.num_spatial_patches * self.num_group_patches + + # Split the embedding dimensions between spatial & band patches equally + pos_dim = band_dim = dim // 2 + + self.latlon_embedding = nn.Linear(2, dim) + self.time_embedding = nn.Linear(3, dim) + self.patch_embedding = nn.ModuleDict( + { + name: Patchify(len(bands), dim, patch_size) + for name, bands in self.band_groups.items() + } + ) + + # Fix the position & band embedding to sine & cosine functions + self.register_buffer( + name="pos_encoding", + tensor=posemb_sincos_2d( + h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim + ), # [L D/2] + persistent=False, + ) + self.register_buffer( + name="band_encoding", + tensor=posemb_sincos_1d( + length=self.num_group_patches, dim=band_dim + ), # [G D/2] + persistent=False, + ) + + # Freeze the weights of position & band encoding + self.pos_encoding = self.pos_encoding.requires_grad_(False) + self.band_encoding = self.band_encoding.requires_grad_(False) + + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer( + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_dim=dim * mlp_ratio, + dropout=dropout, + ) + + def to_patch_embed(self, cube): + """ + Patchify the input cube & create embeddings per patch + + Parameters + ---------- + cube : torch.Tensor + A tensor of shape (B, C, H, W) containing the pixels of the + datacube. + + Returns + ------- + patches : torch.Tensor + A tensor of shape (B, G, L, D) containing the embeddings of the + patches. + """ + patches = [] + for name, bands in self.band_groups.items(): + cubeslice = cube[:, bands, :, :] # [B C H W] -> [B C[slice[...]] H W] + patches.append(self.patch_embedding[name](cubeslice)) + + patches = rearrange(patches, "G B L D -> B G L D") # [B G L D] + return patches # [B G L D] + + def add_encodings(self, patches): + """ + Add position & band encoding to the patches + + Parameters + ---------- + patches : torch.Tensor + A tensor of shape (B, G, L, D) containing the embeddings of the + patches. + + Returns + ------- + patches : torch.Tensor + A tensor of shape (B, G, L, D) containing the embeddings of the + patches + position & band encoding. + """ + self.B, G, L, D = patches.shape + + # Align position & band embeddings across patches + pos_encoding = repeat( + self.pos_encoding, "L D -> 1 repeat L D", repeat=G + ) # [1 G L D/2] + + band_encoding = repeat( + self.band_encoding, "G D -> 1 G repeat D", repeat=L + ) # [1 G L D/2] + + pos_band_encoding = torch.cat( + (pos_encoding, band_encoding), dim=-1 + ) # [1 G L D] + + # Add position & band encoding to the input feature vector + patches = patches + pos_band_encoding # [B G L D] + [1 G L D] - broadcasting + patches = self.dropout(patches) # [B G L D] + return patches # [B G L D] + + def embed_metadata(self, patches, latlon, time): + """ + Add timestep & latlon embedding to the patches + + Parameters + ---------- + patches : torch.Tensor + A tensor of shape (B, GL, D) containing the embeddings of the + patches + position & band encoding. + latlon : torch.Tensor + A tensor of shape (B, 2) containing the latlon of the datacube. + time : torch.Tensor + A tensor of shape (B, 2) containing the timestep of the datacube. + + Returns + ------- + patches : torch.Tensor + A tensor of shape (B, GL, D) containing the embeddings of the + patches + position & band encoding + timestep & latlon embedding. + """ + latlon_embedding = rearrange( + self.latlon_embedding(latlon), "B D -> B 1 D" + ) # [B D] -> [B 1 D] + time_embedding = rearrange( + self.time_embedding(time), "B D -> B 1 D" + ) # [B D] -> [B 1 D] + patches = torch.cat( + [patches, latlon_embedding, time_embedding], dim=1 + ) # [B GL D] + [B 1 D] + [B 1 D] -> [B (GL + 2) D] + return patches # [B (GL + 2) D] + + def mask_out(self, patches): + """ + Mask out patches randomly by shuffling the patches & masking out the + first N patches + + Parameters + ---------- + patches : torch.Tensor + A tensor of shape (B, GL, D) containing the embeddings of the + patches + position & band encoding + timestep & latlon embedding. + + Returns + ------- + unmasked_patches : torch.Tensor + A tensor of shape (B, GL:(1 - mask_ratio), D) containing the + embeddings of the unmasked patches. + unmasked_indices : torch.Tensor + A tensor of shape (B, (1 - mask_ratio)) containing the indices of + the unmasked patches. + masked_indices : torch.Tensor + A tensor of shape (B, mask_ratio) containing the indices of the + masked patches. + masked_matrix : torch.Tensor + A tensor of shape (B, G, L) containing the mask matrix. + """ + B, GL, D = patches.shape + assert ( + GL == self.num_patches + ), f"Expected {self.num_patches} patches, got {GL} patches." + + if self.shuffle: # Shuffle the patches + noise = torch.randn((B, GL), device=patches.device) # [B GL] + else: # Don't shuffle useful for interpolation & inspection of embeddings + noise = rearrange( + torch.arange(B * GL, device=patches.device), "(B GL) -> B GL", B=B + ) + + random_indices = torch.argsort(noise, dim=-1) # [B GL] + reverse_indices = torch.argsort(random_indices, dim=-1) # [B GL] + + num_masked_patches = int( + self.mask_ratio * self.num_patches + ) # Number of patches to be masked out + masked_indices, unmasked_indices = ( + random_indices[:, :num_masked_patches], # [B mask_ratio * GL] + random_indices[:, num_masked_patches:], # [B (1 - mask_ratio) * GL] + ) + + # create a mask of shape B G L, where 1 indicates a masked patch + # and 0 indicates an unmasked patch + masked_matrix = torch.zeros((B, GL), device=patches.device) # [B GL] = 0 + masked_matrix[:, :num_masked_patches] = 1 # [B mask_ratio * GL] = 1 + masked_matrix = torch.gather( + masked_matrix, dim=1, index=reverse_indices + ) # [B GL] -> [B GL] - reorder the patches + masked_matrix = rearrange( + masked_matrix, + "B (G L) -> B G L", + G=self.num_group_patches, # [B G L] + ) + + # mask out the patches + batch_indices = rearrange( + torch.arange(B, device=patches.device), "B -> B 1" + ) # [B 1] + unmasked_patches = patches[ + batch_indices, unmasked_indices, : + ] # [B GL:(1 - mask_ratio) D] + _ = patches[batch_indices, masked_indices, :] # [B GL:mask_ratio D] + + return ( + unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) # [B GL:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B G L] + + def forward(self, datacube): + cube, time, latlon = ( + datacube["pixels"], + datacube["timestep"], + datacube["latlon"], + ) # [B C H W] + + B, C, H, W = cube.shape + + patches = self.to_patch_embed( + cube + ) # [B G L D] - patchify & create embeddings per patch + + patches = self.add_encodings( + patches + ) # [B G L D] - add position & band encoding to the embeddings + + patches = rearrange(patches, "B G L D -> B (G L) D") # [B (GL) D] + patches = self.dropout(patches) # [B (GL) D] + + # mask out patches + ( + unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) = self.mask_out( + patches + ) # [B GL:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B G L] + + # add timestep & latlon embedding to only the unmasked patches + unmasked_patches = self.embed_metadata( + unmasked_patches, latlon, time + ) # [B (GL:(1 - mask_ratio) + 2) D] + + # pass the unmasked patches through the transformer + encoded_unmasked_patches = self.transformer( + unmasked_patches + ) # [B (GL:(1 - mask_ratio) + 2) D] + + return ( + encoded_unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) # [B (GL:(1 - mask_ratio) + 2) D], [(1-mask_ratio)], [mask_ratio], [B G L] + + +class Decoder(nn.Module): + def __init__( # noqa: PLR0913 + self, + mask_ratio, + image_size, + patch_size, + encoder_dim, + dim, + depth, + heads, + dim_head, + mlp_ratio, + band_groups, + dropout, + ): + super().__init__() + self.mask_ratio = mask_ratio + self.image_size = image_size + self.patch_size = patch_size + self.encoder_dim = encoder_dim + self.dim = dim + self.band_groups = band_groups + self.num_spatial_patches = (image_size // patch_size) ** 2 + self.num_group_patches = len(band_groups) + self.num_patches = self.num_spatial_patches * self.num_group_patches + + self.enc_to_dec = ( + nn.Linear(encoder_dim, dim) if encoder_dim != dim else nn.Identity() + ) + self.mask_patch = nn.Parameter(torch.randn(dim)) + self.transformer = Transformer( + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_dim=dim * mlp_ratio, + dropout=dropout, + ) + + # Split the embedding dimensions between spatial & band patches equally + pos_dim = band_dim = dim // 2 + + # Fix the position & band embedding to sine & cosine functions + self.register_buffer( + name="pos_encoding", + tensor=posemb_sincos_2d( + h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim + ), # [L D/2] + persistent=False, + ) + self.register_buffer( + name="band_encoding", + tensor=posemb_sincos_1d( + length=self.num_group_patches, dim=band_dim + ), # [G D/2] + persistent=False, + ) + + # Freeze the weights of position & band encoding + self.pos_encoding = self.pos_encoding.requires_grad_(False) + self.band_encoding = self.band_encoding.requires_grad_(False) + + self.embed_to_pixels = nn.ModuleDict( + { + name: nn.Linear(dim, (patch_size**2) * len(bands)) + for name, bands in self.band_groups.items() + } + ) + + def reconstruct_and_add_encoding( + self, unmasked_patches, unmasked_indices, masked_indices + ): + """ + Reconstruct the input patches from the random mask patch & add position + & band encoding to them. + + Parameters + ---------- + unmasked_patches : torch.Tensor + A tensor of shape (B, GL:(1 - mask_ratio), D) containing the + embeddings of the unmasked patches. + unmasked_indices : torch.Tensor + A tensor of shape (B, (1 - mask_ratio)) containing the indices of + the unmasked patches. + masked_indices : torch.Tensor + A tensor of shape (B, mask_ratio) containing the indices of the + masked patches. + + Returns + ------- + decoder_patches : torch.Tensor + A tensor of shape (B, GL, D) containing the embeddings for the + decoder part of the model. + """ + B, *_ = unmasked_patches.shape + + # Align position & band embeddings across patches + pos_encoding = repeat( + self.pos_encoding, "L D -> 1 repeat L D", repeat=self.num_group_patches + ) # [1 G L D/2] + band_encoding = repeat( + self.band_encoding, "G D -> 1 G repeat D", repeat=self.num_spatial_patches + ) # [1 G L D/2] + + pos_band_encoding = torch.cat( + (pos_encoding, band_encoding), dim=-1 + ) # [1 G L D] + pos_band_encoding = rearrange( + pos_band_encoding, "1 G L D -> 1 (G L) D" + ) # [1 (GL) D] + pos_band_encoding = repeat( + pos_band_encoding, "1 (GL) D -> B (GL) D", B=B + ) # [B (GL) D] + + batch_indices = rearrange( + torch.arange(B, device=unmasked_patches.device), "B -> B 1" + ) # [B 1] + unmasked_pos_band_encoding = pos_band_encoding[ + batch_indices, unmasked_indices, : + ] # [B (GL:(1 - mask_ratio)) D] + masked_pos_band_encoding = pos_band_encoding[ + batch_indices, masked_indices, : + ] # [B (GL:mask_ratio) D] + + # Reconstruct the masked patches from the random mask patch & + # add position & band encoding to them + num_masked_patches = int(self.mask_ratio * self.num_patches) + masked_patches = repeat( + self.mask_patch, "D -> B GL D", B=B, GL=num_masked_patches + ) # [B GL:mask_ratio D] + masked_patches = ( + masked_patches + masked_pos_band_encoding + ) # [B GL:mask_ratio D] + [B GL:mask_ratio D] + + # Add position & band encoding to the unmasked patches + unmasked_patches = ( + unmasked_patches + unmasked_pos_band_encoding + ) # [B GL:(1 - masked_ratio) D] + [B GL:(1 - mask_ratio) D] + + # Concatenate the masked & unmasked patches + decoder_patches = torch.zeros( + (B, self.num_patches, self.dim), device=unmasked_patches.device + ) # [B GL D] + decoder_patches[batch_indices, unmasked_indices, :] = ( + unmasked_patches # [B GL:(1 - mask_ratio) D] + ) + decoder_patches[batch_indices, masked_indices, :] = ( + masked_patches # [B GL:mask_ratio D] + ) + + return decoder_patches # [B GL D] + + def pixelify(self, patches): + """ + Convert the patches into pixel space to compute the loss + + Parameters + ---------- + patches : torch.Tensor + A tensor of shape (B, GL, D) containing the embeddings from the + decoder part of the model. + + Returns + ------- + pixels : torch.Tensor + A tensor of shape (B, C, L, PP) containing the pixels of the + datacube. + """ + patches = rearrange( + patches, "B (G L) D -> B G L D", G=len(self.band_groups) + ) # [B G L D] + pixels = [] + for i, (name, bands) in enumerate(self.band_groups.items()): + group_embeddings = patches[:, i, :, :] # [B L D] + group_pixels = self.embed_to_pixels[name](group_embeddings) # [B L (P P C)] + group_pixels = rearrange( + group_pixels, + "B L (PP C) -> B C L PP", + PP=(self.patch_size**2), + ) # [B C L PP] + pixels.append(group_pixels) # [B C L PP] + + pixels = torch.cat(pixels, dim=1) # [B C L PP] + return pixels # [B C L PP] + + def forward(self, encoded_unmasked_patches, unmasked_indices, masked_indices): + # Change the embedding dimension from encoder to decoder + encoded_unmasked_patches = self.enc_to_dec(encoded_unmasked_patches) + + # Split the patches into encoded unmasked patches & meta patches + encoded_unmasked_patches, encoded_unmasked_meta_patches = ( + encoded_unmasked_patches[:, :-2, :], + encoded_unmasked_patches[:, -2:, :], + ) # [B (GL:(1 - mask_ratio)) D], [B 2 D] + + # Reconstruct the patches to feed into the decoder transformer + decoder_patches = self.reconstruct_and_add_encoding( + encoded_unmasked_patches, unmasked_indices, masked_indices + ) # [B GL D] + + # Add the metadata patches back to the decoder patches + decoder_patches = torch.cat( + [decoder_patches, encoded_unmasked_meta_patches], dim=1 + ) # [B (GL + 2) D] + + # Pass the decoder patches through the transformer + decoded_patches = self.transformer(decoder_patches) # [B (GL + 2) D] + + # Remove the metadata patches from the decoded patches + decoded_patches = decoded_patches[:, :-2, :] # [B GL D] + + # Piixelify the decoded patches + pixels = self.pixelify(decoded_patches) # [B C L PP] + return pixels + + +class CLAY(nn.Module): + def __init__( # noqa: PLR0913 + self, + mask_ratio, + image_size, + patch_size, + shuffle, + # ENCODER + dim, + depth, + heads, + dim_head, + mlp_ratio, + dropout, + emb_dropout, + # DECODER + decoder_dim, + decoder_depth, + decoder_heads, + decoder_dim_head, + decoder_mlp_ratio, + decoder_dropout, + # EO + band_groups={ + "rgb": (2, 1, 0), + "rededge": (3, 4, 5, 7), + "nir": (6,), + "swir": (8, 9), + "sar": (10, 11), + "dem": (12,), + }, + **kwargs, + ): + super().__init__() + self.mask_ratio = mask_ratio + self.image_size = image_size + self.patch_size = patch_size + self.shuffle = shuffle + self.band_groups = band_groups + + self.encoder = Encoder( + mask_ratio=mask_ratio, + image_size=image_size, + patch_size=patch_size, + shuffle=shuffle, + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_ratio=mlp_ratio, + band_groups=band_groups, + dropout=dropout, + emb_dropout=emb_dropout, + ) + + self.decoder = Decoder( + mask_ratio=mask_ratio, + image_size=image_size, + patch_size=patch_size, + encoder_dim=dim, + dim=decoder_dim, + depth=decoder_depth, + heads=decoder_heads, + dim_head=decoder_dim_head, + mlp_ratio=decoder_mlp_ratio, + band_groups=band_groups, + dropout=decoder_dropout, + ) + + def per_pixel_loss(self, cube, pixels, masked_matrix): + """ + Compute the per pixel loss + + Parameters + ---------- + cube : torch.Tensor + A tensor of shape (B, C, H, W) containing the pixels of the + datacube. + pixels : torch.Tensor + A tensor of shape (B, C, L, PP) containing the pixels per patch of + the datacube. + masked_matrix : torch.Tensor + A tensor of shape (B, G, L) containing the mask matrix. + + Returns + ------- + loss + """ + patches = rearrange( + cube, + "B C (h p1) (w p2) -> B C (h w) (p1 p2)", + p1=self.patch_size, + p2=self.patch_size, + ) # [B C L PP] + + # loss = (patches - pixels) ** 2 # loss per pixel + loss = F.mse_loss(patches, pixels, reduction="none") # loss per pixel + loss = reduce(loss, "B C L PP -> B C L", reduction="mean") # loss per patch + + # mask out the loss for unmasked patches + actual_loss, masked_patches_in_group = 0.0, 0.0 + for i, (name, group) in enumerate(self.band_groups.items()): + group_loss = reduce( + loss[:, group, :], "B G L -> B L", "mean" + ) # (B, L) - loss per group + actual_loss += ( + group_loss * masked_matrix[:, i] + ).sum() # (B, L) * (B, L) -> (B, L) -> (B) -> scalar + masked_patches_in_group += masked_matrix[ + :, i + ].sum() # (B, L) -> (B) -> scalar + + return actual_loss / masked_patches_in_group + + def forward(self, datacube): + # ENCODER + ( + encoded_unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) = self.encoder( + datacube + ) # [B (GL:(1 - mask_ratio) + 2) D], [(1-mask_ratio)], [mask_ratio], [B G L] + + # DECODER + pixels = self.decoder( + encoded_unmasked_patches, unmasked_indices, masked_indices + ) # [B C L PP] + + # LOSS + loss = self.per_pixel_loss(datacube["pixels"], pixels, masked_matrix) + + return loss + + +def clay_tiny(**kwargs): + args = { + # ENCODER + "dim": 256, + "depth": 4, + "heads": 4, + "dim_head": 64, + "mlp_ratio": 2, + "dropout": 0.0, + "emb_dropout": 0.0, + # DECODER + "decoder_dim": 128, + "decoder_depth": 2, + "decoder_heads": 2, + "decoder_dim_head": 64, + "decoder_mlp_ratio": 2, + "decoder_dropout": 0.0, + } + args.update(kwargs) + model = CLAY(**args) + return model + + +def clay_small(**kwargs): + args = { + # ENCODER + "dim": 768, + "depth": 12, + "heads": 12, + "dim_head": 64, + "mlp_ratio": 4, + "dropout": 0.0, + "emb_dropout": 0.0, + # DECODER + "decoder_dim": 512, + "decoder_depth": 8, + "decoder_heads": 8, + "decoder_dim_head": 64, + "decoder_mlp_ratio": 4, + "decoder_dropout": 0.0, + } + args.update(kwargs) + model = CLAY(**args) + return model + + +def clay_medium(**kwargs): + args = { + # ENCODER + "dim": 1024, + "depth": 24, + "heads": 16, + "dim_head": 64, + "mlp_ratio": 4, + "dropout": 0.0, + "emb_dropout": 0.0, + # DECODER + "decoder_dim": 512, + "decoder_depth": 8, + "decoder_heads": 16, + "decoder_dim_head": 64, + "decoder_mlp_ratio": 4, + "decoder_dropout": 0.0, + } + args.update(kwargs) + model = CLAY(**args) + return model + + +def clay_large(**kwargs): + args = { + # ENCODER + "dim": 1280, + "depth": 32, + "heads": 16, + "dim_head": 64, + "mlp_ratio": 4, + "dropout": 0.0, + "emb_dropout": 0.0, + # DECODER + "decoder_dim": 512, + "decoder_depth": 8, + "decoder_heads": 16, + "decoder_dim_head": 64, + "decoder_mlp_ratio": 4, + "decoder_dropout": 0.0, + } + args.update(kwargs) + model = CLAY(**args) + return model + + +class CLAYModule(L.LightningModule): + def __init__( # noqa: PLR0913 + self, + model_size="small", + mask_ratio=0.75, + image_size=512, + patch_size=32, + shuffle=False, + lr=1e-4, + wd=0.05, + b1=0.9, + b2=0.95, + embeddings_level: Literal["mean", "patch", "group"] = "mean", + band_groups={ + "rgb": (2, 1, 0), + "rededge": (3, 4, 5, 7), + "nir": (6,), + "swir": (8, 9), + "sar": (10, 11), + "dem": (12,), + }, + ): + super().__init__() + self.save_hyperparameters(logger=True) + model_map = { + "tiny": clay_tiny, + "small": clay_small, + "medium": clay_medium, + "large": clay_large, + } + if model_size in model_map: + model_args = { + "mask_ratio": mask_ratio, + "image_size": image_size, + "patch_size": patch_size, + "shuffle": shuffle, + } + if band_groups: + model_args["band_groups"] = band_groups + self.model = model_map[model_size](**model_args) + else: + raise ValueError( + f"Invalid model size {model_size}. Expected one of {model_map.keys()}" + ) + + def forward(self, cube: dict[str, torch.Tensor]): + return self.model(cube) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.hparams.lr, + weight_decay=self.hparams.wd, + betas=(self.hparams.b1, self.hparams.b2), + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=1000, T_mult=2, eta_min=self.hparams.lr * 100, last_epoch=-1 + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + }, + } + + def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, phase: str): + cube = batch + loss = self(cube) + self.log( + name=f"{phase}/loss", + value=loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return loss + + def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): + return self.shared_step(batch, batch_idx, phase="train") + + def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int): + return self.shared_step(batch, batch_idx, phase="val") + + def predict_step( + self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int + ) -> gpd.GeoDataFrame: + """ + Logic for the neural network's prediction loop. + """ + # Get image, bounding box, EPSG code, and date inputs + # x: torch.Tensor = batch["pixels"] # image of shape (1, 13, 512, 512) # BCHW + bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes + epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code + dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12'] + source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif'] + "source_url" + ] + + # Forward encoder + self.model.encoder.mask_ratio = 0.0 # disable masking + outputs_encoder: dict = self.model.encoder( + datacube=batch # input (pixels, timestep, latlon) + ) + + # Get embeddings generated from encoder + # (encoded_unmasked_patches, _, _, _) = outputs_encoder + embeddings_raw: torch.Tensor = outputs_encoder[0] + assert embeddings_raw.shape == torch.Size( + [self.model.encoder.B, 1538, 768] # (batch_size, seq_length, hidden_size) + ) + assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding + + if self.hparams.embeddings_level == "mean": + # Take the mean of the embeddings along the sequence_length dimension + # excluding the last two latlon_ and time_ embeddings, i.e. compute + # mean over patch embeddings only + embeddings_output: torch.Tensor = embeddings_raw[:, :-2, :].mean(dim=1) + expected_size = [self.model.encoder.B, 768] # (batch_size, hidden_size) + elif self.hparams.embeddings_level in ["patch", "group"]: + # Take the mean of the embeddings along the group dimension + # excluding the last two latlon_ and time_ embeddings. This + # results in one embedding per patch. + embeddings_output = rearrange( + embeddings_raw[:, :-2, :], "b (g h w) d -> b g h w d", w=16, h=16, g=6 + ) + if self.hparams.embeddings_level == "patch": + embeddings_output = reduce( + embeddings_output, "b g h w d -> b h w d", "mean" + ) + expected_size = [ + self.model.encoder.B, + 16, + 16, + 768, + ] + else: + expected_size = [ + self.model.encoder.B, + 6, + 16, + 16, + 768, + ] + else: + raise ValueError( + f"Value {self.hparams.embeddings_level} no allowed. " + "Choose one from mean, patch, or group" + ) + + assert embeddings_output.shape == torch.Size(expected_size) + + # Create table to store the embeddings with spatiotemporal metadata + unique_epsg_codes = set(int(epsg) for epsg in epsgs) + if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG + epsg: int = batch["epsg"][0] + else: + raise NotImplementedError( + f"More than 1 EPSG code detected: {unique_epsg_codes}" + ) + + gdf = gpd.GeoDataFrame( + data={ + "source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"), + "date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( + dtype="date32[day][pyarrow]" + ), + "embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( + np.ascontiguousarray(embeddings_output.cpu().detach().__array__()) + ), + }, + geometry=shapely.box( + xmin=bboxes[:, 0], + ymin=bboxes[:, 1], + xmax=bboxes[:, 2], + ymax=bboxes[:, 3], + ), + crs=f"EPSG:{epsg}", + ) + gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates + + return gdf + + def on_predict_epoch_end(self) -> gpd.GeoDataFrame: + """ + Logic to gather all the results from one epoch in a prediction loop. + """ + # Combine list of geopandas.GeoDataFrame objects + results: list[gpd.GeoDataFrame] = self.trainer.predict_loop.predictions + if results: + gdf: gpd.GeoDataFrame = pd.concat( + objs=results, axis="index", ignore_index=True + ) + else: + print( + "No embeddings generated, " + f"possibly no GeoTIFF files in {self.trainer.datamodule.data_dir}" + ) + return + + # Save embeddings in GeoParquet format, one file for each MGRS code + outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings" + os.makedirs(name=outfolder, exist_ok=True) + + # Find unique MGRS names (e.g. '12ABC'), e.g. + # from 's3://.../.../claytile_12ABC_20201231_v02_0001.tif', get 12ABC + mgrs_codes = gdf.source_url.str.split("/").str[-1].str.split("_").str[1] + unique_mgrs_codes = mgrs_codes.unique() + for mgrs_code in unique_mgrs_codes: + if re.match(pattern=r"(\d{2}[A-Z]{3})", string=mgrs_code) is None: + raise ValueError( + "MGRS code should have 2 numbers and 3 letters (e.g. 12ABC), " + f"but got {mgrs_code} instead" + ) + + # Subset GeoDataFrame to a single MGRS code + _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code].reset_index() + + # Get min/max date from GeoDataFrame + minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"]) + min_date: str = minmax_date["min"].strftime("%Y%m%d") + max_date: str = minmax_date["max"].strftime("%Y%m%d") + + # Output to a GeoParquet filename like + # {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq + outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq" + _gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0") + print( + f"Saved {len(_gdf)} rows of embeddings of " + f"shape {gdf.embeddings.iloc[0].shape} to {outpath}" + ) + + return gdf