Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sfno to latest version #204

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions modulus/datapipes/climate/sfno/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def init_distributed_io(params): # pragma: no cover
num_io_ranks = math.prod(params.io_grid)
if not ((num_io_ranks == 1) or (num_io_ranks == comm.get_size("spatial"))):
raise AssertionError
if not (params.io_grid[1] == comm.get_size("h")) or (params.io_grid[1] == 1):
if not ((params.io_grid[1] == comm.get_size("h")) or (params.io_grid[1] == 1)):
raise AssertionError
if not (params.io_grid[2] == comm.get_size("w")) or (params.io_grid[2] == 1):
if not ((params.io_grid[2] == comm.get_size("w")) or (params.io_grid[2] == 1)):
raise AssertionError

params.io_rank = [0, 0, 0]
Expand All @@ -79,9 +79,10 @@ def get_dataloader(
return zarr.get_data_loader(params, files_pattern, train)

elif params.get("multifiles", False):
from utils.dataloaders.data_loader_multifiles import (
from modulus.datapipes.climate.sfno.dataloaders.data_loader_multifiles import (
MultifilesDataset as MultifilesDataset2D,
)
from torch.utils.data.distributed import DistributedSampler

# multifiles dataset
dataset = MultifilesDataset2D(params, files_pattern, train)
Expand All @@ -101,12 +102,16 @@ def get_dataloader(
dataset,
batch_size=int(params.batch_size),
num_workers=params.num_data_workers,
shuffle=False, # (sampler is None),
shuffle=False,
sampler=sampler if train else None,
drop_last=True,
pin_memory=torch.cuda.is_available(),
)

# for compatibility with the DALI dataloader
dataloader.get_output_normalization = dataset.get_output_normalization
dataloader.get_input_normalization = dataset.get_input_normalization

elif params.enable_synthetic_data:
from modulus.datapipes.climate.sfno.dataloaders.data_loader_dummy import (
DummyLoader,
Expand All @@ -127,6 +132,8 @@ def get_dataloader(
img_local_shape_y=dataloader.img_local_shape_y,
img_local_offset_x=dataloader.img_local_offset_x,
img_local_offset_y=dataloader.img_local_offset_y,
img_local_pad_x=dataloader.img_local_pad_x,
img_local_pad_y=dataloader.img_local_pad_y,
)

# not needed for the no multifiles case
Expand Down Expand Up @@ -154,6 +161,8 @@ def get_dataloader(
img_local_shape_y=dataloader.img_local_shape_y,
img_local_offset_x=dataloader.img_local_offset_x,
img_local_offset_y=dataloader.img_local_offset_y,
img_local_pad_x=dataloader.img_local_pad_x,
img_local_pad_y=dataloader.img_local_pad_y,
)

if params.enable_benchy and train:
Expand Down
151 changes: 53 additions & 98 deletions modulus/datapipes/climate/sfno/dataloaders/dali_es_helper_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import torch
import zarr

# can replace this import with zoneinfo from the standard library in python3.9+.
import pytz
mnabian marked this conversation as resolved.
Show resolved Hide resolved

from modulus.utils.zenith_angle import cos_zenith_angle


Expand All @@ -50,6 +53,7 @@ def __init__(
train,
batch_size,
dt,
dhours,
n_history,
n_future,
in_channels,
Expand All @@ -64,10 +68,10 @@ def __init__(
truncate_old=True,
enable_logging=True,
zenith_angle=True,
lat_lon=None,
dataset_path="fields",
seed=333,
is_parallel=True,
host_prefetch_buffers=False,
timestep_hours=6,
): # pragma: no cover
self.batch_size = batch_size
self.location = location
Expand All @@ -76,6 +80,7 @@ def __init__(
self.truncate_old = truncate_old
self.train = train
self.dt = dt
self.dhours = dhours
self.n_history = n_history
self.n_future = n_future
self.in_channels = in_channels
Expand All @@ -89,9 +94,8 @@ def __init__(
self.device_id = device_id
self.shard_id = shard_id
self.is_parallel = is_parallel
self.host_prefetch_buffers = host_prefetch_buffers
self.zenith_angle = zenith_angle
self.timestep_hours = timestep_hours
self.dataset_path = dataset_path

# set the read slices
# we do not support channel parallelism yet
Expand All @@ -110,8 +114,11 @@ def __init__(

# we need some additional static fields in this case
if self.zenith_angle:
longitude = np.linspace(0, 360, self.img_shape[1], endpoint=False)
latitude = np.linspace(90, -90, self.img_shape[0])
if lat_lon is not None:
latitude, longitude = lat_lon
else:
longitude = np.linspace(0, 360, self.img_shape[1], endpoint=False)
latitude = np.linspace(90, -90, self.img_shape[0])
self.lon_grid, self.lat_grid = np.meshgrid(longitude, latitude)
self.lat_grid_local = self.lat_grid[
self.read_anchor[0] : self.read_anchor[0] + self.read_shape[0],
Expand All @@ -122,30 +129,25 @@ def __init__(
self.read_anchor[1] : self.read_anchor[1] + self.read_shape[1],
]

# these things we want to read from a descriptor file ultimately:
self.dt_samples = 6

# HDF5 routines
def _get_stats_h5(self, enable_logging): # pragma: no cover
with h5py.File(self.files_paths[0], "r") as _f:
if enable_logging:
logging.info("Getting file stats from {}".format(self.files_paths[0]))
# original image shape (before padding)
self.img_shape = _f["fields"].shape[
2:4
] # - 1 #just get rid of one of the pixels
self.total_channels = _f["fields"].shape[1]
self.img_shape = _f[self.dataset_path].shape[2:4]
self.total_channels = _f[self.dataset_path].shape[1]

# get all sample counts
self.n_samples_year = []
for filename in self.files_paths:
with h5py.File(filename, "r") as _f:
self.n_samples_year.append(_f["fields"].shape[0])
self.n_samples_year.append(_f[self.dataset_path].shape[0])
return

def _get_year_h5(self, year_idx): # pragma: no cover
self.files[year_idx] = h5py.File(self.files_paths[year_idx], "r")
self.dsets[year_idx] = self.files[year_idx]["fields"]
self.dsets[year_idx] = self.files[year_idx][self.dataset_path]
return

def _get_data_h5(
Expand All @@ -155,7 +157,6 @@ def _get_data_h5(
for slice_in in self.in_channels_slices:
start = off
end = start + (slice_in.stop - slice_in.start)
# inp[:, start:end, ...] = dset[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, slice_in, start_x:end_x, start_y:end_y]
dset.read_direct(
inp,
np.s_[
Expand All @@ -172,7 +173,6 @@ def _get_data_h5(
for slice_out in self.out_channels_slices:
start = off
end = start + (slice_out.stop - slice_out.start)
# tar[:, start:end, ...] = dset[(local_idx + self.dt):(local_idx + self.dt * (self.n_future + 1) + 1):self.dt, slice_out, start_x:end_x, start_y:end_y]
dset.read_direct(
tar,
np.s_[
Expand All @@ -195,21 +195,19 @@ def _get_stats_zarr(self, enable_logging): # pragma: no cover
if enable_logging:
logging.info("Getting file stats from {}".format(self.files_paths[0]))
# original image shape (before padding)
self.img_shape = _f["/fields"].shape[
2:4
] # - 1 #just get rid of one of the pixels
self.total_channels = _f["/fields"].shape[1]
self.img_shape = _f[f"/{self.dataset_path}"].shape[2:4]
self.total_channels = _f[f"/{self.dataset_path}"].shape[1]

self.n_samples_year = []
for filename in self.files_paths:
with zarr.convenience.open(filename, "r") as _f:
self.n_samples_year.append(_f["/fields"].shape[0])
self.n_samples_year.append(_f[f"/{self.dataset_path}"].shape[0])

return

def _get_year_zarr(self, year_idx): # pragma: no cover
self.files[year_idx] = zarr.convenience.open(self.files_paths[year_idx], "r")
self.dsets[year_idx] = self.files[year_idx]["/fields"]
self.dsets[year_idx] = self.files[year_idx][f"/{self.dataset_path}"]
return

def _get_data_zarr(
Expand Down Expand Up @@ -255,10 +253,13 @@ def _get_files_stats(self, enable_logging): # pragma: no cover
)
self.file_format = "h5"

# # TODO: probably requires fix to re-enable zarr
# if not self.files_paths:
# self.files_paths = glob.glob(os.path.join(self.location, "*.zarr"))
# self.file_format = "zarr"
# check for zarr files if no hdf5 files are found
if not self.files_paths:
for location in self.location:
self.files_paths = self.files_paths + glob.glob(
os.path.join(location, "????.zarr")
)
self.file_format = "zarr"

if not self.files_paths:
raise IOError(
Expand Down Expand Up @@ -368,17 +369,17 @@ def _get_files_stats(self, enable_logging): # pragma: no cover
self.batch_size,
)
)
logging.info("Delta t: {} hours".format(self.timestep_hours * self.dt))
logging.info("Delta t: {} hours".format(self.dhours * self.dt))
logging.info(
"Including {} hours of past history in training at a frequency of {} hours".format(
self.timestep_hours * self.dt * self.n_history,
self.timestep_hours * self.dt,
self.dhours * self.dt * (self.n_history + 1),
self.dhours * self.dt,
)
)
logging.info(
"Including {} hours of future targets in training at a frequency of {} hours".format(
self.timestep_hours * self.dt * self.n_future,
self.timestep_hours * self.dt,
self.dhours * self.dt * (self.n_future + 1),
self.dhours * self.dt,
)
)

Expand All @@ -390,30 +391,7 @@ def _get_files_stats(self, enable_logging): # pragma: no cover
if not self.is_parallel:
self._init_buffers()

def _init_double_buff_host(self, n_tsteps, n_channels): # pragma: no cover
buffs = [
np.zeros(
(
n_tsteps,
n_channels,
self.read_shape[0],
self.read_shape[1],
),
dtype=np.float32,
),
np.zeros(
(
n_tsteps,
n_channels,
self.read_shape[0],
self.read_shape[1],
),
dtype=np.float32,
),
]
return buffs

def _init_double_buff_gpu(self, n_tsteps, n_channels): # pragma: no cover
def _init_double_buff(self, n_tsteps, n_channels): # pragma: no cover
buffs = [
cpx.zeros_pinned(
(
Expand Down Expand Up @@ -441,27 +419,11 @@ def _init_buffers(self): # pragma: no cover
self.device = cp.cuda.Device(self.device_id)
self.device.use()
self.current_buffer = 0
if self.host_prefetch_buffers:
self.inp_buffs = self._init_double_buff_host(
self.n_history + 1, self.n_in_channels
)
self.tar_buffs = self._init_double_buff_host(
self.n_future + 1, self.n_out_channels
)
else:
self.inp_buffs = self._init_double_buff_gpu(
self.n_history + 1, self.n_in_channels
)
self.tar_buffs = self._init_double_buff_gpu(
self.n_future + 1, self.n_out_channels
)
self.inp_buffs = self._init_double_buff(self.n_history + 1, self.n_in_channels)
self.tar_buffs = self._init_double_buff(self.n_future + 1, self.n_out_channels)
if self.zenith_angle:
if self.host_prefetch_buffers:
self.zen_inp_buffs = self._init_double_buff_host(self.n_history + 1, 1)
self.zen_tar_buffs = self._init_double_buff_host(self.n_future + 1, 1)
else:
self.zen_inp_buffs = self._init_double_buff_gpu(self.n_history + 1, 1)
self.zen_tar_buffs = self._init_double_buff_gpu(self.n_future + 1, 1)
self.zen_inp_buffs = self._init_double_buff(self.n_history + 1, 1)
self.zen_tar_buffs = self._init_double_buff(self.n_future + 1, 1)
return

def _compute_zenith_angle(
Expand All @@ -471,51 +433,47 @@ def _compute_zenith_angle(

# compute hours into the year
year = self.years[year_idx]
jan_01_epoch = datetime.datetime(year, 1, 1, 0, 0, 0)
jan_01_epoch = datetime.datetime(year, 1, 1, 0, 0, 0, tzinfo=pytz.utc)

# zenith angle for input
inp_times = np.asarray(
[
jan_01_epoch + datetime.timedelta(hours=idx * self.timestep_hours)
jan_01_epoch + datetime.timedelta(hours=idx * self.dhours)
for idx in range(
local_idx - self.dt * self.n_history, local_idx + 1, self.dt
)
]
)
cos_zenith_inp = np.asarray(
cos_zenith_inp = np.expand_dims(
[
np.expand_dims(
cos_zenith_angle(
inp_time, self.lon_grid_local, self.lat_grid_local
).astype(np.float32),
axis=0,
)
cos_zenith_angle(
inp_time, self.lon_grid_local, self.lat_grid_local
).astype(np.float32)
for inp_time in inp_times
]
],
axis=0,
)
zen_inp[...] = cos_zenith_inp[...]

# zenith angle for target
tar_times = np.asarray(
[
jan_01_epoch + datetime.timedelta(hours=idx * self.timestep_hours)
jan_01_epoch + datetime.timedelta(hours=idx * self.dhours)
for idx in range(
local_idx + self.dt,
local_idx + self.dt * (self.n_future + 1) + 1,
self.dt,
)
]
)
cos_zenith_tar = np.asarray(
cos_zenith_tar = np.expand_dims(
[
np.expand_dims(
cos_zenith_angle(
tar_time, self.lon_grid_local, self.lat_grid_local
).astype(np.float32),
axis=0,
)
cos_zenith_angle(
tar_time, self.lon_grid_local, self.lat_grid_local
).astype(np.float32)
for tar_time in tar_times
]
],
axis=1,
)
zen_tar[...] = cos_zenith_tar[...]

Expand Down Expand Up @@ -554,11 +512,8 @@ def __call__(self, sample_info): # pragma: no cover
cycle_sample_idx = global_sample_idx % self.num_samples_per_cycle_shard
cycle_epoch_idx = global_sample_idx // self.num_samples_per_cycle_shard

# print(f'{"TRAIN" if self.train else "VALIDATION"} ITER INFO:', sample_info.idx_in_epoch, self.num_samples_per_epoch_shard)

# check if epoch is done
if sample_info.iteration >= self.num_steps_per_epoch:
# print(f'{"TRAIN" if self.train else "VALIDATION"} END OF EPOCH TRIGGERED FOR', sample_info.idx_in_epoch, self.num_samples_per_epoch_shard, sample_info.iteration, self.num_steps_per_epoch)
raise StopIteration

torch.cuda.nvtx.range_push("GeneralES:__call__")
Expand Down
Loading