Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scaling for HPX dataloaders, code cleanup #721

Merged
merged 13 commits into from
Dec 5, 2024
45 changes: 4 additions & 41 deletions modulus/datapipes/healpix/coupledtimeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,43 +157,10 @@ def _get_scaling_da(self):
scaling_df.loc["zeros"] = {"mean": 0.0, "std": 1.0}
scaling_da = scaling_df.to_xarray().astype("float32")

# only thing we do different here is get the scaling for the coupled values
for c in self.couplings:
c.set_scaling(scaling_da)
# REMARK: we remove the xarray overhead from these
try:
self.input_scaling = scaling_da.sel(index=self.input_variables).rename(
{"index": "channel_in"}
)
self.input_scaling = {
"mean": np.expand_dims(
self.input_scaling["mean"].to_numpy(), (0, 2, 3, 4)
),
"std": np.expand_dims(
self.input_scaling["std"].to_numpy(), (0, 2, 3, 4)
),
}
except (ValueError, KeyError):
raise KeyError(
f"one or more of the input data variables f{list(self.ds.channel_in)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
)
try:
self.target_scaling = scaling_da.sel(index=self.input_variables).rename(
{"index": "channel_out"}
)
self.target_scaling = {
"mean": np.expand_dims(
self.target_scaling["mean"].to_numpy(), (0, 2, 3, 4)
),
"std": np.expand_dims(
self.target_scaling["std"].to_numpy(), (0, 2, 3, 4)
),
}
except (ValueError, KeyError):
raise KeyError(
f"one or more of the target data variables f{list(self.ds.channel_out)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
)
super()._get_scaling_da()

def __getitem__(self, item):
# start range
Expand Down Expand Up @@ -251,7 +218,6 @@ def __getitem__(self, item):
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("CoupledTimeSeriesDataset:__getitem__:process_batch")
compute_time = time.time()
# Insolation
if self.add_insolation:
sol = insolation(
Expand Down Expand Up @@ -305,11 +271,9 @@ def __getitem__(self, item):
np.transpose(x, axes=(0, 3, 1, 2, 4, 5)) for x in inputs_result
]

if "constants" in self.ds.data_vars:
if self.constants is not None:
# Add the constants as [F, C, H, W]
inputs_result.append(np.swapaxes(self.ds.constants.values, 0, 1))
# inputs_result.append(self.ds.constants.values)
logger.log(5, "computed batch in %0.2f s", time.time() - compute_time)
inputs_result.append(self.constants)

# append integrated couplings
inputs_result.append(integrated_couplings)
Expand All @@ -328,7 +292,6 @@ def __getitem__(self, item):
return inputs_result, targets

def next_integration(self, model_outputs, constants):

inputs_result = []

# grab last few model outputs for re-initialization
Expand Down
10 changes: 4 additions & 6 deletions modulus/datapipes/healpix/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def open_time_series_dataset_classic_on_the_fly(
file_name = _get_file_name(directory, prefix, variable, suffix)
logger.debug("open nc dataset %s", file_name)

ds = xr.open_dataset(file_name, chunks={"sample": batch_size}, autoclose=True)
ds = xr.open_dataset(file_name, autoclose=True)

if "LL" in prefix:
ds = ds.rename({"lat": "height", "lon": "width"})
Expand Down Expand Up @@ -212,7 +212,7 @@ def open_time_series_dataset_classic_prebuilt(
if not ds_path.exists():
raise FileNotFoundError(f"Dataset doesn't appear to exist at {ds_path}")

result = xr.open_zarr(ds_path, chunks={"time": batch_size})
result = xr.open_zarr(ds_path)
daviddpruitt marked this conversation as resolved.
Show resolved Hide resolved
return result


Expand Down Expand Up @@ -286,11 +286,9 @@ def create_time_series_dataset_classic(
file_name = _get_file_name(src_directory, prefix, variable, suffix)
logger.debug("open nc dataset %s", file_name)
if "sample" in list(xr.open_dataset(file_name).sizes.keys()):
ds = xr.open_dataset(file_name, chunks={"sample": batch_size}).rename(
{"sample": "time"}
)
ds = xr.open_dataset(file_name).rename({"sample": "time"})
else:
ds = xr.open_dataset(file_name, chunks={"time": batch_size})
ds = xr.open_dataset(file_name)
if "varlev" in ds.dims:
ds = ds.isel(varlev=0)

Expand Down
71 changes: 57 additions & 14 deletions modulus/datapipes/healpix/timeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,24 +189,35 @@ def __init__(

self.input_scaling = None
self.target_scaling = None
self.constant_scaling = None
self.constants = None
if self.scaling:
self._get_scaling_da()

# setup constants
if "constants" in self.ds.data_vars:
# extract from ds:
const = self.ds.constants.values

if self.scaling:
daviddpruitt marked this conversation as resolved.
Show resolved Hide resolved
const = (const - self.constant_scaling["mean"]) / self.constant_scaling[
"std"
]

# transpose to match new format:
# [C, F, H, W] -> [F, C, H, W]
self.constants = np.transpose(const, axes=(1, 0, 2, 3))

self.get_constants()
daviddpruitt marked this conversation as resolved.
Show resolved Hide resolved

def get_constants(self):
"""Returns the constants used in this dataset

Returns
-------
np.ndarray: The list of constants, None if there are no constants
"""
# extract from ds:
const = self.ds.constants.values

# transpose to match new format:
# [C, F, H, W] -> [F, C, H, W]
const = np.transpose(const, axes=(1, 0, 2, 3))

return const
return self.constants

@staticmethod
def _convert_time_step(dt): # pylint: disable=invalid-name
Expand Down Expand Up @@ -244,9 +255,13 @@ def _get_scaling_da(self):
),
}
except (ValueError, KeyError):
missing = [
m
for m in self.ds.channel_in.values
if m not in list(self.scaling.keys())
]
raise KeyError(
f"one or more of the input data variables f{list(self.ds.channel_in)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
f"Input channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})"
)
try:
self.target_scaling = scaling_da.sel(
Expand All @@ -261,9 +276,37 @@ def _get_scaling_da(self):
),
}
except (ValueError, KeyError):
missing = [
m
for m in self.ds.channel_out.values
if m not in list(self.scaling.keys())
]
raise KeyError(
f"one or more of the target data variables f{list(self.ds.channel_out)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
f"Target channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})"
)

try:
# not all datasets will have constants
if "constants" in self.ds.data_vars:
self.constant_scaling = scaling_da.sel(
index=self.ds.channel_c.values
).rename({"index": "channel_out"})
self.constant_scaling = {
"mean": np.expand_dims(
self.constant_scaling["mean"].to_numpy(), (1, 2, 3)
),
"std": np.expand_dims(
self.constant_scaling["std"].to_numpy(), (1, 2, 3)
),
}
except (ValueError, KeyError):
missing = [
m
for m in self.ds.channel_c.values
if m not in list(self.scaling.keys())
]
raise KeyError(
f"Constant channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})"
)

def __len__(self):
Expand Down Expand Up @@ -428,9 +471,9 @@ def __getitem__(self, item):
np.transpose(x, axes=(0, 3, 1, 2, 4, 5)) for x in inputs_result
]

