From e8b937d9c3cbf5db63876617e3aa7f10dcb48964 Mon Sep 17 00:00:00 2001 From: jordancaraballo Date: Tue, 1 Oct 2024 23:38:13 -0400 Subject: [PATCH] Adding datamodule --- above_shrubs/datamodules/chm_datamodule.py | 72 ++++++++++- above_shrubs/pipelines/base_pipeline.py | 7 +- above_shrubs/pipelines/chm_pipeline.py | 141 +++++++++++++++------ projects/chm/configs/above_shrubs_chm.yaml | 2 + requirements/Dockerfile | 4 +- 5 files changed, 180 insertions(+), 46 deletions(-) diff --git a/above_shrubs/datamodules/chm_datamodule.py b/above_shrubs/datamodules/chm_datamodule.py index 16e7c04..4a8368e 100644 --- a/above_shrubs/datamodules/chm_datamodule.py +++ b/above_shrubs/datamodules/chm_datamodule.py @@ -1,13 +1,80 @@ import torch import lightning as L +#import kornia.augmentation as K # noqa: N812 + +from typing import Any from torchvision import transforms +from torchgeo.datamodules import NonGeoDataModule from torch.utils.data import random_split, DataLoader +from torchgeo.transforms import AugmentationSequential from above_shrubs.datasets.chm_dataset import CHMDataset -# Note - you must have torchvision installed for this example -from torchvision.datasets import MNIST + +MEANS = [325.04178, 518.01135, 393.07028, 2660.147, 343.5341] +STDS = [80.556175, 133.02502, 135.68076, 822.97205, 116.81135] + + +class CHMDataModule(NonGeoDataModule): + """NonGeo Fire Scars data module implementation""" + + def __init__( + self, + train_data_dir: str, + train_label_dir: str, + test_data_dir: str = None, + test_label_dir: str = None, + batch_size: int = 16, + num_workers: int = 8, + tile_size: tuple = (224, 224), + **kwargs: Any + ) -> None: + + super().__init__(CHMDataset, batch_size, num_workers, **kwargs) + # applied for training + #self.train_aug = AugmentationSequential( + # K.Normalize(MEANS, STDS), + # K.RandomCrop(tile_size), + # data_keys=["image", "mask"], + #) + #self.aug = AugmentationSequential( + # K.Normalize(MEANS, STDS), + # data_keys=["image", "mask"] + #) + + # tile size + self.tile_size = tile_size + + # training paths + self.train_data_dir = train_data_dir + self.train_label_dir = train_label_dir + + # test paths + self.test_data_dir = test_data_dir + self.test_label_dir = test_label_dir + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + self.train_data_dir, + self.train_label_dir, + img_size=self.tile_size, + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + self.test_data_dir, + self.test_label_dir, + img_size=self.tile_size, + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + self.test_data_dir, + self.test_label_dir, + img_size=self.tile_size, + ) + +""" class CHMDataModule(L.LightningDataModule): @@ -47,3 +114,4 @@ def test_dataloader(self): def predict_dataloader(self): return DataLoader(self.mnist_predict, batch_size=32) +""" \ No newline at end of file diff --git a/above_shrubs/pipelines/base_pipeline.py b/above_shrubs/pipelines/base_pipeline.py index d62cb7f..a9c1b44 100644 --- a/above_shrubs/pipelines/base_pipeline.py +++ b/above_shrubs/pipelines/base_pipeline.py @@ -168,11 +168,12 @@ def seed_everything(self, seed: int = 42) -> None: Returns: None. """ - np.random.seed(seed) + np.random.seed(int(seed)) if HAS_GPU: try: - cp.random.seed(seed) - except RuntimeError: + cp.random.seed(int(seed)) + except (RuntimeError, TypeError): + logging.warning('Seed could not be fixed for cupy.') return return diff --git a/above_shrubs/pipelines/chm_pipeline.py b/above_shrubs/pipelines/chm_pipeline.py index b75d53a..6ed72a9 100644 --- a/above_shrubs/pipelines/chm_pipeline.py +++ b/above_shrubs/pipelines/chm_pipeline.py @@ -3,15 +3,33 @@ import timm import torch -import terratorch import numpy as np +import pandas as pd +from tqdm import tqdm from itertools import repeat from torch.utils.data import DataLoader from multiprocessing import Pool, cpu_count from above_shrubs.datasets.chm_dataset import CHMDataset +from above_shrubs.datamodules.chm_datamodule import CHMDataModule from above_shrubs.pipelines.base_pipeline import BasePipeline + +# temporary additions + +import os +import tempfile + +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger + +from torchgeo.datamodules import EuroSAT100DataModule +from torchgeo.models import ResNet18_Weights +from torchgeo.trainers import PixelwiseRegressionTask + + CHUNKS = {'band': 'auto', 'x': 'auto', 'y': 'auto'} xp = np @@ -79,33 +97,34 @@ def preprocess(self): # Calculate mean and std values for training data_filenames = self.get_dataset_filenames( - self.train_data_dir, , ext='*.tif') + self.conf.train_data_dir, ext='*.tif') logging.info(f'Mean and std values from {len(data_filenames)} files.') # Temporarily disable standardization and augmentation - #current_standardization = self.conf.standardization - #self.conf.standardization = None - #metadata_output_filename = os.path.join( - # self.metadata_dir, 'mean-std-values.csv') + current_standardization = self.conf.standardization + self.conf.standardization = None + metadata_output_filename = os.path.join( + self.metadata_dir, 'mean-std-values.csv') # Set main data loader - #chm_train_dataset = CHMDataset( - # os.path.join(self.conf.train_tiles_dir, 'images'), - # os.path.join(self.conf.train_tiles_dir, 'labels'), - # img_size=(self.conf.tile_size, self.conf.tile_size), - #) - #train_dataloader = DataLoader( - # chm_train_dataset, - # batch_size=self.conf.batch_size, shuffle=False - #) + chm_train_dataset = CHMDataset( + self.conf.train_data_dir, + self.conf.train_label_dir, + img_size=(self.conf.tile_size, self.conf.tile_size), + ) + + train_dataloader = DataLoader( + chm_train_dataset, + batch_size=self.conf.batch_size, shuffle=False + ) # Get mean and std array - #mean, std = self.get_mean_std_dataset( - # train_dataloader, metadata_output_filename) - #logging.info(f'Mean: {mean.numpy()}, Std: {std.numpy()}') + mean, std = self.get_mean_std_dataset( + train_dataloader, metadata_output_filename) + logging.info(f'Mean: {mean.numpy()}, Std: {std.numpy()}') # Re-enable standardization for next pipeline step - #self.conf.standardization = current_standardization + self.conf.standardization = current_standardization logging.info('Done with preprocessing stage') @@ -120,31 +139,45 @@ def train(self): self._set_train_test_dirs() - # Set main data loader - chm_train_dataset = CHMDataset( - os.path.join(self.conf.train_tiles_dir, 'images'), - os.path.join(self.conf.train_tiles_dir, 'labels'), - img_size=(self.conf.tile_size, self.conf.tile_size), + batch_size = 10 + num_workers = 2 + max_epochs = 50 + fast_dev_run = False + + datamodule = CHMDataModule( + train_data_dir=self.conf.train_data_dir, + train_label_dir=self.conf.train_label_dir, + test_data_dir=self.conf.test_data_dir, + test_label_dir=self.conf.test_label_dir, + batch_size=16, + num_workers=8, ) - chm_test_dataset = CHMDataset( - os.path.join(self.conf.test_tiles_dir, 'images'), - os.path.join(self.conf.test_tiles_dir, 'labels'), - img_size=(self.conf.tile_size, self.conf.tile_size), - ) + # Set main data loader + #chm_train_dataset = CHMDataset( + # os.path.join(self.conf.train_tiles_dir, 'images'), + # os.path.join(self.conf.train_tiles_dir, 'labels'), + # img_size=(self.conf.tile_size, self.conf.tile_size), + #) + + #chm_test_dataset = CHMDataset( + # os.path.join(self.conf.test_tiles_dir, 'images'), + # os.path.join(self.conf.test_tiles_dir, 'labels'), + # img_size=(self.conf.tile_size, self.conf.tile_size), + #) # start dataloader - train_dataloader = DataLoader( - chm_train_dataset, - batch_size=self.conf.batch_size, shuffle=True - ) + #train_dataloader = DataLoader( + # chm_train_dataset, + # batch_size=self.conf.batch_size, shuffle=True + #) - test_dataloader = DataLoader( - chm_test_dataset, - batch_size=self.conf.batch_size, shuffle=False - ) + #test_dataloader = DataLoader( + # chm_test_dataset, + # batch_size=self.conf.batch_size, shuffle=False + #) - print(timm.list_models("prithvi*")) + #print(timm.list_models("prithvi*")) # Build model # model = build_finetune_model(config, logger) @@ -154,4 +187,34 @@ def train(self): # is_pretrain=False, # logger=logger) - return \ No newline at end of file + return + + # ------------------------------------------------------------------------- + # get_mean_std_dataset + # ------------------------------------------------------------------------- + def get_mean_std_dataset(self, dataloader, output_filename: str): + + test_sample_shape = next(iter(dataloader))['image'].shape + if test_sample_shape[1] < test_sample_shape[2]: + mean_dims = (0, 2, 3) + else: + mean_dims = (1, 2, 3) + + for index, data_dict in tqdm(enumerate(dataloader)): + data = data_dict['image'].cuda() + channels_sum, channels_squared_sum, num_batches = 0, 0, 0 + channels_sum += torch.mean(data, dim=mean_dims) + channels_squared_sum += torch.mean(data**2, dim=mean_dims) + num_batches += 1 + + mean = channels_sum / num_batches + std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 + + mean, std = mean.cpu(), std.cpu() + + if output_filename is not None: + mean_std = np.stack( + [mean.numpy(), std.numpy()], axis=0) + pd.DataFrame(mean_std).to_csv( + output_filename, header=None, index=None) + return mean, std diff --git a/projects/chm/configs/above_shrubs_chm.yaml b/projects/chm/configs/above_shrubs_chm.yaml index e849219..4e361f8 100644 --- a/projects/chm/configs/above_shrubs_chm.yaml +++ b/projects/chm/configs/above_shrubs_chm.yaml @@ -21,6 +21,8 @@ test_tiles_dir: '${work_dir}/test_data' #dtm_path: /explore/nobackup/projects/dem/AK_IFSAR/alaska_ifsar_dtm_20221222.tif #dsm_path: /explore/nobackup/projects/dem/AK_IFSAR/alaska_ifsar_dsm_20221222.tif +batch_size: 512 + #------------------------------------- General -------------------------------------# #model_dir: '/explore/nobackup/people/${oc.env:USER}/projects/SRLite/regression-models' diff --git a/requirements/Dockerfile b/requirements/Dockerfile index cbdfccf..f3c5426 100644 --- a/requirements/Dockerfile +++ b/requirements/Dockerfile @@ -1,5 +1,5 @@ # Arguments to pass to the image -ARG VERSION_DATE=24.08 +ARG VERSION_DATE=24.09 ARG FROM_IMAGE=nvcr.io/nvidia/pytorch # Import RAPIDS container as the BASE Image (cuda base image) @@ -18,7 +18,7 @@ RUN apt-get update && \ RUN pip --no-cache-dir install --ignore-installed omegaconf \ #terratorch \ pytorch-lightning \ - Lightning \ + #Lightning \ transformers \ datasets \ webdataset \