Skip to content

Commit

Permalink
Adding datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
jordancaraballo committed Oct 2, 2024
1 parent b6d8ad0 commit e8b937d
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 46 deletions.
72 changes: 70 additions & 2 deletions above_shrubs/datamodules/chm_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,80 @@
import torch
import lightning as L
#import kornia.augmentation as K # noqa: N812

from typing import Any
from torchvision import transforms
from torchgeo.datamodules import NonGeoDataModule
from torch.utils.data import random_split, DataLoader
from torchgeo.transforms import AugmentationSequential
from above_shrubs.datasets.chm_dataset import CHMDataset

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST

MEANS = [325.04178, 518.01135, 393.07028, 2660.147, 343.5341]
STDS = [80.556175, 133.02502, 135.68076, 822.97205, 116.81135]


class CHMDataModule(NonGeoDataModule):
"""NonGeo Fire Scars data module implementation"""

def __init__(
self,
train_data_dir: str,
train_label_dir: str,
test_data_dir: str = None,
test_label_dir: str = None,
batch_size: int = 16,
num_workers: int = 8,
tile_size: tuple = (224, 224),
**kwargs: Any
) -> None:

super().__init__(CHMDataset, batch_size, num_workers, **kwargs)
# applied for training
#self.train_aug = AugmentationSequential(
# K.Normalize(MEANS, STDS),
# K.RandomCrop(tile_size),
# data_keys=["image", "mask"],
#)
#self.aug = AugmentationSequential(
# K.Normalize(MEANS, STDS),
# data_keys=["image", "mask"]
#)

# tile size
self.tile_size = tile_size

# training paths
self.train_data_dir = train_data_dir
self.train_label_dir = train_label_dir

# test paths
self.test_data_dir = test_data_dir
self.test_label_dir = test_label_dir

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_data_dir,
self.train_label_dir,
img_size=self.tile_size,
)
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class(
self.test_data_dir,
self.test_label_dir,
img_size=self.tile_size,
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
self.test_data_dir,
self.test_label_dir,
img_size=self.tile_size,
)



