Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract_input_target_forcings add option for left-justification of train/eval #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 74 additions & 84 deletions graphcast/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from typing import Any, Mapping, Sequence, Tuple, Union

from graphcast import solar_radiation
import numpy as np
import pandas as pd
import xarray
Expand All @@ -37,15 +36,6 @@

DAY_PROGRESS = "day_progress"
YEAR_PROGRESS = "year_progress"
_DERIVED_VARS = {
DAY_PROGRESS,
f"{DAY_PROGRESS}_sin",
f"{DAY_PROGRESS}_cos",
YEAR_PROGRESS,
f"{YEAR_PROGRESS}_sin",
f"{YEAR_PROGRESS}_cos",
}
TISR = "toa_incident_solar_radiation"


def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -133,7 +123,10 @@ def featurize_progress(


def add_derived_vars(data: xarray.Dataset) -> None:
"""Adds year and day progress features to `data` in place if missing.
"""Adds year and day progress features to `data` in place.

NOTE: `toa_incident_solar_radiation` needs to be computed in this function
as well.

Args:
data: Xarray dataset to which derived features will be added.
Expand All @@ -154,71 +147,38 @@ def add_derived_vars(data: xarray.Dataset) -> None:
)
batch_dim = ("batch",) if "batch" in data.dims else ()

# Add year progress features if missing.
if YEAR_PROGRESS not in data.data_vars:
year_progress = get_year_progress(seconds_since_epoch)
data.update(
featurize_progress(
name=YEAR_PROGRESS,
dims=batch_dim + ("time",),
progress=year_progress,
)
)

# Add day progress features if missing.
if DAY_PROGRESS not in data.data_vars:
longitude_coord = data.coords["lon"]
day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data)
data.update(
featurize_progress(
name=DAY_PROGRESS,
dims=batch_dim + ("time",) + longitude_coord.dims,
progress=day_progress,
)
)


def add_tisr_var(data: xarray.Dataset) -> None:
"""Adds TISR feature to `data` in place if missing.

Args:
data: Xarray dataset to which TISR feature will be added.

Raises:
ValueError if `datetime`, 'lat', or `lon` are not in `data` coordinates.
"""

if TISR in data.data_vars:
return

for coord in ("datetime", "lat", "lon"):
if coord not in data.coords:
raise ValueError(f"'{coord}' must be in `data` coordinates.")

# Remove `batch` dimension of size one if present. An error will be raised if
# the `batch` dimension exists and has size greater than one.
data_no_batch = data.squeeze("batch") if "batch" in data.dims else data

tisr = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
data_no_batch, use_jit=True
# Add year progress features.
year_progress = get_year_progress(seconds_since_epoch)
data.update(
featurize_progress(
name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress
)
)

if "batch" in data.dims:
tisr = tisr.expand_dims("batch", axis=0)

data.update({TISR: tisr})
# Add day progress features.
longitude_coord = data.coords["lon"]
day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data)
data.update(
featurize_progress(
name=DAY_PROGRESS,
dims=batch_dim + ("time",) + longitude_coord.dims,
progress=day_progress,
)
)


