Skip to content

Commit

Permalink
Fixing container
Browse files Browse the repository at this point in the history
  • Loading branch information
jordancaraballo committed Sep 30, 2024
1 parent 43670d1 commit b6d8ad0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
37 changes: 19 additions & 18 deletions above_shrubs/pipelines/chm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,33 +78,34 @@ def preprocess(self):
self._set_train_test_dirs()

# Calculate mean and std values for training
data_filenames = self.get_dataset_filenames(self.train_images_dir)
data_filenames = self.get_dataset_filenames(
self.train_data_dir, , ext='*.tif')
logging.info(f'Mean and std values from {len(data_filenames)} files.')

# Temporarily disable standardization and augmentation
current_standardization = self.conf.standardization
self.conf.standardization = None
metadata_output_filename = os.path.join(
self.metadata_dir, 'mean-std-values.csv')
#current_standardization = self.conf.standardization
#self.conf.standardization = None
#metadata_output_filename = os.path.join(
# self.metadata_dir, 'mean-std-values.csv')

# Set main data loader
chm_train_dataset = CHMDataset(
os.path.join(self.conf.train_tiles_dir, 'images'),
os.path.join(self.conf.train_tiles_dir, 'labels'),
img_size=(self.conf.tile_size, self.conf.tile_size),
)
train_dataloader = DataLoader(
chm_train_dataset,
batch_size=self.conf.batch_size, shuffle=False
)
#chm_train_dataset = CHMDataset(
# os.path.join(self.conf.train_tiles_dir, 'images'),
# os.path.join(self.conf.train_tiles_dir, 'labels'),
# img_size=(self.conf.tile_size, self.conf.tile_size),
#)
#train_dataloader = DataLoader(
# chm_train_dataset,
# batch_size=self.conf.batch_size, shuffle=False
#)

# Get mean and std array
mean, std = self.get_mean_std_dataset(
train_dataloader, metadata_output_filename)
logging.info(f'Mean: {mean.numpy()}, Std: {std.numpy()}')
#mean, std = self.get_mean_std_dataset(
# train_dataloader, metadata_output_filename)
#logging.info(f'Mean: {mean.numpy()}, Std: {std.numpy()}')

# Re-enable standardization for next pipeline step
self.conf.standardization = current_standardization
#self.conf.standardization = current_standardization

logging.info('Done with preprocessing stage')

Expand Down
2 changes: 1 addition & 1 deletion requirements/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ RUN apt-get update && \

# Pip
RUN pip --no-cache-dir install --ignore-installed omegaconf \
terratorch \
#terratorch \
pytorch-lightning \
Lightning \
transformers \
Expand Down

0 comments on commit b6d8ad0

Please sign in to comment.