Skip to content

Commit

Permalink
Upgrade to keras3 models
Browse files Browse the repository at this point in the history
  • Loading branch information
charlie-becker committed May 1, 2024
1 parent 2f85c6b commit bb3296c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 56 deletions.
81 changes: 36 additions & 45 deletions ptype/inference.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import pandas as pd
from herbie import Herbie, Herbie_latest
from herbie import Herbie, HerbieLatest
from metpy.units import units
from metpy.calc import dewpoint_from_relative_humidity, dewpoint_from_specific_humidity
import numpy as np
import xarray as xr
import cfgrib
import os
from bridgescaler import load_scaler
from mlguess.keras.models import CategoricalDNN
import yaml
import numba
from numba import jit
import glob
import zarr
import pygrib
from pyproj import Proj, CRS, Transformer
from keras.models import load_model

def df_flatten(ds, varsP, vertical_level_name='isobaricInhPa'):
""" Split pressure level variables by pressure level, reassign and return as flattened Dataframe.
Expand Down Expand Up @@ -44,7 +43,7 @@ def convert_longitude(lon):
""" Convert longitude from -180-180 to 0-360"""
return lon % 360

def download_data(date, model, product, save_dir, forecast_hour):
def download_data(date, model, save_dir, forecast_hour, **kwargs):
""" Download data use Herbie for specified dates, model and forecast range.
Args:
date (List of pandas date times): List of Model initialization times
Expand All @@ -55,17 +54,19 @@ def download_data(date, model, product, save_dir, forecast_hour):
Returns:
None
"""
if "member" not in kwargs.keys():
kwargs["member"] = None
if date == "most_recent":
h = Herbie_latest(model=model,
product=product,
save_dir=save_dir,
fxx=[forecast_hour])
h = HerbieLatest(model=model,
save_dir=save_dir,
fxx=forecast_hour,
**kwargs)
else:
h = Herbie(date=date,
model=model,
product=product,
save_dir=save_dir,
fxx=forecast_hour)
fxx=forecast_hour,
**kwargs)
h.download()

return h.get_localFilePath()
Expand Down Expand Up @@ -93,19 +94,21 @@ def load_data(var_dict, file, model, extent, drop):
grib = cfgrib.open_dataset(file, backend_kwargs={
"filter_by_keys": {'typeOfLevel': key, 'cfVarName': var, 'stepType': 'instant'}})
if len(grib) == 0:
grib = cfgrib.open_dataset(file, backend_kwargs={
grib = cfgrib.open_dataset(file, backend_kwargs={
"filter_by_keys": {'typeOfLevel': key, 'shortName': var, 'stepType': 'instant'}})
if var == "tp":
grib = cfgrib.open_dataset(file, backend_kwargs={
"filter_by_keys": {'typeOfLevel': key, 'shortName': var, 'stepType': 'accum'}})
if "heightAboveGround" in grib.variables.keys():
grib = grib.drop_vars("heightAboveGround")
grib_data.append(grib)

for idx in glob.glob(str(file) + '*.idx'):
os.remove(idx) # delete index files that are created when opening grib
nwp_dataset = xr.merge(grib_data).load()

nwp_dataset['t'].values = kelvin_to_celsius(nwp_dataset['t'].values)
nwp_dataset['t'].attrs["units"] = 'degC'
if model == "rap":
if (model == "rap") | (model == "gefs"):
nwp_dataset['dpt'] = dewpoint_from_relative_humidity(nwp_dataset['t'] * units.degC,
nwp_dataset['r'].values / 100)
elif model == "gfs":
Expand Down Expand Up @@ -227,7 +230,7 @@ def add_coord_data(file_path, grib_data, extent):
return grib_data


def convert_and_interpolate(data, surface_data, pressure_levels, height_levels):
def convert_and_interpolate(data, surface_data, pressure_levels, height_levels, variables):
"""
Convert Pressure level data to height above surface and interpolate data across specified height levels.
Args:
Expand All @@ -239,6 +242,7 @@ def convert_and_interpolate(data, surface_data, pressure_levels, height_levels):
Returns:
Pandas Dataframe of interpolated data at height above the surface.
"""
surface_match = {"t": "t2m", "dpt": "d2m", "u": "u10", "v": "v10"}
cols = {}
height_levels = np.arange(start=height_levels["low"],
stop=height_levels["high"] + height_levels["interval"],
Expand All @@ -247,14 +251,12 @@ def convert_and_interpolate(data, surface_data, pressure_levels, height_levels):
cols[var] = [f"{var}_{int(x)}" for x in pressure_levels]

var_arrays = []
variables = ['t', 'dpt', 'u', 'v']
surface_variables = ['t2m', 'd2m', 'u10', 'v10']
height_data = data[cols['hgt_above_sfc']].values

for v, sv in zip(variables, surface_variables):
for v in variables:
pressure_level_data = data[cols[v]].values
height_interp_data = interpolate(height_data, pressure_level_data, height_levels)
height_interp_data[:, 0] = surface_data[sv]
height_interp_data[:, 0] = surface_data[surface_match[v]]
var_arrays.append(height_interp_data)

pl_array = np.tile(pressure_levels, len(height_data)).reshape(len(height_data), len(pressure_levels))
Expand Down Expand Up @@ -293,7 +295,7 @@ def transform_data(input_data, transformer, input_features):



def load_model(model_path, input_scaler_file):
def load_saved_model(model_path, scaler_path):
"""
Load ML model and bridgescaler object.
Args:
Expand All @@ -302,24 +304,15 @@ def load_model(model_path, input_scaler_file):
Returns:
Loaded Tensorflow model, bridgescaler object
"""
config = os.path.join(model_path, "model.yml")
with open(config) as cf:
conf = yaml.load(cf, Loader=yaml.FullLoader)
conf['batch_size'] = 1000

x_transformer = load_scaler(os.path.join(model_path, "scalers", input_scaler_file))

with open(os.path.join(model_path, "model.yml")) as f:
conf = yaml.safe_load(f)

conf["input_features"] = conf["TEMP_C"] + conf["T_DEWPOINT_C"] + conf["UGRD_m/s"] + conf["VGRD_m/s"]
conf["output_features"] = conf["ptypes"]

model = CategoricalDNN().load_model(conf)
model = load_model(model_path)
scaler = load_scaler(scaler_path)
groups = scaler.groups_
input_features = [x for y in groups for x in y]

return model, x_transformer, conf["input_features"]
return model, scaler, input_features

def grid_predictions(data, predictions, interp_df=None, interpolated_pl=None, height_levels=None,
def grid_predictions(data, predictions, interp_df=None, interpolated_pl=None, variables=None, height_levels=None,
add_interp_data=False, evidential=False):
"""
Populate gridded xarray dataset with ML probabilities and categorical predictions as separate variables.
Expand Down Expand Up @@ -372,7 +365,7 @@ def grid_predictions(data, predictions, interp_df=None, interpolated_pl=None, he
drop_vars.append(v)
if add_interp_data:

interpolated_gridded = add_interp_gridded(data, interp_df, height_levels)
interpolated_gridded = add_interp_gridded(data, interp_df, height_levels, variables)
interpolated_gridded['isobaricInhPa_h'] = (['heightAboveGround', 'y', 'x'],
np.moveaxis(interpolated_pl.reshape(data['y'].size,
data['x'].size,
Expand All @@ -386,13 +379,12 @@ def grid_predictions(data, predictions, interp_df=None, interpolated_pl=None, he
return data.drop(drop_vars)


def save_data(dataset, out_path, date, model, forecast_hour, save_format):
def save_data(dataset, out_path, model_name, forecast_hour, save_format):
"""
Save ML predictions and surface data as netCDf file.
Args:
dataset: Xarray dataset with ML predictions and surface data.
out_path: Path to save data.
date: Datetime object for predictions.
model: NWP model name.
forecast_hour: Forecast hour of ML predictions.
Expand All @@ -402,9 +394,9 @@ def save_data(dataset, out_path, date, model, forecast_hour, save_format):
date_str = dataset.time.dt.strftime("%Y-%m-%d").values
dir_str = dataset.time.dt.strftime("%Y%m%d").values
model_run_str = dataset.time.dt.strftime("%H%M").values
os.makedirs(os.path.join(out_path, model, str(dir_str), str(model_run_str)), exist_ok=True)
file_str = f"MILES_ptype_{model}_{date_str}_{model_run_str}_f{forecast_hour:02}"
full_path = os.path.join(out_path, model, str(dir_str), str(model_run_str), file_str)
os.makedirs(os.path.join(out_path, model_name, str(dir_str), str(model_run_str)), exist_ok=True)
file_str = f"MILES_ptype_{model_name}_{date_str}_{model_run_str}_f{forecast_hour:02}"
full_path = os.path.join(out_path, model_name, str(dir_str), str(model_run_str), file_str)

dataset = dataset.expand_dims('time')

Expand All @@ -421,7 +413,7 @@ def save_data(dataset, out_path, date, model, forecast_hour, save_format):
return


def add_interp_gridded(nwp_data, interp_data, height_levels):
def add_interp_gridded(nwp_data, interp_data, height_levels, variables):
"""
Convert height interpolated data to xarray format and merge with main dataset.
Args:
Expand All @@ -432,10 +424,10 @@ def add_interp_gridded(nwp_data, interp_data, height_levels):
Returns:
Merged xr.Dataset of NWP/ML data and height interpolated data
"""
var_names = {"t": "temperature", "dpt": "dewpoint", "u": "u-component of wind", "v": "v-component of wind"}
x = nwp_data.stack(d=['y', 'x'])['x'].values.astype('int16')
y = nwp_data.stack(d=['y', 'x'])['y'].values.astype('int16')

variables = ['t', 'dpt', 'u', 'v']
height_levels = np.arange(height_levels['high'], height_levels['low'] - height_levels['interval'],
-height_levels['interval'])
columns = [f"{var}_{level}" for var in variables for level in height_levels]
Expand All @@ -447,11 +439,10 @@ def add_interp_gridded(nwp_data, interp_data, height_levels):
gridded = new_df.to_xarray()

datasets = []
for var, long_name in zip(['t', 'dpt', 'u', 'v'],
['temperature', 'dewpoint', 'u-component of wind', 'v-component of wind']):
for var in variables:
dataset = xr.concat([gridded[f"{var}_{i}"] for i in height_levels],
dim='heightAboveGround').to_dataset().rename({f"{var}_{int(height_levels.max())}": f"{var}_h"})
dataset[f"{var}_h"].attrs['Description'] = f"Height interpolated {long_name}"
dataset[f"{var}_h"].attrs['Description'] = f"Height interpolated {var_names[var]}"
dataset[f"{var}_h"] = dataset[f"{var}_h"].astype('float32')
datasets.append(dataset)

Expand Down
23 changes: 12 additions & 11 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import yaml
import pandas as pd
from ptype.inference import (download_data, load_data, convert_and_interpolate,
load_model, transform_data, grid_predictions, save_data)
load_saved_model, transform_data, grid_predictions, save_data)
import itertools
from multiprocessing import Pool
from dask.distributed import Client
Expand All @@ -13,46 +13,47 @@
def main(config, username, date, forecast_hour):
print("starting", date, "for forecast hour: ", forecast_hour)
out_path = config["out_path"].replace("username", username)
print(out_path)
nwp_model = config["model"]
model, transformer, input_features = load_model(model_path=config["ML_model_path"],
input_scaler_file=config["input_scaler_file"])
model, transformer, input_features = load_saved_model(model_path=config["ML_model_path"],
scaler_path=config["input_scaler_file"])
mod_file = download_data(date=date,
model=config["model"],
product=config["variables"]["model"][nwp_model]["product"],
save_dir=out_path,
forecast_hour=forecast_hour)
forecast_hour=forecast_hour,
**config["variables"]["model"][nwp_model]["kwargs"])

ds, df, surface_vars = load_data(var_dict=config["variables"]["model"][nwp_model],
file=mod_file,
model=nwp_model,
extent=config["extent"],
drop=config["drop_input_data"])

data, interpolated_pl = convert_and_interpolate(data=df,
surface_data=surface_vars,
pressure_levels=ds["isobaricInhPa"],
height_levels=config["height_levels"])
height_levels=config["height_levels"],
variables=config["ml_atm_varaibles"])

x_data = transform_data(input_data=data,
transformer=transformer,
input_features=input_features)

if config["evidential"]:
predictions = model.predict_uncertainty(x_data)
predictions = model.predict(x_data, return_uncertainties=True, batch_size=2048)
else:
predictions = model.predict(x_data, batch_size=2048)
predictions = model.predict(x_data, return_uncertainties=False, batch_size=2048)

gridded_preds = grid_predictions(data=ds,
predictions=predictions,
interp_df=data,
interpolated_pl=interpolated_pl,
variables=config["ml_atm_varaibles"],
height_levels=config['height_levels'],
add_interp_data=config["add_interp_data"],
evidential=config["evidential"])
save_data(dataset=gridded_preds,
out_path=out_path,
date=date,
model=config["model"],
model_name=config["model"],
forecast_hour=forecast_hour,
save_format=config["save_format"])
os.remove(str(mod_file)) # delete grib file
Expand Down

0 comments on commit bb3296c

Please sign in to comment.