diff --git a/modulus/datapipes/healpix/coupledtimeseries_dataset.py b/modulus/datapipes/healpix/coupledtimeseries_dataset.py index 99d0343ad..70db9757f 100644 --- a/modulus/datapipes/healpix/coupledtimeseries_dataset.py +++ b/modulus/datapipes/healpix/coupledtimeseries_dataset.py @@ -157,43 +157,10 @@ def _get_scaling_da(self): scaling_df.loc["zeros"] = {"mean": 0.0, "std": 1.0} scaling_da = scaling_df.to_xarray().astype("float32") + # only thing we do different here is get the scaling for the coupled values for c in self.couplings: c.set_scaling(scaling_da) - # REMARK: we remove the xarray overhead from these - try: - self.input_scaling = scaling_da.sel(index=self.input_variables).rename( - {"index": "channel_in"} - ) - self.input_scaling = { - "mean": np.expand_dims( - self.input_scaling["mean"].to_numpy(), (0, 2, 3, 4) - ), - "std": np.expand_dims( - self.input_scaling["std"].to_numpy(), (0, 2, 3, 4) - ), - } - except (ValueError, KeyError): - raise KeyError( - f"one or more of the input data variables f{list(self.ds.channel_in)} not found in the " - f"scaling config dict data.scaling ({list(self.scaling.keys())})" - ) - try: - self.target_scaling = scaling_da.sel(index=self.input_variables).rename( - {"index": "channel_out"} - ) - self.target_scaling = { - "mean": np.expand_dims( - self.target_scaling["mean"].to_numpy(), (0, 2, 3, 4) - ), - "std": np.expand_dims( - self.target_scaling["std"].to_numpy(), (0, 2, 3, 4) - ), - } - except (ValueError, KeyError): - raise KeyError( - f"one or more of the target data variables f{list(self.ds.channel_out)} not found in the " - f"scaling config dict data.scaling ({list(self.scaling.keys())})" - ) + super()._get_scaling_da() def __getitem__(self, item): # start range @@ -251,7 +218,6 @@ def __getitem__(self, item): torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push("CoupledTimeSeriesDataset:__getitem__:process_batch") - compute_time = time.time() # Insolation if self.add_insolation: sol = insolation( @@ -305,11 +271,9 @@ def __getitem__(self, item): np.transpose(x, axes=(0, 3, 1, 2, 4, 5)) for x in inputs_result ] - if "constants" in self.ds.data_vars: + if self.constants is not None: # Add the constants as [F, C, H, W] - inputs_result.append(np.swapaxes(self.ds.constants.values, 0, 1)) - # inputs_result.append(self.ds.constants.values) - logger.log(5, "computed batch in %0.2f s", time.time() - compute_time) + inputs_result.append(self.constants) # append integrated couplings inputs_result.append(integrated_couplings) @@ -328,7 +292,6 @@ def __getitem__(self, item): return inputs_result, targets def next_integration(self, model_outputs, constants): - inputs_result = [] # grab last few model outputs for re-initialization diff --git a/modulus/datapipes/healpix/data_modules.py b/modulus/datapipes/healpix/data_modules.py index 665dc77bf..976529bec 100644 --- a/modulus/datapipes/healpix/data_modules.py +++ b/modulus/datapipes/healpix/data_modules.py @@ -119,7 +119,7 @@ def open_time_series_dataset_classic_on_the_fly( file_name = _get_file_name(directory, prefix, variable, suffix) logger.debug("open nc dataset %s", file_name) - ds = xr.open_dataset(file_name, chunks={"sample": batch_size}, autoclose=True) + ds = xr.open_dataset(file_name, autoclose=True) if "LL" in prefix: ds = ds.rename({"lat": "height", "lon": "width"}) @@ -212,7 +212,7 @@ def open_time_series_dataset_classic_prebuilt( if not ds_path.exists(): raise FileNotFoundError(f"Dataset doesn't appear to exist at {ds_path}") - result = xr.open_zarr(ds_path, chunks={"time": batch_size}) + result = xr.open_zarr(ds_path) return result @@ -286,11 +286,9 @@ def create_time_series_dataset_classic( file_name = _get_file_name(src_directory, prefix, variable, suffix) logger.debug("open nc dataset %s", file_name) if "sample" in list(xr.open_dataset(file_name).sizes.keys()): - ds = xr.open_dataset(file_name, chunks={"sample": batch_size}).rename( - {"sample": "time"} - ) + ds = xr.open_dataset(file_name).rename({"sample": "time"}) else: - ds = xr.open_dataset(file_name, chunks={"time": batch_size}) + ds = xr.open_dataset(file_name) if "varlev" in ds.dims: ds = ds.isel(varlev=0) diff --git a/modulus/datapipes/healpix/timeseries_dataset.py b/modulus/datapipes/healpix/timeseries_dataset.py index a36eb33ff..21b7c31e5 100644 --- a/modulus/datapipes/healpix/timeseries_dataset.py +++ b/modulus/datapipes/healpix/timeseries_dataset.py @@ -189,9 +189,25 @@ def __init__( self.input_scaling = None self.target_scaling = None + self.constant_scaling = None + self.constants = None if self.scaling: self._get_scaling_da() + # setup constants + if "constants" in self.ds.data_vars: + # extract from ds: + const = self.ds.constants.values + + if self.constant_scaling: + const = (const - self.constant_scaling["mean"]) / self.constant_scaling[ + "std" + ] + + # transpose to match new format: + # [C, F, H, W] -> [F, C, H, W] + self.constants = np.transpose(const, axes=(1, 0, 2, 3)) + def get_constants(self): """Returns the constants used in this dataset @@ -199,14 +215,7 @@ def get_constants(self): ------- np.ndarray: The list of constants, None if there are no constants """ - # extract from ds: - const = self.ds.constants.values - - # transpose to match new format: - # [C, F, H, W] -> [F, C, H, W] - const = np.transpose(const, axes=(1, 0, 2, 3)) - - return const + return self.constants @staticmethod def _convert_time_step(dt): # pylint: disable=invalid-name @@ -244,9 +253,13 @@ def _get_scaling_da(self): ), } except (ValueError, KeyError): + missing = [ + m + for m in self.ds.channel_in.values + if m not in list(self.scaling.keys()) + ] raise KeyError( - f"one or more of the input data variables f{list(self.ds.channel_in)} not found in the " - f"scaling config dict data.scaling ({list(self.scaling.keys())})" + f"Input channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})" ) try: self.target_scaling = scaling_da.sel( @@ -261,9 +274,37 @@ def _get_scaling_da(self): ), } except (ValueError, KeyError): + missing = [ + m + for m in self.ds.channel_out.values + if m not in list(self.scaling.keys()) + ] raise KeyError( - f"one or more of the target data variables f{list(self.ds.channel_out)} not found in the " - f"scaling config dict data.scaling ({list(self.scaling.keys())})" + f"Target channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})" + ) + + try: + # not all datasets will have constants + if "constants" in self.ds.data_vars: + self.constant_scaling = scaling_da.sel( + index=self.ds.channel_c.values + ).rename({"index": "channel_out"}) + self.constant_scaling = { + "mean": np.expand_dims( + self.constant_scaling["mean"].to_numpy(), (1, 2, 3) + ), + "std": np.expand_dims( + self.constant_scaling["std"].to_numpy(), (1, 2, 3) + ), + } + except (ValueError, KeyError): + missing = [ + m + for m in self.ds.channel_c.values + if m not in list(self.scaling.keys()) + ] + raise KeyError( + f"Constant channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})" ) def __len__(self): @@ -428,9 +469,9 @@ def __getitem__(self, item): np.transpose(x, axes=(0, 3, 1, 2, 4, 5)) for x in inputs_result ] - if "constants" in self.ds.data_vars: + if self.constants is not None: # Add the constants as [F, C, H, W] - inputs_result.append(np.swapaxes(self.ds.constants.values, 0, 1)) + inputs_result.append(self.constants) logger.log(5, "computed batch in %0.2f s", time.time() - compute_time) torch.cuda.nvtx.range_pop() diff --git a/test/datapipes/test_healpix.py b/test/datapipes/test_healpix.py index 3667cb650..67d0a9f8e 100644 --- a/test/datapipes/test_healpix.py +++ b/test/datapipes/test_healpix.py @@ -72,7 +72,10 @@ def scaling_dict(): "z1000": {"mean": 952.1435546875, "std": 895.7516479492188}, "z250": {"mean": 101186.28125, "std": 5551.77978515625}, "z500": {"mean": 55625.9609375, "std": 2681.712890625}, + "lsm": {"mean": 0, "std": 1}, + "z": {"mean": 0, "std": 1}, "tp6": {"mean": 1, "std": 0, "log_epsilon": 1e-6}, + "extra": {"mean": 1, "std": 0}, } return DictConfig(scaling) @@ -88,6 +91,9 @@ def scaling_double_dict(): "z250": {"mean": 0, "std": 2}, "z500": {"mean": 0, "std": 2}, "tp6": {"mean": 0, "std": 2, "log_epsilon": 1e-6}, + "lsm": {"mean": 0, "std": 2}, + "z": {"mean": 0, "std": 2}, + "extra": {"mean": 0, "std": 2}, } return DictConfig(scaling) @@ -212,7 +218,7 @@ def test_TimeSeriesDataset_initialization( "bogosity": {"mean": 0, "std": 42}, } ) - with pytest.raises(KeyError, match=("one or more of the input data variables")): + with pytest.raises(KeyError, match=("Input channels ")): timeseries_ds = TimeSeriesDataset( dataset=zarr_ds, data_time_step="3h", @@ -550,9 +556,14 @@ def test_TimeSeriesDataModule_get_constants( # open our test dataset ds_path = Path(data_dir, dataset_name + ".zarr") zarr_ds = xr.open_zarr(ds_path) - expected = np.transpose( - zarr_ds.constants.sel(channel_c=list(constants.keys())).values, - axes=(1, 0, 2, 3), + + # dividing by 2 due to scaling + expected = ( + np.transpose( + zarr_ds.constants.sel(channel_c=list(constants.keys())).values, + axes=(1, 0, 2, 3), + ) + / 2.0 ) assert np.array_equal( @@ -620,7 +631,7 @@ def test_TimeSeriesDataModule_get_dataloaders( test_dataloader, test_sampler = timeseries_dm.test_dataloader(num_shards=1) assert test_sampler is None assert isinstance(test_dataloader, DataLoader) - print(f"dataset lenght {len}") + # with >1 shard should be distributed sampler train_dataloader, train_sampler = timeseries_dm.train_dataloader(num_shards=2) assert isinstance(train_sampler, DistributedSampler) diff --git a/test/datapipes/test_healpix_couple.py b/test/datapipes/test_healpix_couple.py index 681a943bf..7ee24aacc 100644 --- a/test/datapipes/test_healpix_couple.py +++ b/test/datapipes/test_healpix_couple.py @@ -71,6 +71,8 @@ def scaling_dict(): "z1000": {"mean": 952.1435546875, "std": 895.7516479492188}, "z250": {"mean": 101186.28125, "std": 5551.77978515625}, "z500": {"mean": 55625.9609375, "std": 2681.712890625}, + "lsm": {"mean": 0, "std": 1}, + "z": {"mean": 0, "std": 1}, "tp6": {"mean": 1, "std": 0, "log_epsilon": 1e-6}, } return DictConfig(scaling) @@ -86,6 +88,8 @@ def scaling_double_dict(): "z1000": {"mean": 0, "std": 2}, "z250": {"mean": 0, "std": 2}, "z500": {"mean": 0, "std": 2}, + "lsm": {"mean": 0, "std": 2}, + "z": {"mean": 0, "std": 2}, "tp6": {"mean": 0, "std": 2, "log_epsilon": 1e-6}, } return DictConfig(scaling) @@ -94,7 +98,7 @@ def scaling_double_dict(): @nfsdata_or_fail def test_ConstantCoupler(data_dir, dataset_name, scaling_dict, pytestconfig): variables = ["z500", "z1000"] - input_times = ["0H"] + input_times = ["0h"] input_time_dim = 1 output_time_dim = 1 presteps = 0 @@ -116,7 +120,7 @@ def test_ConstantCoupler(data_dir, dataset_name, scaling_dict, pytestconfig): assert isinstance(coupler, ConstantCoupler) interval = 2 - data_time_step = "3H" + data_time_step = "3h" coupler.compute_coupled_indices(interval, data_time_step) coupled_integration_dim = presteps + max(output_time_dim // input_time_dim, 1) expected = np.empty([batch_size, coupled_integration_dim, len(input_times)]) @@ -144,12 +148,12 @@ def test_ConstantCoupler(data_dir, dataset_name, scaling_dict, pytestconfig): @nfsdata_or_fail def test_TrailingAverageCoupler(data_dir, dataset_name, scaling_dict, pytestconfig): variables = ["z500", "z1000"] - input_times = ["6H", "12H"] + input_times = ["6h", "12h"] input_time_dim = 2 output_time_dim = 2 presteps = 0 batch_size = 2 - averaging_window = "6H" + averaging_window = "6h" # open our test dataset ds_path = Path(data_dir, dataset_name + ".zarr") zarr_ds = xr.open_zarr(ds_path) @@ -167,7 +171,7 @@ def test_TrailingAverageCoupler(data_dir, dataset_name, scaling_dict, pytestconf assert isinstance(coupler, TrailingAverageCoupler) interval = 2 - data_time_step = "3H" + data_time_step = "3h" coupler.compute_coupled_indices(interval, data_time_step) coupled_integration_dim = presteps + max(output_time_dim // input_time_dim, 1) expected = np.empty([batch_size, coupled_integration_dim, len(input_times)]) @@ -239,7 +243,7 @@ def test_CoupledTimeSeriesDataset_initialization( "bogosity": {"mean": 0, "std": 42}, } ) - with pytest.raises(KeyError, match=("one or more of the input data variables")): + with pytest.raises(KeyError, match=("Input channels ")): timeseries_ds = CoupledTimeSeriesDataset( dataset=zarr_ds, input_variables=variables, @@ -363,7 +367,7 @@ def test_CoupledTimeSeriesDataset_get_constants( "params": { "batch_size": 1, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -405,7 +409,7 @@ def test_CoupledTimeSeriesDataset_len( "params": { "batch_size": 1, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0h"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -431,7 +435,7 @@ def test_CoupledTimeSeriesDataset_len( "params": { "batch_size": 2, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0h"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -486,7 +490,7 @@ def test_CoupledTimeSeriesDataset_get( "params": { "batch_size": batch_size, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0h"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -624,7 +628,7 @@ def test_CoupledTimeSeriesDataModule_initialization( "params": { "batch_size": 1, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0h"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -718,7 +722,7 @@ def test_CoupledTimeSeriesDataModule_get_constants( "params": { "batch_size": 1, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0h"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -759,9 +763,14 @@ def test_CoupledTimeSeriesDataModule_get_constants( # open our test dataset ds_path = Path(data_dir, dataset_name + ".zarr") zarr_ds = xr.open_zarr(ds_path) - expected = np.transpose( - zarr_ds.constants.sel(channel_c=list(constants.keys())).values, - axes=(1, 0, 2, 3), + + # divide by 2 due to scaling + expected = ( + np.transpose( + zarr_ds.constants.sel(channel_c=list(constants.keys())).values, + axes=(1, 0, 2, 3), + ) + / 2.0 ) assert np.array_equal( @@ -810,7 +819,7 @@ def test_CoupledTimeSeriesDataModule_get_dataloaders( "params": { "batch_size": 1, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0h"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -873,7 +882,7 @@ def test_CoupledTimeSeriesDataModule_get_coupled_vars( "params": { "batch_size": 1, "variables": ["z250"], - "input_times": ["0H"], + "input_times": ["0h"], "input_time_dim": 1, "output_time_dim": 1, "presteps": 0, @@ -908,8 +917,8 @@ def test_CoupledTimeSeriesDataModule_get_coupled_vars( "params": { "batch_size": 1, "variables": ["z250"], - "input_times": ["6H"], - "averaging_window": "6H", + "input_times": ["6h"], + "averaging_window": "6h", "input_time_dim": 1, "output_time_dim": 1, "presteps": 0,