Skip to content

Commit

Permalink
added single recon example notebook, routine pep8 work
Browse files Browse the repository at this point in the history
  • Loading branch information
cssprad1 committed Nov 13, 2024
1 parent 7fe8753 commit d596fbc
Show file tree
Hide file tree
Showing 27 changed files with 680 additions and 122 deletions.
368 changes: 368 additions & 0 deletions notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5facdc34-efbd-4082-91ef-e70a4f34c441",
"metadata": {},
"source": [
"# SatVision-TOA Reconstruction Example Notebook\n",
"\n",
"This notebook demonstrates the reconstruction capabilities of the SatVision-TOA model, designed to process and reconstruct MODIS TOA (Top of Atmosphere) imagery using Masked Image Modeling (MIM) for Earth observation tasks.\n",
"\n",
"Follow this step-by-step guide to install necessary dependencies, load model weights, transform data, make predictions, and visualize the results.\n",
"\n",
"## 1. Setup and Install Dependencies\n",
"\n",
"The following packages are required to run the notebook:\n",
"- `yacs` – for handling configuration\n",
"- `timm` – for Transformer and Image Models in PyTorch\n",
"- `segmentation-models-pytorch` – for segmentation utilities\n",
"- `termcolor` – for colored terminal text\n",
"- `webdataset==0.2.86` – for handling datasets from web sources"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5e08cd1-d8df-4dd8-b884-d452ef90943b",
"metadata": {},
"outputs": [],
"source": [
"# !pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4506576-5e30-417d-96de-8953d71c76c2",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import time\n",
"import random\n",
"import datetime\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import logging\n",
"\n",
"import torch\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.backends.backend_pdf import PdfPages\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings('ignore') "
]
},
{
"cell_type": "markdown",
"id": "775cb720-5151-49fa-a7d5-7291ef663d45",
"metadata": {},
"source": [
"## 2. Model and Configuration Imports\n",
"\n",
"We load necessary modules from the pytorch-caney library, including the model, transformations, and plotting utilities."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "edf47149-f489-497b-8601-89a7e8dbd9b9",
"metadata": {},
"outputs": [],
"source": [
"sys.path.append('../../pytorch-caney')\n",
"\n",
"from pytorch_caney.models.mim import build_mim_model\n",
"from pytorch_caney.transforms.mim_modis_toa import MimTransform\n",
"from pytorch_caney.configs.config import _C, _update_config_from_file\n",
"from pytorch_caney.plotting.modis_toa import plot_export_pdf"
]
},
{
"cell_type": "markdown",
"id": "fe00e78e-fca3-4221-86dd-da205fed4192",
"metadata": {},
"source": [
"## 2. Fetching the model\n",
"\n",
"### 2.1 Clone model ckpt from huggingface\n",
"\n",
"Model repo: https://huggingface.co/nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n",
"\n",
"```bash\n",
"# On prism/explore system\n",
"module load git-lfs\n",
"\n",
"git lfs install\n",
"\n",
"git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n",
"```\n",
"\n",
"Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.\n",
"https://huggingface.co/docs/hub/security-git-ssh\n",
"\n",
"```bash\n",
"# If this outputs as anon, follow the next steps.\n",
"ssh -T git@hf.co\n",
"```\n",
"\n",
"\n",
"```bash\n",
"eval $(ssh-agent)\n",
"\n",
"# Check if ssh-agent is using the proper key\n",
"ssh-add -l\n",
"\n",
"# If not\n",
"ssh-add ~/.ssh/your-key\n",
"\n",
"# Or if you want to use the default id_* key, just do\n",
"ssh-add\n",
"\n",
"```\n",
"\n",
"## 3. Fetching the validation dataset\n",
"\n",
"### 3.1 Clone dataset repo from huggingface\n",
"\n",
"Dataset repo: https://huggingface.co/datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n",
"\n",
"\n",
"```bash\n",
"# On prims/explore system\n",
"module load git-lfs\n",
"\n",
"git lfs install\n",
"\n",
"git clone git@hf.co:datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "abb754ff-1753-4a4c-804e-8e3e5461fd0a",
"metadata": {},
"source": [
"## 4. Define Model and Data Paths\n",
"\n",
"Specify paths to model checkpoint, configuration file, and the validation dataset. Customize these paths as needed for your environment."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ec267ce-ded1-40e6-8443-e1037297f710",
"metadata": {},
"outputs": [],
"source": [
"MODEL_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mp_rank_00_model_states.pt'\n",
"CONFIG_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml'\n",
"\n",
"OUTPUT: str = '.'\n",
"DATA_PATH: str = '../../modis_toa_cloud_reconstruction_validation/sv_toa_128_chip_validation_04_24.npy'"
]
},
{
"cell_type": "markdown",
"id": "bd7d0b93-7fd3-49cb-ab9e-7536820ec5f2",
"metadata": {},
"source": [
"## 5. Configure Model\n",
"\n",
"Load and update the configuration for the SatVision-TOA model, specifying model and data paths."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aac43f0e-dc4b-49ba-a482-933b5bab4b79",
"metadata": {},
"outputs": [],
"source": [
"# Update config given configurations\n",
"\n",
"config = _C.clone()\n",
"_update_config_from_file(config, CONFIG_PATH)\n",
"\n",
"config.defrost()\n",
"config.MODEL.PRETRAINED = MODEL_PATH\n",
"config.DATA.DATA_PATHS = [DATA_PATH]\n",
"config.OUTPUT = OUTPUT\n",
"config.freeze()"
]
},
{
"cell_type": "markdown",
"id": "1d596904-d1df-4f6d-8e88-4c647ac26924",
"metadata": {},
"source": [
"## 6. Load Model Weights from Checkpoint\n",
"\n",
"Build and initialize the model from the checkpoint to prepare for evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfe245f7-589e-4b02-9990-15cb1733f6cb",
"metadata": {},
"outputs": [],
"source": [
"print('Building un-initialized model')\n",
"model = build_mim_model(config)\n",
"print('Successfully built uninitialized model')\n",
"\n",
"print(f'Attempting to load checkpoint from {config.MODEL.PRETRAINED}')\n",
"checkpoint = torch.load(config.MODEL.PRETRAINED)\n",
"model.load_state_dict(checkpoint['module'])\n",
"print('Successfully applied checkpoint')\n",
"model.cuda()\n",
"model.eval()"
]
},
{
"cell_type": "markdown",
"id": "20c26d1e-125a-4b4c-a21e-ab07d6977222",
"metadata": {},
"source": [
"## 7. Transform Validation Data\n",
"\n",
"The MODIS TOA dataset is loaded and transformed using MimTransform, generating a masked dataset for reconstruction."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b3b47b1-0690-4ef9-bed6-ec243b5d42cb",
"metadata": {},
"outputs": [],
"source": [
"# Use the Masked-Image-Modeling transform specific to MODIS TOA data\n",
"transform = MimTransform(config)\n",
"\n",
"# The reconstruction evaluation set is a single numpy file\n",
"validation_dataset_path = config.DATA.DATA_PATHS[0]\n",
"validation_dataset = np.load(validation_dataset_path)\n",
"len_batch = range(validation_dataset.shape[0])\n",
"\n",
"# Apply transform to each image in the batch\n",
"# A mask is auto-generated in the transform\n",
"imgMasks = [transform(validation_dataset[idx]) for idx \\\n",
" in len_batch]\n",
"\n",
"# Seperate img and masks, cast masks to torch tensor\n",
"img = torch.stack([imgMask[0] for imgMask in imgMasks])\n",
"mask = torch.stack([torch.from_numpy(imgMask[1]) for \\\n",
" imgMask in imgMasks])"
]
},
{
"cell_type": "markdown",
"id": "8b2148e4-da6d-4ae0-a194-c7adb62728a0",
"metadata": {
"tags": []
},
"source": [
"## 8. Prediction\n",
"\n",
"Run predictions on each sample and calculate reconstruction losses. Each image is processed individually to track individual losses."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3814751-f352-456e-850c-fe1d289b1d6b",
"metadata": {},
"outputs": [],
"source": [
"inputs = []\n",
"outputs = []\n",
"masks = []\n",
"losses = []\n",
"\n",
"# We could do this in a single batch however we\n",
"# want to report the loss per-image, in place of\n",
"# loss per-batch.\n",
"for i in tqdm(range(img.shape[0])):\n",
" single_img = img[i].unsqueeze(0)\n",
" single_mask = mask[i].unsqueeze(0)\n",
" single_img = single_img.cuda(non_blocking=True)\n",
" single_mask = single_mask.cuda(non_blocking=True)\n",
"\n",
" with torch.no_grad():\n",
" z = model.encoder(single_img, single_mask)\n",
" img_recon = model.decoder(z)\n",
" loss = model(single_img, single_mask)\n",
"\n",
" inputs.extend(single_img.cpu())\n",
" masks.extend(single_mask.cpu())\n",
" outputs.extend(img_recon.cpu())\n",
" losses.append(loss.cpu()) "
]
},
{
"cell_type": "markdown",
"id": "22329bb4-5c6e-42dc-a492-8863fc2bf672",
"metadata": {},
"source": [
"## 9. Export Reconstruction Results to PDF\n",
"\n",
"Save and visualize the reconstruction results. The output PDF will contain reconstructed images with original and masked versions."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ac6a09d-5fe2-4aa9-ac37-f235d5a8020a",
"metadata": {},
"outputs": [],
"source": [
"pdfPath = '../../satvision-toa-reconstruction-validation-giant-example.pdf'\n",
"rgbIndex = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n",
"plot_export_pdf(pdfPath, inputs, outputs, masks, rgbIndex)"
]
},
{
"cell_type": "markdown",
"id": "1e0eb426-c7b4-47d4-aefa-2199ecfce2ab",
"metadata": {},
"source": [
"This notebook provides an end-to-end example for reconstructing satellite images with the SatVision-TOA model, from setup through prediction and output visualization."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62065e24-ddf2-4bf1-8362-90dc0c9bf49e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "ILAB Kernel (Pytorch)",
"language": "python",
"name": "pytorch-kernel"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
9 changes: 8 additions & 1 deletion pytorch_caney/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from .model_factory import ModelFactory
from .mim import MiMModel
from .mim import MiMModel
from .heads import SegmentationHead
from .decoders import FcnDecoder
from .encoders import SatVision, SwinTransformerV2, FcnEncoder


__all__ = [ModelFactory, MiMModel, SegmentationHead,
FcnDecoder, SatVision, SwinTransformerV2, FcnEncoder]
5 changes: 4 additions & 1 deletion pytorch_caney/models/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .fcn_decoder import FcnDecoder
from .fcn_decoder import FcnDecoder


__all__ = [FcnDecoder]
Loading

0 comments on commit d596fbc

Please sign in to comment.