diff --git a/test/neuralgcm_test.py b/test/neuralgcm_test.py index 9db4aac1..a9f3b2e9 100644 --- a/test/neuralgcm_test.py +++ b/test/neuralgcm_test.py @@ -93,10 +93,16 @@ def setUp(self): eval_era5 = xarray_utils.regrid(sliced_era5, regridder) eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5) - # inner_steps = 24 # save model outputs once every 24 hours - # outer_steps = 4 * 24 // inner_steps # total of 4 days - inner_steps = 4 # save model outputs once every 24 hours - outer_steps = 4 * 4 // inner_steps # total of 4 days + import os + if os.getenv("NEURALGCM_LARGE") is not None: + inner_steps = 24 # save model outputs once every 24 hours + outer_steps = 4 * 24 // inner_steps # total of 4 days + elif s.getenv("NEURALGCM_MEDIUM") is not None: + inner_steps = 4 # save model outputs once every 24 hours + outer_steps = 4 * 4 // inner_steps # total of 4 days + else: + inner_steps = 1 # save model outputs once every 24 hours + outer_steps = 1 * 1 // inner_steps # total of 4 days timedelta = np.timedelta64(1, "h") * inner_steps # times = (np.arange(outer_steps) * inner_steps) # time axis in hours