diff --git a/notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb b/notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb new file mode 100644 index 0000000..853d0db --- /dev/null +++ b/notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb @@ -0,0 +1,368 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5facdc34-efbd-4082-91ef-e70a4f34c441", + "metadata": {}, + "source": [ + "# SatVision-TOA Reconstruction Example Notebook\n", + "\n", + "This notebook demonstrates the reconstruction capabilities of the SatVision-TOA model, designed to process and reconstruct MODIS TOA (Top of Atmosphere) imagery using Masked Image Modeling (MIM) for Earth observation tasks.\n", + "\n", + "Follow this step-by-step guide to install necessary dependencies, load model weights, transform data, make predictions, and visualize the results.\n", + "\n", + "## 1. Setup and Install Dependencies\n", + "\n", + "The following packages are required to run the notebook:\n", + "- `yacs` – for handling configuration\n", + "- `timm` – for Transformer and Image Models in PyTorch\n", + "- `segmentation-models-pytorch` – for segmentation utilities\n", + "- `termcolor` – for colored terminal text\n", + "- `webdataset==0.2.86` – for handling datasets from web sources" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5e08cd1-d8df-4dd8-b884-d452ef90943b", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4506576-5e30-417d-96de-8953d71c76c2", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import time\n", + "import random\n", + "import datetime\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import logging\n", + "\n", + "import torch\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.backends.backend_pdf import PdfPages\n", + "\n", + "import warnings\n", + "\n", + "warnings.filterwarnings('ignore') " + ] + }, + { + "cell_type": "markdown", + "id": "775cb720-5151-49fa-a7d5-7291ef663d45", + "metadata": {}, + "source": [ + "## 2. Model and Configuration Imports\n", + "\n", + "We load necessary modules from the pytorch-caney library, including the model, transformations, and plotting utilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edf47149-f489-497b-8601-89a7e8dbd9b9", + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append('../../pytorch-caney')\n", + "\n", + "from pytorch_caney.models.mim import build_mim_model\n", + "from pytorch_caney.transforms.mim_modis_toa import MimTransform\n", + "from pytorch_caney.configs.config import _C, _update_config_from_file\n", + "from pytorch_caney.plotting.modis_toa import plot_export_pdf" + ] + }, + { + "cell_type": "markdown", + "id": "fe00e78e-fca3-4221-86dd-da205fed4192", + "metadata": {}, + "source": [ + "## 2. Fetching the model\n", + "\n", + "### 2.1 Clone model ckpt from huggingface\n", + "\n", + "Model repo: https://huggingface.co/nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n", + "\n", + "```bash\n", + "# On prism/explore system\n", + "module load git-lfs\n", + "\n", + "git lfs install\n", + "\n", + "git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n", + "```\n", + "\n", + "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.\n", + "https://huggingface.co/docs/hub/security-git-ssh\n", + "\n", + "```bash\n", + "# If this outputs as anon, follow the next steps.\n", + "ssh -T git@hf.co\n", + "```\n", + "\n", + "\n", + "```bash\n", + "eval $(ssh-agent)\n", + "\n", + "# Check if ssh-agent is using the proper key\n", + "ssh-add -l\n", + "\n", + "# If not\n", + "ssh-add ~/.ssh/your-key\n", + "\n", + "# Or if you want to use the default id_* key, just do\n", + "ssh-add\n", + "\n", + "```\n", + "\n", + "## 3. Fetching the validation dataset\n", + "\n", + "### 3.1 Clone dataset repo from huggingface\n", + "\n", + "Dataset repo: https://huggingface.co/datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n", + "\n", + "\n", + "```bash\n", + "# On prims/explore system\n", + "module load git-lfs\n", + "\n", + "git lfs install\n", + "\n", + "git clone git@hf.co:datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "abb754ff-1753-4a4c-804e-8e3e5461fd0a", + "metadata": {}, + "source": [ + "## 4. Define Model and Data Paths\n", + "\n", + "Specify paths to model checkpoint, configuration file, and the validation dataset. Customize these paths as needed for your environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ec267ce-ded1-40e6-8443-e1037297f710", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mp_rank_00_model_states.pt'\n", + "CONFIG_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml'\n", + "\n", + "OUTPUT: str = '.'\n", + "DATA_PATH: str = '../../modis_toa_cloud_reconstruction_validation/sv_toa_128_chip_validation_04_24.npy'" + ] + }, + { + "cell_type": "markdown", + "id": "bd7d0b93-7fd3-49cb-ab9e-7536820ec5f2", + "metadata": {}, + "source": [ + "## 5. Configure Model\n", + "\n", + "Load and update the configuration for the SatVision-TOA model, specifying model and data paths." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aac43f0e-dc4b-49ba-a482-933b5bab4b79", + "metadata": {}, + "outputs": [], + "source": [ + "# Update config given configurations\n", + "\n", + "config = _C.clone()\n", + "_update_config_from_file(config, CONFIG_PATH)\n", + "\n", + "config.defrost()\n", + "config.MODEL.PRETRAINED = MODEL_PATH\n", + "config.DATA.DATA_PATHS = [DATA_PATH]\n", + "config.OUTPUT = OUTPUT\n", + "config.freeze()" + ] + }, + { + "cell_type": "markdown", + "id": "1d596904-d1df-4f6d-8e88-4c647ac26924", + "metadata": {}, + "source": [ + "## 6. Load Model Weights from Checkpoint\n", + "\n", + "Build and initialize the model from the checkpoint to prepare for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfe245f7-589e-4b02-9990-15cb1733f6cb", + "metadata": {}, + "outputs": [], + "source": [ + "print('Building un-initialized model')\n", + "model = build_mim_model(config)\n", + "print('Successfully built uninitialized model')\n", + "\n", + "print(f'Attempting to load checkpoint from {config.MODEL.PRETRAINED}')\n", + "checkpoint = torch.load(config.MODEL.PRETRAINED)\n", + "model.load_state_dict(checkpoint['module'])\n", + "print('Successfully applied checkpoint')\n", + "model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "20c26d1e-125a-4b4c-a21e-ab07d6977222", + "metadata": {}, + "source": [ + "## 7. Transform Validation Data\n", + "\n", + "The MODIS TOA dataset is loaded and transformed using MimTransform, generating a masked dataset for reconstruction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b3b47b1-0690-4ef9-bed6-ec243b5d42cb", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the Masked-Image-Modeling transform specific to MODIS TOA data\n", + "transform = MimTransform(config)\n", + "\n", + "# The reconstruction evaluation set is a single numpy file\n", + "validation_dataset_path = config.DATA.DATA_PATHS[0]\n", + "validation_dataset = np.load(validation_dataset_path)\n", + "len_batch = range(validation_dataset.shape[0])\n", + "\n", + "# Apply transform to each image in the batch\n", + "# A mask is auto-generated in the transform\n", + "imgMasks = [transform(validation_dataset[idx]) for idx \\\n", + " in len_batch]\n", + "\n", + "# Seperate img and masks, cast masks to torch tensor\n", + "img = torch.stack([imgMask[0] for imgMask in imgMasks])\n", + "mask = torch.stack([torch.from_numpy(imgMask[1]) for \\\n", + " imgMask in imgMasks])" + ] + }, + { + "cell_type": "markdown", + "id": "8b2148e4-da6d-4ae0-a194-c7adb62728a0", + "metadata": { + "tags": [] + }, + "source": [ + "## 8. Prediction\n", + "\n", + "Run predictions on each sample and calculate reconstruction losses. Each image is processed individually to track individual losses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3814751-f352-456e-850c-fe1d289b1d6b", + "metadata": {}, + "outputs": [], + "source": [ + "inputs = []\n", + "outputs = []\n", + "masks = []\n", + "losses = []\n", + "\n", + "# We could do this in a single batch however we\n", + "# want to report the loss per-image, in place of\n", + "# loss per-batch.\n", + "for i in tqdm(range(img.shape[0])):\n", + " single_img = img[i].unsqueeze(0)\n", + " single_mask = mask[i].unsqueeze(0)\n", + " single_img = single_img.cuda(non_blocking=True)\n", + " single_mask = single_mask.cuda(non_blocking=True)\n", + "\n", + " with torch.no_grad():\n", + " z = model.encoder(single_img, single_mask)\n", + " img_recon = model.decoder(z)\n", + " loss = model(single_img, single_mask)\n", + "\n", + " inputs.extend(single_img.cpu())\n", + " masks.extend(single_mask.cpu())\n", + " outputs.extend(img_recon.cpu())\n", + " losses.append(loss.cpu()) " + ] + }, + { + "cell_type": "markdown", + "id": "22329bb4-5c6e-42dc-a492-8863fc2bf672", + "metadata": {}, + "source": [ + "## 9. Export Reconstruction Results to PDF\n", + "\n", + "Save and visualize the reconstruction results. The output PDF will contain reconstructed images with original and masked versions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ac6a09d-5fe2-4aa9-ac37-f235d5a8020a", + "metadata": {}, + "outputs": [], + "source": [ + "pdfPath = '../../satvision-toa-reconstruction-validation-giant-example.pdf'\n", + "rgbIndex = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n", + "plot_export_pdf(pdfPath, inputs, outputs, masks, rgbIndex)" + ] + }, + { + "cell_type": "markdown", + "id": "1e0eb426-c7b4-47d4-aefa-2199ecfce2ab", + "metadata": {}, + "source": [ + "This notebook provides an end-to-end example for reconstructing satellite images with the SatVision-TOA model, from setup through prediction and output visualization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62065e24-ddf2-4bf1-8362-90dc0c9bf49e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ILAB Kernel (Pytorch)", + "language": "python", + "name": "pytorch-kernel" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pytorch_caney/models/__init__.py b/pytorch_caney/models/__init__.py index 32eb253..e381289 100644 --- a/pytorch_caney/models/__init__.py +++ b/pytorch_caney/models/__init__.py @@ -1,2 +1,9 @@ from .model_factory import ModelFactory -from .mim import MiMModel \ No newline at end of file +from .mim import MiMModel +from .heads import SegmentationHead +from .decoders import FcnDecoder +from .encoders import SatVision, SwinTransformerV2, FcnEncoder + + +__all__ = [ModelFactory, MiMModel, SegmentationHead, + FcnDecoder, SatVision, SwinTransformerV2, FcnEncoder] diff --git a/pytorch_caney/models/decoders/__init__.py b/pytorch_caney/models/decoders/__init__.py index fafea15..303682d 100644 --- a/pytorch_caney/models/decoders/__init__.py +++ b/pytorch_caney/models/decoders/__init__.py @@ -1 +1,4 @@ -from .fcn_decoder import FcnDecoder \ No newline at end of file +from .fcn_decoder import FcnDecoder + + +__all__ = [FcnDecoder] diff --git a/pytorch_caney/models/decoders/fcn_decoder.py b/pytorch_caney/models/decoders/fcn_decoder.py index 232147d..cdb3d04 100644 --- a/pytorch_caney/models/decoders/fcn_decoder.py +++ b/pytorch_caney/models/decoders/fcn_decoder.py @@ -9,17 +9,17 @@ def __init__(self, num_features: int = 1024): super(FcnDecoder, self).__init__() self.output_channels = 64 self.decoder = nn.Sequential( - nn.ConvTranspose2d(num_features, 2048, kernel_size=3, stride=2, padding=1, output_padding=1), # 16x16x512 + nn.ConvTranspose2d(num_features, 2048, kernel_size=3, stride=2, padding=1, output_padding=1), # 16x16x512 # noqa: E501 nn.ReLU(), - nn.ConvTranspose2d(2048, 512, kernel_size=3, stride=2, padding=1, output_padding=1), # 32x32x256 + nn.ConvTranspose2d(2048, 512, kernel_size=3, stride=2, padding=1, output_padding=1), # 32x32x256 # noqa: E501 nn.ReLU(), - nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), # 64x64x128 + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), # 64x64x128 # noqa: E501 nn.ReLU(), - nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # 64x64x128 + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # 64x64x128 # noqa: E501 nn.ReLU(), - nn.ConvTranspose2d(128, self.output_channels, kernel_size=3, stride=2, padding=1, output_padding=1), # 128x128x64 + nn.ConvTranspose2d(128, self.output_channels, kernel_size=3, stride=2, padding=1, output_padding=1), # 128x128x64 # noqa: E501 nn.ReLU() ) def forward(self, x): - return self.decoder(x) \ No newline at end of file + return self.decoder(x) diff --git a/pytorch_caney/models/encoders/__init__.py b/pytorch_caney/models/encoders/__init__.py index c699db0..ac897ad 100644 --- a/pytorch_caney/models/encoders/__init__.py +++ b/pytorch_caney/models/encoders/__init__.py @@ -1,3 +1,6 @@ from .fcn_encoder import FcnEncoder from .satvision import SatVision -from .swinv2 import SwinTransformerV2 \ No newline at end of file +from .swinv2 import SwinTransformerV2 + + +__all__ = [FcnEncoder, SatVision, SwinTransformerV2] diff --git a/pytorch_caney/models/encoders/fcn_encoder.py b/pytorch_caney/models/encoders/fcn_encoder.py index 0c1e20f..3f77cc0 100644 --- a/pytorch_caney/models/encoders/fcn_encoder.py +++ b/pytorch_caney/models/encoders/fcn_encoder.py @@ -11,16 +11,16 @@ def __init__(self, config): self.num_input_channels = self.config.MODEL.IN_CHANS self.num_features = 1024 self.encoder = nn.Sequential( - nn.Conv2d(self.num_input_channels, 64, kernel_size=3, stride=1, padding=1), # 128x128x64 + nn.Conv2d(self.num_input_channels, 64, kernel_size=3, stride=1, padding=1), # 128x128x64 # noqa: E501 nn.ReLU(), - nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 64x64x128 + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 64x64x128 # noqa: E501 nn.ReLU(), - nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 32x32x256 + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 32x32x256 # noqa: E501 nn.ReLU(), - nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 16x16x512 + nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 16x16x512 # noqa: E501 nn.ReLU(), - nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1) # 8x8x1024 + nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1) # 8x8x1024 # noqa: E501 ) def forward(self, x): - return self.encoder(x) \ No newline at end of file + return self.encoder(x) diff --git a/pytorch_caney/models/encoders/satvision.py b/pytorch_caney/models/encoders/satvision.py index f19247d..06141d0 100644 --- a/pytorch_caney/models/encoders/satvision.py +++ b/pytorch_caney/models/encoders/satvision.py @@ -5,7 +5,7 @@ # ----------------------------------------------------------------------------- -# SatVision +# SatVision # ----------------------------------------------------------------------------- @ModelFactory.encoder("satvision") class SatVision(nn.Module): @@ -42,16 +42,17 @@ def __init__(self, config): if self.config.MODEL.PRETRAINED: self.load_pretrained() - self.num_classes = self.model.num_classes - self.num_layers = self.model.num_layers - self.num_features = self.model.num_features + self.num_classes = self.model.num_classes + self.num_layers = self.model.num_layers + self.num_features = self.model.num_features # ------------------------------------------------------------------------- # __init__ # ------------------------------------------------------------------------- def load_pretrained(self): - checkpoint = torch.load(self.config.MODEL.PRETRAINED, map_location='cpu') + checkpoint = torch.load( + self.config.MODEL.PRETRAINED, map_location='cpu') checkpoint_model = checkpoint['module'] @@ -77,14 +78,14 @@ def load_pretrained(self): torch.cuda.empty_cache() - print(f">>>>>>>>>> loaded successfully '{self.config.MODEL.PRETRAINED}'") + print(f">>>>>>> loaded successfully '{self.config.MODEL.PRETRAINED}'") # ------------------------------------------------------------------------- # forward # ------------------------------------------------------------------------- def forward(self, x): return self.model.forward(x) - + # ------------------------------------------------------------------------- # forward_features # ------------------------------------------------------------------------- diff --git a/pytorch_caney/models/encoders/swinv2.py b/pytorch_caney/models/encoders/swinv2.py index 1ceb257..fd4ed8f 100644 --- a/pytorch_caney/models/encoders/swinv2.py +++ b/pytorch_caney/models/encoders/swinv2.py @@ -219,7 +219,7 @@ def flops(self, N): # ----------------------------------------------------------------------------- -# Mlp +# Mlp # ----------------------------------------------------------------------------- class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, @@ -242,7 +242,7 @@ def forward(self, x): # ----------------------------------------------------------------------------- -# window_partition +# window_partition # ----------------------------------------------------------------------------- def window_partition(x, window_size): """ @@ -262,7 +262,7 @@ def window_partition(x, window_size): # ----------------------------------------------------------------------------- -# window_reverse +# window_reverse # ----------------------------------------------------------------------------- def window_reverse(windows, window_size, H, W): """ @@ -283,7 +283,7 @@ def window_reverse(windows, window_size, H, W): # ----------------------------------------------------------------------------- -# SwinTransformerBlock +# SwinTransformerBlock # ----------------------------------------------------------------------------- class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. @@ -494,7 +494,7 @@ def flops(self): # ----------------------------------------------------------------------------- -# BasicLayer +# BasicLayer # ----------------------------------------------------------------------------- class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. @@ -595,7 +595,7 @@ def _init_respostnorm(self): # ----------------------------------------------------------------------------- -# PatchEmbed +# PatchEmbed # ----------------------------------------------------------------------------- class PatchEmbed(nn.Module): r""" Image to Patch Embedding @@ -656,7 +656,7 @@ def flops(self): # ----------------------------------------------------------------------------- -# SwinTransformerV2 +# SwinTransformerV2 # ----------------------------------------------------------------------------- @ModelFactory.encoder("swinv2") class SwinTransformerV2(nn.Module): diff --git a/pytorch_caney/models/heads/__init__.py b/pytorch_caney/models/heads/__init__.py index df60bcd..fcf4565 100644 --- a/pytorch_caney/models/heads/__init__.py +++ b/pytorch_caney/models/heads/__init__.py @@ -1 +1,4 @@ -from .segmentation_head import SegmentationHead \ No newline at end of file +from .segmentation_head import SegmentationHead + + +__all__ = [SegmentationHead] diff --git a/pytorch_caney/models/heads/segmentation_head.py b/pytorch_caney/models/heads/segmentation_head.py index 86308e6..5561bac 100644 --- a/pytorch_caney/models/heads/segmentation_head.py +++ b/pytorch_caney/models/heads/segmentation_head.py @@ -2,15 +2,20 @@ from ..model_factory import ModelFactory + @ModelFactory.head("segmentation_head") class SegmentationHead(nn.Module): - def __init__(self, decoder_channels=128, num_classes=4, head_dropout=0.2, output_shape=(91, 40)): + def __init__(self, decoder_channels=128, num_classes=4, + head_dropout=0.2, output_shape=(91, 40)): super(SegmentationHead, self).__init__() self.head = nn.Sequential( - nn.Conv2d(decoder_channels, num_classes, kernel_size=3, stride=1, padding=1), + nn.Conv2d(decoder_channels, num_classes, + kernel_size=3, stride=1, padding=1), nn.Dropout(head_dropout), - nn.Upsample(size=output_shape, mode='bilinear', align_corners=False) + nn.Upsample(size=output_shape, + mode='bilinear', + align_corners=False) ) def forward(self, x): - return self.head(x) \ No newline at end of file + return self.head(x) diff --git a/pytorch_caney/models/mim.py b/pytorch_caney/models/mim.py index 6e41647..2d421cf 100644 --- a/pytorch_caney/models/mim.py +++ b/pytorch_caney/models/mim.py @@ -7,7 +7,7 @@ # ----------------------------------------------------------------------------- -# SwinTransformerV2ForMiM +# SwinTransformerV2ForMiM # ----------------------------------------------------------------------------- class SwinTransformerV2ForSimMIM(SwinTransformerV2): def __init__(self, **kwargs): @@ -48,7 +48,7 @@ def no_weight_decay(self): # ----------------------------------------------------------------------------- -# MiMModel +# MiMModel # ----------------------------------------------------------------------------- class MiMModel(nn.Module): """ @@ -101,7 +101,7 @@ def no_weight_decay_keywords(self): # ----------------------------------------------------------------------------- -# build_mim_model +# build_mim_model # ----------------------------------------------------------------------------- def build_mim_model(config): """Builds the masked-image-modeling model. diff --git a/pytorch_caney/optimizers/build.py b/pytorch_caney/optimizers/build.py index e2be339..c5d6c17 100644 --- a/pytorch_caney/optimizers/build.py +++ b/pytorch_caney/optimizers/build.py @@ -39,7 +39,7 @@ def get_optimizer_from_dict(optimizer_name, config): error_msg = f"{optimizer_name} is not an implemented optimizer" - error_msg = f"{error_msg}. Available optimizer functions: {OPTIMIZERS.keys()}" + error_msg = f"{error_msg}. Available optimizer functions: {OPTIMIZERS.keys()}" # noqa: E501 raise KeyError(error_msg) @@ -98,10 +98,10 @@ def build_optimizer(config, model, is_pretrain=False, logger=None): optimizer = None optimizer = optimizer_to_use(parameters, - eps=config.TRAIN.OPTIMIZER.EPS, - betas=config.TRAIN.OPTIMIZER.BETAS, - lr=config.TRAIN.BASE_LR, - weight_decay=config.TRAIN.WEIGHT_DECAY) + eps=config.TRAIN.OPTIMIZER.EPS, + betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY) if logger: logger.info(optimizer) @@ -246,4 +246,3 @@ def get_swin_layer(name, num_layers, depths): else: return num_layers - 1 - diff --git a/pytorch_caney/optimizers/lamb.py b/pytorch_caney/optimizers/lamb.py index de6aebb..e51466d 100644 --- a/pytorch_caney/optimizers/lamb.py +++ b/pytorch_caney/optimizers/lamb.py @@ -1,14 +1,14 @@ """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb This optimizer code was adapted from the following (starting with latest) -* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py -* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py # noqa: E501 +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py # noqa: E501 * https://github.com/cybertronai/pytorch-lamb -Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is -similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is # noqa: E501 +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. # noqa: E501 -In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. # noqa: E501 Original copyrights for above sources are below. @@ -41,7 +41,7 @@ # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # -# The above copyright notice and this permission notice shall be included in all +# The above copyright notice and this permission notice shall be included in all # noqa: E501 # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR @@ -60,7 +60,9 @@ from torch.utils.tensorboard import SummaryWriter -def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): +def log_lamb_rs(optimizer: Optimizer, + event_writer: SummaryWriter, + token_count: int): """Log a histogram of trust ratio scalars in across layers.""" results = collections.defaultdict(list) for group in optimizer.param_groups: @@ -75,13 +77,13 @@ def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: class Lamb(Optimizer): - """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB - reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB # noqa: E501 + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py # noqa: E501 - LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. # noqa: E501 Arguments: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. # noqa: E501 lr (float, optional): learning rate. (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999)) @@ -102,12 +104,17 @@ class Lamb(Optimizer): """ def __init__( - self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, - weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): + self, params, lr=1e-3, bias_correction=True, + betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, + grad_averaging=True, max_grad_norm=1.0, + trust_clip=False, always_adapt=False): + defaults = dict( - lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, + lr=lr, bias_correction=bias_correction, + betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, trust_clip=trust_clip, always_adapt=always_adapt) + super().__init__(params, defaults) @torch.no_grad() @@ -123,7 +130,8 @@ def step(self, closure=None): loss = closure() device = self.param_groups[0]['params'][0].device - one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + # because torch.where doesn't handle scalars correctly + one_tensor = torch.tensor(1.0, device=device) global_grad_norm = torch.zeros(1, device=device) for group in self.param_groups: for p in group['params']: @@ -131,13 +139,15 @@ def step(self, closure=None): continue grad = p.grad if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + raise RuntimeError( + 'Lamb does not support sparse gradients, consider SparseAdam instad.') # noqa: E501 global_grad_norm.add_(grad.pow(2).sum()) global_grad_norm = torch.sqrt(global_grad_norm) - # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes - # scalar types properly https://github.com/pytorch/pytorch/issues/9190 - max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes # noqa: E501 + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 # noqa: E501 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], + device=device) clip_global_grad_norm = torch.where( global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, @@ -150,7 +160,7 @@ def step(self, closure=None): beta3 = 1 - beta1 if grad_averaging else 1.0 # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel + # per parameter step can be easily support by making it tensor, or pass list into kernel # noqa: E501 if 'step' in group: group['step'] += 1 else: @@ -179,9 +189,10 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t # noqa: E501 - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + denom = (exp_avg_sq.sqrt() / + math.sqrt(bias_correction2)).add_(group['eps']) update = (exp_avg / bias_correction1).div_(denom) weight_decay = group['weight_decay'] @@ -189,11 +200,11 @@ def step(self, closure=None): update.add_(p, alpha=weight_decay) if weight_decay != 0 or group['always_adapt']: - # Layer-wise LR adaptation. By default, skip adaptation on parameters that are - # excluded from weight decay, unless always_adapt == True, then always enabled. + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are # noqa: E501 + # excluded from weight decay, unless always_adapt == True, then always enabled. # noqa: E501 w_norm = p.norm(2.0) g_norm = update.norm(2.0) - # FIXME nested where required since logical and/or not working in PT XLA + # FIXME nested where required since logical and/or not working in PT XLA # noqa: E501 trust_ratio = torch.where( w_norm > 0, torch.where(g_norm > 0, w_norm / g_norm, one_tensor), diff --git a/pytorch_caney/pipelines/__init__.py b/pytorch_caney/pipelines/__init__.py index f008267..911c274 100644 --- a/pytorch_caney/pipelines/__init__.py +++ b/pytorch_caney/pipelines/__init__.py @@ -1,10 +1,12 @@ from .satvision_toa_pretrain_pipeline import SatVisionToaPretrain -from .three_d_cloud_pipeline import ThreeDCloudTask +from .three_d_cloud_pipeline import ThreeDCloudTask + PIPELINES = { 'satvisiontoapretrain': SatVisionToaPretrain, '3dcloud': ThreeDCloudTask } + def get_available_pipelines(): return {name: cls for name, cls in PIPELINES.items()} diff --git a/pytorch_caney/pipelines/satvision_toa_pretrain_pipeline.py b/pytorch_caney/pipelines/satvision_toa_pretrain_pipeline.py index 07e3bb5..c5461fe 100644 --- a/pytorch_caney/pipelines/satvision_toa_pretrain_pipeline.py +++ b/pytorch_caney/pipelines/satvision_toa_pretrain_pipeline.py @@ -46,13 +46,13 @@ def __init__(self, config): batch_size=self.batch_size).dataset() # ------------------------------------------------------------------------- - # load_checkpoint + # load_checkpoint # ------------------------------------------------------------------------- def load_checkpoint(self): - print(f'Attempting to load checkpoint from {self.config.MODEL.PRETRAINED}') + print('Loading checkpoint from {self.config.MODEL.PRETRAINED}') checkpoint = torch.load(self.config.MODEL.PRETRAINED) self.model.load_state_dict(checkpoint['module']) - print(f'Successfully applied checkpoint') + print('Successfully applied checkpoint') # ------------------------------------------------------------------------- # forward @@ -73,16 +73,15 @@ def training_step(self, batch, batch_idx): self.train_loss_avg.compute(), rank_zero_only=True, batch_size=self.batch_size, - prog_bar=True - ) - + prog_bar=True) + return loss # ------------------------------------------------------------------------- # configure_optimizers # ------------------------------------------------------------------------- def configure_optimizers(self): - optimizer = build_optimizer(self.config, self.model, is_pretrain=True) + optimizer = build_optimizer(self.config, self.model, is_pretrain=True) return optimizer # ------------------------------------------------------------------------- diff --git a/pytorch_caney/pipelines/three_d_cloud_pipeline.py b/pytorch_caney/pipelines/three_d_cloud_pipeline.py index 492db5b..717c705 100644 --- a/pytorch_caney/pipelines/three_d_cloud_pipeline.py +++ b/pytorch_caney/pipelines/three_d_cloud_pipeline.py @@ -4,21 +4,20 @@ import lightning.pytorch as pl -from pytorch_caney.models.mim import build_mim_model from pytorch_caney.optimizers.build import build_optimizer from pytorch_caney.transforms.abi_toa import AbiToaTransform from pytorch_caney.models import ModelFactory -from pytorch_caney.models.decoders import FcnDecoder -from pytorch_caney.models.heads import SegmentationHead -from typing import Any, Tuple +from typing import Tuple + # ----------------------------------------------------------------------------- # ThreeDCloudTask # ----------------------------------------------------------------------------- class ThreeDCloudTask(pl.LightningModule): - NUM_CLASSES: int = 1 + NUM_CLASSES: int = 1 OUTPUT_SHAPE: Tuple[int, int] = (91, 40) + # ------------------------------------------------------------------------- # __init__ # ------------------------------------------------------------------------- @@ -32,7 +31,7 @@ def __init__(self, config): self.transform = AbiToaTransform(self.config) # ------------------------------------------------------------------------- - # configure_models + # configure_models # ------------------------------------------------------------------------- def configure_models(self): factory = ModelFactory() @@ -41,9 +40,10 @@ def configure_models(self): name=self.config.MODEL.ENCODER, config=self.config) - self.decoder = factory.get_component(component_type="decoder", - name=self.config.MODEL.DECODER, - num_features=self.encoder.num_features) + self.decoder = factory.get_component( + component_type="decoder", + name=self.config.MODEL.DECODER, + num_features=self.encoder.num_features) self.segmentation_head = factory.get_component( component_type="head", @@ -59,7 +59,7 @@ def configure_models(self): print(self.model) # ------------------------------------------------------------------------- - # configure_losses + # configure_losses # ------------------------------------------------------------------------- def configure_losses(self): loss: str = self.config.LOSS.NAME @@ -72,7 +72,7 @@ def configure_losses(self): ) # ------------------------------------------------------------------------- - # configure_metrics + # configure_metrics # ------------------------------------------------------------------------- def configure_metrics(self): num_classes = 2 @@ -90,7 +90,7 @@ def configure_metrics(self): # forward # ------------------------------------------------------------------------- def forward(self, x): - return self.model(x) + return self.model(x) # ------------------------------------------------------------------------- # training_step @@ -108,7 +108,7 @@ def training_step(self, batch, batch_idx): self.log('train_loss', self.train_loss_avg.compute(), on_step=True, on_epoch=True, prog_bar=True) self.log('train_iou', self.train_iou_avg.compute(), - on_step=True, on_epoch=True, prog_bar=True) + on_step=True, on_epoch=True, prog_bar=True) return loss # ------------------------------------------------------------------------- @@ -127,14 +127,14 @@ def validation_step(self, batch, batch_idx): on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) self.log('val_iou', self.val_iou_avg.compute(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) - + return val_loss # ------------------------------------------------------------------------- # configure_optimizers # ------------------------------------------------------------------------- def configure_optimizers(self): - optimizer = build_optimizer(self.config, self.model, is_pretrain=True) + optimizer = build_optimizer(self.config, self.model, is_pretrain=True) print(f'Using optimizer: {optimizer}') return optimizer diff --git a/pytorch_caney/plotting/__init__.py b/pytorch_caney/plotting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch_caney/plotting/modis_toa.py b/pytorch_caney/plotting/modis_toa.py new file mode 100644 index 0000000..37671b6 --- /dev/null +++ b/pytorch_caney/plotting/modis_toa.py @@ -0,0 +1,152 @@ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages + +from ..transforms.modis_toa_scale import MinMaxEmissiveScaleReflectance + + +# ----------------------------------------------------------------------------- +# MODIS Reconstruction Visualization Pipeline +# ----------------------------------------------------------------------------- +# This script processes MODIS TOA images and model reconstructions, generating +# comparison visualizations in a PDF format. It contains several functions that +# interact to prepare, transform, and visualize MODIS image data, applying +# necessary transformations for reflective and emissive band scaling, masking, +# and normalization. The flow is as follows: +# +# 1. `plot_export_pdf`: Main function that generates PDF visualizations. +# It uses other functions to process and organize data. +# 2. `process_reconstruction_prediction`: Prepares images and masks for +# visualization, applying transformations and normalization. +# 3. `minmax_norm`: Scales image arrays to 0-255 range for display. +# 4. `process_mask`: Prepares mask images to match the input image dimensions. +# 5. `reverse_transform`: Applies band-specific scaling to MODIS data. +# +# ASCII Diagram: +# +# plot_export_pdf +# └── process_reconstruction_prediction +# ├── minmax_norm +# ├── process_mask +# └── reverse_transform +# +# ----------------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- +# plot_export_pdf +# ----------------------------------------------------------------------------- +# Generates a multi-page PDF with visualizations of original, reconstructed, +# and masked MODIS images. Uses the `process_reconstruction_prediction` funct +# to prepare data for display and organizes subplots for easy comparison. +# ----------------------------------------------------------------------------- +def plot_export_pdf(path, inputs, outputs, masks, rgb_index): + pdf_plot_obj = PdfPages(path) + + for idx in range(len(inputs)): + # prediction processing + image = inputs[idx] + img_recon = outputs[idx] + mask = masks[idx] + rgb_image, rgb_image_masked, rgb_recon_masked, mask = \ + process_reconstruction_prediction( + image, img_recon, mask, rgb_index) + + # matplotlib code + fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30)) + ax0, ax1 = ax01 + ax2, ax3 = ax23 + ax2.imshow(rgb_image) + ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}") + + ax0.imshow(rgb_recon_masked) + ax0.set_title(f"Idx: {idx} Model reconstruction") + + ax1.imshow(rgb_image_masked) + ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked") + + ax3.matshow(mask[:, :, 0]) + ax3.set_title(f"Idx: {idx} Reconstruction Mask") + pdf_plot_obj.savefig() + + pdf_plot_obj.close() + + +# ----------------------------------------------------------------------------- +# process_reconstruction_prediction +# ----------------------------------------------------------------------------- +# Prepares RGB images, reconstructions, and masked versions by extracting and +# normalizing specific bands based on the provided RGB indices. Returns masked +# images and the processed mask for visualization in the PDF. +# ----------------------------------------------------------------------------- +def process_reconstruction_prediction(image, img_recon, mask, rgb_index): + + mask = process_mask(mask) + + red_idx = rgb_index[0] + blue_idx = rgb_index[1] + green_idx = rgb_index[2] + + image = reverse_transform(image.numpy()) + + img_recon = reverse_transform(img_recon.numpy()) + + rgb_image = np.stack((image[red_idx, :, :], + image[blue_idx, :, :], + image[green_idx, :, :]), axis=-1) + rgb_image = minmax_norm(rgb_image) + + rgb_image_recon = np.stack((img_recon[red_idx, :, :], + img_recon[blue_idx, :, :], + img_recon[green_idx, :, :]), axis=-1) + rgb_image_recon = minmax_norm(rgb_image_recon) + + rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon) + rgb_image_masked = np.where(mask == 1, 0, rgb_image) + rgb_recon_masked = rgb_masked + + return rgb_image, rgb_image_masked, rgb_recon_masked, mask + + +# ----------------------------------------------------------------------------- +# minmax_norm +# ----------------------------------------------------------------------------- +# Normalizes an image array to a range of 0-255 for consistent display. +# ----------------------------------------------------------------------------- +def minmax_norm(img_arr): + arr_min = img_arr.min() + arr_max = img_arr.max() + img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min) + img_arr_scaled = img_arr_scaled * 255 + img_arr_scaled = img_arr_scaled.astype(np.uint8) + return img_arr_scaled + + +# ----------------------------------------------------------------------------- +# process_mask +# ----------------------------------------------------------------------------- +# Adjusts the dimensions of a binary mask to match the input image shape, +# replicating mask values across the image. +# ----------------------------------------------------------------------------- +def process_mask(mask): + mask_img = mask.unsqueeze(0) + mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2) + mask_img = mask_img.unsqueeze(1).contiguous()[0, 0] + return np.stack([mask_img] * 3, axis=-1) + + +# ----------------------------------------------------------------------------- +# reverse_transform +# ----------------------------------------------------------------------------- +# Reverses scaling transformations applied to the original MODIS data to +# prepare the image for RGB visualization. +# ----------------------------------------------------------------------------- +def reverse_transform(image): + minMaxTransform = MinMaxEmissiveScaleReflectance() + image = image.transpose((1, 2, 0)) + image[:, :, minMaxTransform.reflectance_indices] *= 100 + emis_min, emis_max = \ + minMaxTransform.emissive_mins, minMaxTransform.emissive_maxs + image[:, :, minMaxTransform.emissive_indices] *= (emis_max - emis_min) + image[:, :, minMaxTransform.emissive_indices] += emis_min + return image.transpose((2, 0, 1)) diff --git a/pytorch_caney/transforms/abi_radiance_conversion.py b/pytorch_caney/transforms/abi_radiance_conversion.py index 71b5e3d..4470b75 100644 --- a/pytorch_caney/transforms/abi_radiance_conversion.py +++ b/pytorch_caney/transforms/abi_radiance_conversion.py @@ -11,7 +11,7 @@ def vis_calibrate(data): factor = np.pi * esd * esd / solar_irradiance return data * np.float32(factor) * 100 - + # ----------------------------------------------------------------------------- # ir_calibrate @@ -42,17 +42,18 @@ class ConvertABIToReflectanceBT(object): """ def __init__(self): - + self.reflectance_indices = [0, 1, 2, 3, 4, 6] self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] def __call__(self, img): - + # Reflectance % to reflectance units img[:, :, self.reflectance_indices] = \ vis_calibrate(img[:, :, self.reflectance_indices]) - + # Brightness temp scaled to (0,1) range - img[:, :, self.emissive_indices] = ir_calibrate(img[:, :, self.emissive_indices]) - - return img \ No newline at end of file + img[:, :, self.emissive_indices] = ir_calibrate( + img[:, :, self.emissive_indices]) + + return img diff --git a/pytorch_caney/transforms/abi_toa.py b/pytorch_caney/transforms/abi_toa.py index f762b25..30afb9c 100644 --- a/pytorch_caney/transforms/abi_toa.py +++ b/pytorch_caney/transforms/abi_toa.py @@ -17,7 +17,7 @@ def __init__(self, img_size): self.transform_img = \ T.Compose([ - ConvertABIToReflectanceBT(), # New transform for MinMax + ConvertABIToReflectanceBT(), MinMaxEmissiveScaleReflectance(), T.ToTensor(), T.Resize((img_size, img_size), antialias=True), diff --git a/pytorch_caney/transforms/abi_toa_scale.py b/pytorch_caney/transforms/abi_toa_scale.py index 0d4cf1e..852aafd 100644 --- a/pytorch_caney/transforms/abi_toa_scale.py +++ b/pytorch_caney/transforms/abi_toa_scale.py @@ -9,7 +9,7 @@ class MinMaxEmissiveScaleReflectance(object): """ def __init__(self): - + self.reflectance_indices = [0, 1, 2, 3, 4, 6] self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] @@ -24,14 +24,14 @@ def __init__(self): dtype=np.float32) def __call__(self, img): - + # Reflectance % to reflectance units img[:, :, self.reflectance_indices] = \ img[:, :, self.reflectance_indices] * 0.01 - + # Brightness temp scaled to (0,1) range img[:, :, self.emissive_indices] = \ (img[:, :, self.emissive_indices] - self.emissive_mins) / \ - (self.emissive_maxs - self.emissive_mins) - - return img \ No newline at end of file + (self.emissive_maxs - self.emissive_mins) + + return img diff --git a/pytorch_caney/transforms/mim_mask_generator.py b/pytorch_caney/transforms/mim_mask_generator.py index c101d3c..530b1ca 100644 --- a/pytorch_caney/transforms/mim_mask_generator.py +++ b/pytorch_caney/transforms/mim_mask_generator.py @@ -30,7 +30,7 @@ def __init__(self, def __call__(self): mask = make_mim_mask(self.token_count, self.mask_count, - self.rand_size, self.scale) + self.rand_size, self.scale) mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) return mask diff --git a/pytorch_caney/transforms/mim_modis_toa.py b/pytorch_caney/transforms/mim_modis_toa.py index c111600..1d168d9 100644 --- a/pytorch_caney/transforms/mim_modis_toa.py +++ b/pytorch_caney/transforms/mim_modis_toa.py @@ -1,7 +1,7 @@ import torchvision.transforms as T from .random_resize_crop import RandomResizedCropNP -from .mim_mask_generator import MimMaskGenerator +from .mim_mask_generator import MimMaskGenerator from .modis_toa_scale import MinMaxEmissiveScaleReflectance @@ -18,7 +18,7 @@ def __init__(self, config): self.transform_img = \ T.Compose([ - MinMaxEmissiveScaleReflectance(), # New transform for MinMax + MinMaxEmissiveScaleReflectance(), RandomResizedCropNP(scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), T.ToTensor(), diff --git a/pytorch_caney/transforms/modis_toa.py b/pytorch_caney/transforms/modis_toa.py index fd52805..24fdb1f 100644 --- a/pytorch_caney/transforms/modis_toa.py +++ b/pytorch_caney/transforms/modis_toa.py @@ -15,7 +15,7 @@ def __init__(self, config): self.transform_img = \ T.Compose([ - MinMaxEmissiveScaleReflectance(), # New transform for MinMax + MinMaxEmissiveScaleReflectance(), T.ToTensor(), T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), ]) @@ -24,4 +24,4 @@ def __call__(self, img): img = self.transform_img(img) - return img \ No newline at end of file + return img diff --git a/pytorch_caney/transforms/modis_toa_scale.py b/pytorch_caney/transforms/modis_toa_scale.py index b256a79..1eb5a30 100644 --- a/pytorch_caney/transforms/modis_toa_scale.py +++ b/pytorch_caney/transforms/modis_toa_scale.py @@ -12,7 +12,7 @@ class MinMaxEmissiveScaleReflectance(object): """ def __init__(self): - + self.reflectance_indices = [0, 1, 2, 3, 4, 6] self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] @@ -27,14 +27,14 @@ def __init__(self): dtype=np.float32) def __call__(self, img): - + # Reflectance % to reflectance units img[:, :, self.reflectance_indices] = \ img[:, :, self.reflectance_indices] * 0.01 - + # Brightness temp scaled to (0,1) range img[:, :, self.emissive_indices] = \ (img[:, :, self.emissive_indices] - self.emissive_mins) / \ - (self.emissive_maxs - self.emissive_mins) - + (self.emissive_maxs - self.emissive_mins) + return img diff --git a/pytorch_caney/transforms/random_resize_crop.py b/pytorch_caney/transforms/random_resize_crop.py index 8eab062..06609ec 100644 --- a/pytorch_caney/transforms/random_resize_crop.py +++ b/pytorch_caney/transforms/random_resize_crop.py @@ -60,4 +60,4 @@ def __call__(self, img): align_corners=False) cropped_squeezed_numpy = cropped_resized.squeeze().numpy() cropped_squeezed_numpy = np.moveaxis(cropped_squeezed_numpy, 0, -1) - return cropped_squeezed_numpy \ No newline at end of file + return cropped_squeezed_numpy diff --git a/pytorch_caney/utils.py b/pytorch_caney/utils.py index 9cde7cf..fc7e46d 100644 --- a/pytorch_caney/utils.py +++ b/pytorch_caney/utils.py @@ -15,21 +15,24 @@ def get_strategy(config): "zero_allow_untested_optimizer": True, "zero_optimization": { "stage": config.DEEPSPEED.STAGE, - "contiguous_gradients": config.DEEPSPEED.CONTIGUOUS_GRADIENTS, + "contiguous_gradients": + config.DEEPSPEED.CONTIGUOUS_GRADIENTS, "overlap_comm": config.DEEPSPEED.OVERLAP_COMM, "reduce_bucket_size": config.DEEPSPEED.REDUCE_BUCKET_SIZE, - "allgather_bucket_size": config.DEEPSPEED.ALLGATHER_BUCKET_SIZE, + "allgather_bucket_size": + config.DEEPSPEED.ALLGATHER_BUCKET_SIZE, }, "activation_checkpointing": { - "partition_activations": config.TRAIN.USE_CHECKPOINT, + "partition_activations": config.TRAIN.USE_CHECKPOINT, }, } - + return DeepSpeedStrategy(config=deepspeed_config) else: # These may be return as strings - return strategy + return strategy + # ----------------------------------------------------------------------------- # get_distributed_train_batches @@ -38,4 +41,5 @@ def get_distributed_train_batches(config, trainer): if config.TRAIN.NUM_TRAIN_BATCHES: return config.TRAIN.NUM_TRAIN_BATCHES else: - return config.DATA.LENGTH // (config.DATA.BATCH_SIZE * trainer.world_size) + return config.DATA.LENGTH // \ + (config.DATA.BATCH_SIZE * trainer.world_size)