Skip to content

Commit

Permalink
Adding new torchgeo dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jordancaraballo committed Sep 30, 2024
1 parent 436a661 commit 43670d1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
61 changes: 43 additions & 18 deletions above_shrubs/datasets/chm_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import sys
import numpy as np
from torch.utils.data import Dataset
import rioxarray as rxr
from typing import Any
from pathlib import Path
from torchgeo.datasets import NonGeoDataset


class CHMDataset(Dataset):
class CHMDataset(NonGeoDataset):
"""
CHM Regression dataset.
CHM Regression dataset from NonGeoDataset.
"""

def __init__(
Expand All @@ -14,7 +18,8 @@ def __init__(
mask_paths: list,
img_size: tuple = (256, 256),
transform=None,
):
) -> None:
super().__init__()

# image size
self.image_size = img_size
Expand Down Expand Up @@ -42,22 +47,42 @@ def __init__(
self.image_list.extend(self.get_filenames(image_path))
self.mask_list.extend(self.get_filenames(mask_path))

def __len__(self):
return len(self.image_list)

def __getitem__(self, idx, transpose=True):

# load image
img = np.load(self.image_list[idx])
# rgb indices for some plots
self.rgb_indices = [0, 1, 2]

# load mask
mask = np.load(self.mask_list[idx])

# perform transformations
if self.transform is not None:
img = self.transform(img)
def __len__(self) -> int:
return len(self.image_list)

return img, mask
# def __getitem__(self, idx, transpose=True):
#
# # load image
# img = np.load(self.image_list[idx])
#
# # load mask
# mask = np.load(self.mask_list[idx])
# # perform transformations
# if self.transform is not None:
# img = self.transform(img)
#
# return img, mask

def __getitem__(self, index: int) -> dict[str, Any]:
output = {
"image": self._load_file(
self.image_list[index]).astype(np.float32),
"mask": self._load_file(
self.mask_list[index]).astype(np.int64),
}
return output

def _load_file(self, path: Path):
if Path(path).suffix == '.npy':
data = np.load(path)
elif Path(path).suffix == '.tif':
data = rxr.open_rasterio(path)
else:
sys.exit('Non-recognized dataset format. Expects npy or tif.')
return data.to_numpy()

def get_filenames(self, path):
"""
Expand Down
2 changes: 1 addition & 1 deletion above_shrubs/pipelines/chm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import timm
import torch
import terratorch
import terratorch
import numpy as np

from itertools import repeat
Expand Down

0 comments on commit 43670d1

Please sign in to comment.