"""
class CHMDataModule(L.LightningDataModule):
Expand Down Expand Up @@ -47,3 +114,4 @@ def test_dataloader(self):
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
"""
7 changes: 4 additions & 3 deletions above_shrubs/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,12 @@ def seed_everything(self, seed: int = 42) -> None:
Returns:
None.
"""
np.random.seed(seed)
np.random.seed(int(seed))
if HAS_GPU:
try:
cp.random.seed(seed)
except RuntimeError:
cp.random.seed(int(seed))
except (RuntimeError, TypeError):
logging.warning('Seed could not be fixed for cupy.')
return
return

Expand Down
141 changes: 102 additions & 39 deletions above_shrubs/pipelines/chm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,33 @@

import timm
import torch
import terratorch
import numpy as np
import pandas as pd

from tqdm import tqdm
from itertools import repeat
from torch.utils.data import DataLoader
from multiprocessing import Pool, cpu_count
from above_shrubs.datasets.chm_dataset import CHMDataset
from above_shrubs.datamodules.chm_datamodule import CHMDataModule
from above_shrubs.pipelines.base_pipeline import BasePipeline


# temporary additions

import os
import tempfile

import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import PixelwiseRegressionTask


CHUNKS = {'band': 'auto', 'x': 'auto', 'y': 'auto'}
xp = np

Expand Down Expand Up @@ -79,33 +97,34 @@ def preprocess(self):

# Calculate mean and std values for training
data_filenames = self.get_dataset_filenames(
self.train_data_dir, , ext='*.tif')
self.conf.train_data_dir, ext='*.tif')
logging.info(f'Mean and std values from {len(data_filenames)} files.')

# Temporarily disable standardization and augmentation
#current_standardization = self.conf.standardization
#self.conf.standardization = None
#metadata_output_filename = os.path.join(
# self.metadata_dir, 'mean-std-values.csv')
current_standardization = self.conf.standardization
self.conf.standardization = None
metadata_output_filename = os.path.join(
self.metadata_dir, 'mean-std-values.csv')

# Set main data loader
#chm_train_dataset = CHMDataset(
# os.path.join(self.conf.train_tiles_dir, 'images'),
# os.path.join(self.conf.train_tiles_dir, 'labels'),
# img_size=(self.conf.tile_size, self.conf.tile_size),
#)
#train_dataloader = DataLoader(
# chm_train_dataset,
# batch_size=self.conf.batch_size, shuffle=False
#)
chm_train_dataset = CHMDataset(
self.conf.train_data_dir,
self.conf.train_label_dir,
img_size=(self.conf.tile_size, self.conf.tile_size),
)

train_dataloader = DataLoader(
chm_train_dataset,
batch_size=self.conf.batch_size, shuffle=False
)

# Get mean and std array
#mean, std = self.get_mean_std_dataset(
# train_dataloader, metadata_output_filename)
#logging.info(f'Mean: {mean.numpy()}, Std: {std.numpy()}')
mean, std = self.get_mean_std_dataset(
train_dataloader, metadata_output_filename)
logging.info(f'Mean: {mean.numpy()}, Std: {std.numpy()}')

# Re-enable standardization for next pipeline step
#self.conf.standardization = current_standardization
self.conf.standardization = current_standardization

logging.info('Done with preprocessing stage')

Expand All @@ -120,31 +139,45 @@ def train(self):

self._set_train_test_dirs()

# Set main data loader
chm_train_dataset = CHMDataset(
os.path.join(self.conf.train_tiles_dir, 'images'),
os.path.join(self.conf.train_tiles_dir, 'labels'),
img_size=(self.conf.tile_size, self.conf.tile_size),
batch_size = 10
num_workers = 2
max_epochs = 50
fast_dev_run = False

datamodule = CHMDataModule(
train_data_dir=self.conf.train_data_dir,
train_label_dir=self.conf.train_label_dir,
test_data_dir=self.conf.test_data_dir,
test_label_dir=self.conf.test_label_dir,
batch_size=16,
num_workers=8,
)

chm_test_dataset = CHMDataset(
os.path.join(self.conf.test_tiles_dir, 'images'),
os.path.join(self.conf.test_tiles_dir, 'labels'),
img_size=(self.conf.tile_size, self.conf.tile_size),
)
# Set main data loader
#chm_train_dataset = CHMDataset(
# os.path.join(self.conf.train_tiles_dir, 'images'),
# os.path.join(self.conf.train_tiles_dir, 'labels'),
# img_size=(self.conf.tile_size, self.conf.tile_size),
#)

#chm_test_dataset = CHMDataset(
# os.path.join(self.conf.test_tiles_dir, 'images'),
# os.path.join(self.conf.test_tiles_dir, 'labels'),
# img_size=(self.conf.tile_size, self.conf.tile_size),
#)

# start dataloader
train_dataloader = DataLoader(
chm_train_dataset,
batch_size=self.conf.batch_size, shuffle=True
)
#train_dataloader = DataLoader(
# chm_train_dataset,
# batch_size=self.conf.batch_size, shuffle=True
#)

test_dataloader = DataLoader(
chm_test_dataset,
batch_size=self.conf.batch_size, shuffle=False
)
#test_dataloader = DataLoader(
# chm_test_dataset,
# batch_size=self.conf.batch_size, shuffle=False
#)

print(timm.list_models("prithvi*"))
#print(timm.list_models("prithvi*"))
# Build model
# model = build_finetune_model(config, logger)

Expand All @@ -154,4 +187,34 @@ def train(self):
# is_pretrain=False,
# logger=logger)

return
return

# -------------------------------------------------------------------------
# get_mean_std_dataset
# -------------------------------------------------------------------------
def get_mean_std_dataset(self, dataloader, output_filename: str):

test_sample_shape = next(iter(dataloader))['image'].shape
if test_sample_shape[1] < test_sample_shape[2]:
mean_dims = (0, 2, 3)
else:
mean_dims = (1, 2, 3)

for index, data_dict in tqdm(enumerate(dataloader)):
data = data_dict['image'].cuda()
channels_sum, channels_squared_sum, num_batches = 0, 0, 0
channels_sum += torch.mean(data, dim=mean_dims)
channels_squared_sum += torch.mean(data**2, dim=mean_dims)
num_batches += 1

mean = channels_sum / num_batches
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

mean, std = mean.cpu(), std.cpu()

if output_filename is not None:
mean_std = np.stack(
[mean.numpy(), std.numpy()], axis=0)
pd.DataFrame(mean_std).to_csv(
output_filename, header=None, index=None)
return mean, std
2 changes: 2 additions & 0 deletions projects/chm/configs/above_shrubs_chm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ test_tiles_dir: '${work_dir}/test_data'
#dtm_path: /explore/nobackup/projects/dem/AK_IFSAR/alaska_ifsar_dtm_20221222.tif
#dsm_path: /explore/nobackup/projects/dem/AK_IFSAR/alaska_ifsar_dsm_20221222.tif

batch_size: 512

#------------------------------------- General -------------------------------------#

#model_dir: '/explore/nobackup/people/${oc.env:USER}/projects/SRLite/regression-models'
Expand Down
4 changes: 2 additions & 2 deletions requirements/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Arguments to pass to the image
ARG VERSION_DATE=24.08
ARG VERSION_DATE=24.09
ARG FROM_IMAGE=nvcr.io/nvidia/pytorch

# Import RAPIDS container as the BASE Image (cuda base image)
Expand All @@ -18,7 +18,7 @@ RUN apt-get update && \
RUN pip --no-cache-dir install --ignore-installed omegaconf \
#terratorch \
pytorch-lightning \
Lightning \
#Lightning \
transformers \
datasets \
webdataset \
Expand Down

0 comments on commit e8b937d

Please sign in to comment.