From 43670d1f63ccd39dad597287984395544a5a86a2 Mon Sep 17 00:00:00 2001 From: jordancaraballo Date: Mon, 30 Sep 2024 12:09:37 -0400 Subject: [PATCH] Adding new torchgeo dataset --- above_shrubs/datasets/chm_dataset.py | 61 ++++++++++++++++++-------- above_shrubs/pipelines/chm_pipeline.py | 2 +- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/above_shrubs/datasets/chm_dataset.py b/above_shrubs/datasets/chm_dataset.py index 217a08e..7344be4 100644 --- a/above_shrubs/datasets/chm_dataset.py +++ b/above_shrubs/datasets/chm_dataset.py @@ -1,11 +1,15 @@ import os +import sys import numpy as np -from torch.utils.data import Dataset +import rioxarray as rxr +from typing import Any +from pathlib import Path +from torchgeo.datasets import NonGeoDataset -class CHMDataset(Dataset): +class CHMDataset(NonGeoDataset): """ - CHM Regression dataset. + CHM Regression dataset from NonGeoDataset. """ def __init__( @@ -14,7 +18,8 @@ def __init__( mask_paths: list, img_size: tuple = (256, 256), transform=None, - ): + ) -> None: + super().__init__() # image size self.image_size = img_size @@ -42,22 +47,42 @@ def __init__( self.image_list.extend(self.get_filenames(image_path)) self.mask_list.extend(self.get_filenames(mask_path)) - def __len__(self): - return len(self.image_list) - - def __getitem__(self, idx, transpose=True): - - # load image - img = np.load(self.image_list[idx]) + # rgb indices for some plots + self.rgb_indices = [0, 1, 2] - # load mask - mask = np.load(self.mask_list[idx]) - - # perform transformations - if self.transform is not None: - img = self.transform(img) + def __len__(self) -> int: + return len(self.image_list) - return img, mask + # def __getitem__(self, idx, transpose=True): + # + # # load image + # img = np.load(self.image_list[idx]) + # + # # load mask + # mask = np.load(self.mask_list[idx]) + # # perform transformations + # if self.transform is not None: + # img = self.transform(img) + # + # return img, mask + + def __getitem__(self, index: int) -> dict[str, Any]: + output = { + "image": self._load_file( + self.image_list[index]).astype(np.float32), + "mask": self._load_file( + self.mask_list[index]).astype(np.int64), + } + return output + + def _load_file(self, path: Path): + if Path(path).suffix == '.npy': + data = np.load(path) + elif Path(path).suffix == '.tif': + data = rxr.open_rasterio(path) + else: + sys.exit('Non-recognized dataset format. Expects npy or tif.') + return data.to_numpy() def get_filenames(self, path): """ diff --git a/above_shrubs/pipelines/chm_pipeline.py b/above_shrubs/pipelines/chm_pipeline.py index 9729cd4..71e2d4d 100644 --- a/above_shrubs/pipelines/chm_pipeline.py +++ b/above_shrubs/pipelines/chm_pipeline.py @@ -3,7 +3,7 @@ import timm import torch -import terratorch +import terratorch import numpy as np from itertools import repeat