def extract_input_target_times(
dataset: xarray.Dataset,
input_duration: TimedeltaLike,
target_lead_times: TargetLeadTimes,
justify: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably better as an enum instead of string.

) -> Tuple[xarray.Dataset, xarray.Dataset]:
"""Extracts inputs and targets for prediction, from a Dataset with a time dim.

The input period is assumed to be contiguous (specified by a duration), but
the targets can be a list of arbitrary lead times.


Examples:

# Use 18 hours of data as inputs, and two specific lead times as targets:
Expand Down Expand Up @@ -256,6 +216,16 @@ def extract_input_target_times(
(inclusive) lead times, or a sequence of lead times. Lead times should be
Timedeltas (or something convertible to). They are given relative to the
final input timestep, and should be positive.
justify: Defines whether inputs and targets are extracted from the beginning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance you can add a test that this does what you'd expect?

or end of the example batch.
When using 'left' justify (default), the final input is defined as the 2nd time element
of the example batch. Targets follow immediately thereafter as defined by the
leadtime.
Alternatively, 'right' justify is where the targets start with the last time
element and the inputs are the two time elements preceding the first target.
Note: It is important to realize that the first prediction can be no
earlier than 12Z based on current construction of example batches. 00Z and
06Z are inputs.

Returns:
inputs:
Expand All @@ -270,23 +240,43 @@ def extract_input_target_times(
(target_lead_times, target_duration
) = _process_target_lead_times_and_get_duration(target_lead_times)

# Shift the coordinates for the time axis so that a timedelta of zero
# corresponds to the forecast reference time. That is, the final timestep
# that's available as input to the forecast, with all following timesteps
# forming the target period which needs to be predicted.
# This means the time coordinates are now forecast lead times.
input_duration = pd.Timedelta(input_duration)
time = dataset.coords["time"]
dataset = dataset.assign_coords(time=time + target_duration - time[-1])

# Slice out targets:
targets = dataset.sel({"time": target_lead_times})
# Slice out inputs and targets:
if justify == 'left':
# Inputs correspond to the first time elements within the input duration
# Targets follow immediatly after per the target lead times
target_start_time = int(input_duration.total_seconds()/3600/6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be 6h specific.

target_end_time = int(input_duration.total_seconds()/3600/6) + int(target_duration.total_seconds()/3600/6)

inputs = dataset.isel(time=slice(int(target_start_time)))
inputs['time'] = inputs['time'] - input_duration + time[1]

targets = dataset.isel(time=slice(target_start_time,target_end_time))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at the google python style guide. Some general comments:

  • There should be spaces between function arguments, no trailing spaces, etc.
  • Comments are in english and should use punctuation (eg missing periods above).
  • Remove debugging statements (commented out code, print statements below).

targets['time'] = targets['time'] - input_duration + time[1]

# targets = targets.assign_coords(time=time[1:target_end_time+1])

elif justify == 'right':
# Shift the coordinates for the time axis so that a timedelta of zero
# corresponds to the forecast reference time. That is, the final timestep
# that's available as input to the forecast, with all following timesteps
# forming the target period which needs to be predicted.
# This means the time coordinates are now forecast lead times.
dataset = dataset.assign_coords(time=time + target_duration - time[-1])

targets = dataset.sel({"time": target_lead_times})
# Both endpoints are inclusive with label-based slicing, so we offset by a
# small epsilon to make one of the endpoints non-inclusive:
zero = pd.Timedelta(0)
epsilon = pd.Timedelta(1, "ns")
inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)})
else:
raise ValueError(
"justify must either be 'left' or 'right'"
)

input_duration = pd.Timedelta(input_duration)
# Both endpoints are inclusive with label-based slicing, so we offset by a
# small epsilon to make one of the endpoints non-inclusive:
zero = pd.Timedelta(0)
epsilon = pd.Timedelta(1, "ns")
inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)})
return inputs, targets


Expand All @@ -311,8 +301,9 @@ def _process_target_lead_times_and_get_duration(

# A list of multiple (not necessarily contiguous) lead times:
target_lead_times = [pd.Timedelta(x) for x in target_lead_times]
target_lead_times.sort()
target_lead_times.sort()
target_duration = target_lead_times[-1]
print(target_lead_times,target_duration)
return target_lead_times, target_duration


Expand All @@ -325,25 +316,24 @@ def extract_inputs_targets_forcings(
pressure_levels: Tuple[int, ...],
input_duration: TimedeltaLike,
target_lead_times: TargetLeadTimes,
justify: str = 'left'
) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]:
"""Extracts inputs, targets and forcings according to requirements."""
dataset = dataset.sel(level=list(pressure_levels))

# "Forcings" include derived variables that do not exist in the original ERA5
# or HRES datasets, as well as other variables (e.g. tisr) that need to be
# computed manually for the target lead times. Compute the requested ones.
if set(forcing_variables) & _DERIVED_VARS:
# "Forcings" are derived variables and do not exist in the original ERA5 or
# HRES datasets. Compute them if they are not in `dataset`.
if not set(forcing_variables).issubset(set(dataset.data_vars)):
add_derived_vars(dataset)
if set(forcing_variables) & {TISR}:
add_tisr_var(dataset)

# `datetime` is needed by add_derived_vars but breaks autoregressive rollouts.
dataset = dataset.drop_vars("datetime")

inputs, targets = extract_input_target_times(
dataset,
input_duration=input_duration,
target_lead_times=target_lead_times)
target_lead_times=target_lead_times,
justify=justify)

if set(forcing_variables) & set(target_variables):
raise ValueError(
Expand Down