-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added single recon example notebook, routine pep8 work
- Loading branch information
cssprad1
committed
Nov 13, 2024
1 parent
7fe8753
commit d596fbc
Showing
27 changed files
with
680 additions
and
122 deletions.
There are no files selected for viewing
368 changes: 368 additions & 0 deletions
368
notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,368 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5facdc34-efbd-4082-91ef-e70a4f34c441", | ||
"metadata": {}, | ||
"source": [ | ||
"# SatVision-TOA Reconstruction Example Notebook\n", | ||
"\n", | ||
"This notebook demonstrates the reconstruction capabilities of the SatVision-TOA model, designed to process and reconstruct MODIS TOA (Top of Atmosphere) imagery using Masked Image Modeling (MIM) for Earth observation tasks.\n", | ||
"\n", | ||
"Follow this step-by-step guide to install necessary dependencies, load model weights, transform data, make predictions, and visualize the results.\n", | ||
"\n", | ||
"## 1. Setup and Install Dependencies\n", | ||
"\n", | ||
"The following packages are required to run the notebook:\n", | ||
"- `yacs` – for handling configuration\n", | ||
"- `timm` – for Transformer and Image Models in PyTorch\n", | ||
"- `segmentation-models-pytorch` – for segmentation utilities\n", | ||
"- `termcolor` – for colored terminal text\n", | ||
"- `webdataset==0.2.86` – for handling datasets from web sources" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e5e08cd1-d8df-4dd8-b884-d452ef90943b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# !pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b4506576-5e30-417d-96de-8953d71c76c2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"import time\n", | ||
"import random\n", | ||
"import datetime\n", | ||
"from tqdm import tqdm\n", | ||
"import numpy as np\n", | ||
"import logging\n", | ||
"\n", | ||
"import torch\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from matplotlib.backends.backend_pdf import PdfPages\n", | ||
"\n", | ||
"import warnings\n", | ||
"\n", | ||
"warnings.filterwarnings('ignore') " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "775cb720-5151-49fa-a7d5-7291ef663d45", | ||
"metadata": {}, | ||
"source": [ | ||
"## 2. Model and Configuration Imports\n", | ||
"\n", | ||
"We load necessary modules from the pytorch-caney library, including the model, transformations, and plotting utilities." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "edf47149-f489-497b-8601-89a7e8dbd9b9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sys.path.append('../../pytorch-caney')\n", | ||
"\n", | ||
"from pytorch_caney.models.mim import build_mim_model\n", | ||
"from pytorch_caney.transforms.mim_modis_toa import MimTransform\n", | ||
"from pytorch_caney.configs.config import _C, _update_config_from_file\n", | ||
"from pytorch_caney.plotting.modis_toa import plot_export_pdf" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fe00e78e-fca3-4221-86dd-da205fed4192", | ||
"metadata": {}, | ||
"source": [ | ||
"## 2. Fetching the model\n", | ||
"\n", | ||
"### 2.1 Clone model ckpt from huggingface\n", | ||
"\n", | ||
"Model repo: https://huggingface.co/nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n", | ||
"\n", | ||
"```bash\n", | ||
"# On prism/explore system\n", | ||
"module load git-lfs\n", | ||
"\n", | ||
"git lfs install\n", | ||
"\n", | ||
"git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n", | ||
"```\n", | ||
"\n", | ||
"Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.\n", | ||
"https://huggingface.co/docs/hub/security-git-ssh\n", | ||
"\n", | ||
"```bash\n", | ||
"# If this outputs as anon, follow the next steps.\n", | ||
"ssh -T git@hf.co\n", | ||
"```\n", | ||
"\n", | ||
"\n", | ||
"```bash\n", | ||
"eval $(ssh-agent)\n", | ||
"\n", | ||
"# Check if ssh-agent is using the proper key\n", | ||
"ssh-add -l\n", | ||
"\n", | ||
"# If not\n", | ||
"ssh-add ~/.ssh/your-key\n", | ||
"\n", | ||
"# Or if you want to use the default id_* key, just do\n", | ||
"ssh-add\n", | ||
"\n", | ||
"```\n", | ||
"\n", | ||
"## 3. Fetching the validation dataset\n", | ||
"\n", | ||
"### 3.1 Clone dataset repo from huggingface\n", | ||
"\n", | ||
"Dataset repo: https://huggingface.co/datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n", | ||
"\n", | ||
"\n", | ||
"```bash\n", | ||
"# On prims/explore system\n", | ||
"module load git-lfs\n", | ||
"\n", | ||
"git lfs install\n", | ||
"\n", | ||
"git clone git@hf.co:datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n", | ||
"\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "abb754ff-1753-4a4c-804e-8e3e5461fd0a", | ||
"metadata": {}, | ||
"source": [ | ||
"## 4. Define Model and Data Paths\n", | ||
"\n", | ||
"Specify paths to model checkpoint, configuration file, and the validation dataset. Customize these paths as needed for your environment." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7ec267ce-ded1-40e6-8443-e1037297f710", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"MODEL_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mp_rank_00_model_states.pt'\n", | ||
"CONFIG_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml'\n", | ||
"\n", | ||
"OUTPUT: str = '.'\n", | ||
"DATA_PATH: str = '../../modis_toa_cloud_reconstruction_validation/sv_toa_128_chip_validation_04_24.npy'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bd7d0b93-7fd3-49cb-ab9e-7536820ec5f2", | ||
"metadata": {}, | ||
"source": [ | ||
"## 5. Configure Model\n", | ||
"\n", | ||
"Load and update the configuration for the SatVision-TOA model, specifying model and data paths." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aac43f0e-dc4b-49ba-a482-933b5bab4b79", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Update config given configurations\n", | ||
"\n", | ||
"config = _C.clone()\n", | ||
"_update_config_from_file(config, CONFIG_PATH)\n", | ||
"\n", | ||
"config.defrost()\n", | ||
"config.MODEL.PRETRAINED = MODEL_PATH\n", | ||
"config.DATA.DATA_PATHS = [DATA_PATH]\n", | ||
"config.OUTPUT = OUTPUT\n", | ||
"config.freeze()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1d596904-d1df-4f6d-8e88-4c647ac26924", | ||
"metadata": {}, | ||
"source": [ | ||
"## 6. Load Model Weights from Checkpoint\n", | ||
"\n", | ||
"Build and initialize the model from the checkpoint to prepare for evaluation." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bfe245f7-589e-4b02-9990-15cb1733f6cb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print('Building un-initialized model')\n", | ||
"model = build_mim_model(config)\n", | ||
"print('Successfully built uninitialized model')\n", | ||
"\n", | ||
"print(f'Attempting to load checkpoint from {config.MODEL.PRETRAINED}')\n", | ||
"checkpoint = torch.load(config.MODEL.PRETRAINED)\n", | ||
"model.load_state_dict(checkpoint['module'])\n", | ||
"print('Successfully applied checkpoint')\n", | ||
"model.cuda()\n", | ||
"model.eval()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "20c26d1e-125a-4b4c-a21e-ab07d6977222", | ||
"metadata": {}, | ||
"source": [ | ||
"## 7. Transform Validation Data\n", | ||
"\n", | ||
"The MODIS TOA dataset is loaded and transformed using MimTransform, generating a masked dataset for reconstruction." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9b3b47b1-0690-4ef9-bed6-ec243b5d42cb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Use the Masked-Image-Modeling transform specific to MODIS TOA data\n", | ||
"transform = MimTransform(config)\n", | ||
"\n", | ||
"# The reconstruction evaluation set is a single numpy file\n", | ||
"validation_dataset_path = config.DATA.DATA_PATHS[0]\n", | ||
"validation_dataset = np.load(validation_dataset_path)\n", | ||
"len_batch = range(validation_dataset.shape[0])\n", | ||
"\n", | ||
"# Apply transform to each image in the batch\n", | ||
"# A mask is auto-generated in the transform\n", | ||
"imgMasks = [transform(validation_dataset[idx]) for idx \\\n", | ||
" in len_batch]\n", | ||
"\n", | ||
"# Seperate img and masks, cast masks to torch tensor\n", | ||
"img = torch.stack([imgMask[0] for imgMask in imgMasks])\n", | ||
"mask = torch.stack([torch.from_numpy(imgMask[1]) for \\\n", | ||
" imgMask in imgMasks])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8b2148e4-da6d-4ae0-a194-c7adb62728a0", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"## 8. Prediction\n", | ||
"\n", | ||
"Run predictions on each sample and calculate reconstruction losses. Each image is processed individually to track individual losses." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a3814751-f352-456e-850c-fe1d289b1d6b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"inputs = []\n", | ||
"outputs = []\n", | ||
"masks = []\n", | ||
"losses = []\n", | ||
"\n", | ||
"# We could do this in a single batch however we\n", | ||
"# want to report the loss per-image, in place of\n", | ||
"# loss per-batch.\n", | ||
"for i in tqdm(range(img.shape[0])):\n", | ||
" single_img = img[i].unsqueeze(0)\n", | ||
" single_mask = mask[i].unsqueeze(0)\n", | ||
" single_img = single_img.cuda(non_blocking=True)\n", | ||
" single_mask = single_mask.cuda(non_blocking=True)\n", | ||
"\n", | ||
" with torch.no_grad():\n", | ||
" z = model.encoder(single_img, single_mask)\n", | ||
" img_recon = model.decoder(z)\n", | ||
" loss = model(single_img, single_mask)\n", | ||
"\n", | ||
" inputs.extend(single_img.cpu())\n", | ||
" masks.extend(single_mask.cpu())\n", | ||
" outputs.extend(img_recon.cpu())\n", | ||
" losses.append(loss.cpu()) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "22329bb4-5c6e-42dc-a492-8863fc2bf672", | ||
"metadata": {}, | ||
"source": [ | ||
"## 9. Export Reconstruction Results to PDF\n", | ||
"\n", | ||
"Save and visualize the reconstruction results. The output PDF will contain reconstructed images with original and masked versions." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8ac6a09d-5fe2-4aa9-ac37-f235d5a8020a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pdfPath = '../../satvision-toa-reconstruction-validation-giant-example.pdf'\n", | ||
"rgbIndex = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n", | ||
"plot_export_pdf(pdfPath, inputs, outputs, masks, rgbIndex)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1e0eb426-c7b4-47d4-aefa-2199ecfce2ab", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook provides an end-to-end example for reconstructing satellite images with the SatVision-TOA model, from setup through prediction and output visualization." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "62065e24-ddf2-4bf1-8362-90dc0c9bf49e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "ILAB Kernel (Pytorch)", | ||
"language": "python", | ||
"name": "pytorch-kernel" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,9 @@ | ||
from .model_factory import ModelFactory | ||
from .mim import MiMModel | ||
from .mim import MiMModel | ||
from .heads import SegmentationHead | ||
from .decoders import FcnDecoder | ||
from .encoders import SatVision, SwinTransformerV2, FcnEncoder | ||
|
||
|
||
__all__ = [ModelFactory, MiMModel, SegmentationHead, | ||
FcnDecoder, SatVision, SwinTransformerV2, FcnEncoder] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from .fcn_decoder import FcnDecoder | ||
from .fcn_decoder import FcnDecoder | ||
|
||
|
||
__all__ = [FcnDecoder] |
Oops, something went wrong.