if "constants" in self.ds.data_vars:
if self.constants is not None:
# Add the constants as [F, C, H, W]
inputs_result.append(np.swapaxes(self.ds.constants.values, 0, 1))
inputs_result.append(self.constants)

logger.log(5, "computed batch in %0.2f s", time.time() - compute_time)
torch.cuda.nvtx.range_pop()
Expand Down
13 changes: 10 additions & 3 deletions test/datapipes/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def scaling_dict():
"z1000": {"mean": 952.1435546875, "std": 895.7516479492188},
"z250": {"mean": 101186.28125, "std": 5551.77978515625},
"z500": {"mean": 55625.9609375, "std": 2681.712890625},
"lsm": {"mean": 0, "std": 1},
"z": {"mean": 0, "std": 1},
"tp6": {"mean": 1, "std": 0, "log_epsilon": 1e-6},
"extra": {"mean": 1, "std": 0},
}
return DictConfig(scaling)

Expand All @@ -88,6 +91,9 @@ def scaling_double_dict():
"z250": {"mean": 0, "std": 2},
"z500": {"mean": 0, "std": 2},
"tp6": {"mean": 0, "std": 2, "log_epsilon": 1e-6},
"lsm": {"mean": 0, "std": 2},
"z": {"mean": 0, "std": 2},
"extra": {"mean": 0, "std": 2},
}
return DictConfig(scaling)

Expand Down Expand Up @@ -212,7 +218,7 @@ def test_TimeSeriesDataset_initialization(
"bogosity": {"mean": 0, "std": 42},
}
)
with pytest.raises(KeyError, match=("one or more of the input data variables")):
with pytest.raises(KeyError, match=("Input channels ")):
timeseries_ds = TimeSeriesDataset(
dataset=zarr_ds,
data_time_step="3h",
Expand Down Expand Up @@ -550,7 +556,8 @@ def test_TimeSeriesDataModule_get_constants(
# open our test dataset
ds_path = Path(data_dir, dataset_name + ".zarr")
zarr_ds = xr.open_zarr(ds_path)
expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3))
# dividing by 2 due to scaling
expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3)) / 2

assert np.array_equal(
timeseries_dm.get_constants(),
Expand Down Expand Up @@ -617,7 +624,7 @@ def test_TimeSeriesDataModule_get_dataloaders(
test_dataloader, test_sampler = timeseries_dm.test_dataloader(num_shards=1)
assert test_sampler is None
assert isinstance(test_dataloader, DataLoader)
print(f"dataset lenght {len}")

# with >1 shard should be distributed sampler
train_dataloader, train_sampler = timeseries_dm.train_dataloader(num_shards=2)
assert isinstance(train_sampler, DistributedSampler)
Expand Down
Loading