Skip to content

Commit

Permalink
Update GraphEM
Browse files Browse the repository at this point in the history
  • Loading branch information
fzhu2e committed Jun 26, 2023
1 parent 94c4b6c commit dfdd56f
Show file tree
Hide file tree
Showing 65 changed files with 5,058 additions and 6,632 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ jobs:
- name: Test with pytest
run: |
conda activate cfr-env
pytest --nbmake -k 'proxy and not psm and not test' -n=auto --nbmake-timeout=3000 --overwrite ./docsrc/notebooks/*.ipynb
# pytest --nbmake -k 'proxy and not psm and not test' -n=auto --nbmake-timeout=3000 --overwrite ./docsrc/notebooks/*.ipynb
65 changes: 56 additions & 9 deletions cfr/reconjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def prep_da_cfg(self, cfg_path, seeds=None, save_job=False, verbose=False):
t_used = t_e - t_s
p_success(f'>>> DONE! Total time used: {t_used/60:.2f} mins.')

def prep_graphem(self, recon_time=None, calib_time=None, recon_period=None, recon_timescale=None, calib_period=None, verbose=False):
def prep_graphem(self, recon_time=None, calib_time=None, recon_period=None, recon_timescale=None, calib_period=None, uniform_pdb=None, verbose=False):
''' A shortcut of the steps for GraphEM data preparation
Args:
Expand All @@ -846,13 +846,30 @@ def prep_graphem(self, recon_time=None, calib_time=None, recon_period=None, rec
else:
recon_time = self.io_cfg('recon_time', recon_time, verbose=verbose)
calib_time = self.io_cfg('calib_time', calib_time, verbose=verbose)

recon_period = self.io_cfg('recon_period', [np.min(recon_time), np.max(recon_time)], verbose=verbose)
calib_period = self.io_cfg('calib_period', [np.min(calib_time), np.max(calib_time)], verbose=verbose)
recon_timescale = self.io_cfg('recon_timescale', np.median(np.diff(recon_time)), verbose=verbose) # unit: yr

uniform_pdb = self.io_cfg('uniform_pdb', uniform_pdb, default=True, verbose=verbose)
if uniform_pdb:
pobj_list = []
for pobj in self.proxydb:
if np.min(pobj.time) <= recon_period[0]:
pobj_list.append(pobj)
new_pdb = ProxyDatabase()
new_pdb += pobj_list
self.proxydb = new_pdb
if verbose: p_success(f'>>> ProxyDatabase filtered to be more uniform. {self.proxydb.nrec} records remaining.')

self.center_proxydb(ref_period=calib_period, verbose=verbose)

self.graphem_params = {}
self.graphem_params['recon_time'] = recon_time
self.graphem_params['calib_time'] = calib_time
if verbose: p_success(f'>>> job.graphem_params["recon_time"] created')
if verbose: p_success(f'>>> job.graphem_params["calib_time"] created')


vn = list(self.obs.keys())[0]
obs = self.obs[vn]
obs_nt = obs.da.shape[0]
Expand Down Expand Up @@ -910,7 +927,7 @@ def graphem_kcv(self, cv_time, ctrl_params, graph_type='neighborhood', stat='MSE
---------
cv_time : array-like, 1d
explain how it differs from recon_time or calib_time
cross validation time points
ctrl_params : array-like, 1d
array of control parameters to try
Expand Down Expand Up @@ -991,6 +1008,9 @@ def run_graphem(self, save_recon=True, save_dirpath=None, save_filename=None,
load_precalculated (bool, optional): load the precalculated `Graph` object. Defaults to False.
verbose (bool, optional): print verbose information. Defaults to False.
fit_kws (dict): the arguments for :py:meth: `GraphEM.solver.GraphEM.fit`
The most important one is "graph_method"; availabel options include "neighborhood", "glasso", and "hybrid", where
"hybrid" means run "neighborhood" first with default `cutoff_radius=1500` to infill the data matrix and then
ran "glasso" with default `sp_FF=3, sp_FP=3` to improve the result further.
See also:
cfr.graphem.solver.GraphEM.fit : fitting the GraphEM method
Expand All @@ -1016,14 +1036,41 @@ def run_graphem(self, save_recon=True, save_dirpath=None, save_filename=None,
fit_kwargs = {
'lonlat': self.graphem_params['lonlat'],
'graph_method': 'neighborhood',
'cutoff_radius': 1500,
'sp_FF': 3,
'sp_FP': 3,
}
fit_kwargs.update(fit_kws)
self.graphem_solver.fit(
self.graphem_params['field'],
self.graphem_params['proxy'],
self.graphem_params['calib_idx'],
verbose=verbose,
**fit_kwargs)
if fit_kwargs['graph_method'] in ['neighborhood', 'glasso']:
self.graphem_solver.fit(
self.graphem_params['field'],
self.graphem_params['proxy'],
self.graphem_params['calib_idx'],
verbose=verbose,
**fit_kwargs)
elif fit_kwargs['graph_method'] == 'hybrid':
fit_kwargs.update({'graph_method': 'neighborhood'})
self.graphem_solver.fit(
self.graphem_params['field'],
self.graphem_params['proxy'],
self.graphem_params['calib_idx'],
verbose=verbose,
**fit_kwargs)

inst = self.graphem_params['calib_idx']
G_L = Graph(
lonlat = self.graphem_params['lonlat'],
field = self.graphem_solver.field_r[inst],
proxy = self.graphem_solver.proxy_r[inst,:])

G_L.glasso_adj(target_FF=fit_kwargs['sp_FF'], target_FP=fit_kwargs['sp_FP'])
fit_kwargs.update({'estimate_graph': False, 'graph': G_L.adj})
self.graphem_solver.fit(
self.graphem_params['field'],
self.graphem_params['proxy'],
self.graphem_params['calib_idx'],
verbose=verbose,
**fit_kwargs)

if verbose: p_success(f'job.graphem_solver created and saved to: {solver_save_path}')

Expand Down
50 changes: 48 additions & 2 deletions cfr/reconres.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
p_fail,
p_warning,
)
import matplotlib.pyplot as plt
from matplotlib import gridspec
from .visual import CartopySettings

class ReconRes:
''' The class for reconstruction results
Expand Down Expand Up @@ -73,5 +76,48 @@ def load(self, vn_list, verbose=False):
p_success(f'>>> ReconRes.recons["{vn}"] created')
p_success(f'>>> ReconRes.da["{vn}"] created')

def compare(self, res2, verbose=False):
vn_list = self.recons.keys()

def valid(self, target_dict, stat=['corr'], timespan=None,
verbose=False):
if type(stat) is not list: stat = [stat]
vn_list = target_dict.keys()
self.load(vn_list, verbose=verbose)
valid_fd, valid_ts = {}, {}
for vn in vn_list:
p_header(f'>>> Validating variable: {vn} ...')
if isinstance(self.recons[vn], ClimateField):
for st in stat:
valid_fd[f'{vn}_{st}'] = self.recons[vn].compare(target_dict[vn], stat=st, timespan=timespan)
valid_fd[f'{vn}_{st}'].plot_kwargs.update({'cbar_orientation': 'horizontal', 'cbar_pad': 0.1})
if verbose: p_success(f'>>> ReconRes.valid_fd[{vn}_{st}] created')
elif isinstance(self.recons[vn], EnsTS):
valid_ts[vn] = self.recons[vn].compare(target_dict[vn], timespan=timespan)
if verbose: p_success(f'>>> ReconRes.valid_ts[{vn}] created')

self.valid_fd = valid_fd
self.valid_ts = valid_ts


def plot_valid(self, recon_name_dict=None, target_name_dict=None,
valid_ts_kws=None, valid_fd_kws=None):
valid_fd_kws = {} if valid_fd_kws is None else valid_fd_kws
valid_ts_kws = {} if valid_ts_kws is None else valid_ts_kws
target_name_dict = {} if target_name_dict is None else target_name_dict
recon_name_dict = {} if recon_name_dict is None else recon_name_dict

fig, ax = {}, {}
for k, v in self.valid_fd.items():
vn, st = k.split('_')
if vn not in target_name_dict: target_name_dict[vn] = 'obs'
fig[k], ax[k] = v.plot(
title=f'{st}({recon_name_dict[vn]}, {target_name_dict[vn]}), mean={v.geo_mean().value[0,0]:.2f}',
**valid_fd_kws)

for k, v in self.valid_ts.items():
if v.value.shape[-1] > 1:
fig[k], ax[k] = v.plot_qs(**valid_ts_kws)
else:
fig[k], ax[k] = v.plot(label='recon', **valid_ts_kws)
ax[k].set_ylabel(recon_name_dict[k])

return fig, ax
2 changes: 1 addition & 1 deletion cfr/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def plot_qs(self, figsize=[12, 4], qs=[0.025, 0.25, 0.5, 0.75, 0.975], color='in
ax.set_xlim(xlim)

if ylim is not None:
ax.set_xlim(ylim)
ax.set_ylim(ylim)


_legend_kwargs = {'ncol': len(qs)//2+1+n_ref, 'loc': 'upper left'}
Expand Down
4 changes: 2 additions & 2 deletions cfr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ def coefficient_efficiency(ref, test, valid=None):
error = test - ref

# CE
numer = np.nansum(np.power(error,2),axis=0)
denom = np.nansum(np.power(ref-np.nanmean(ref,axis=0),2),axis=0)
numer = np.sum(np.power(error,2),axis=0)
denom = np.sum(np.power(ref-np.nanmean(ref,axis=0),2),axis=0)
CE = 1. - np.divide(numer,denom)

if valid:
Expand Down
24 changes: 23 additions & 1 deletion cfr/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from matplotlib.legend_handler import HandlerLine2D
import pathlib
import pandas as pd
import string

from cartopy import util as cutil
from . import utils
Expand Down Expand Up @@ -1396,4 +1397,25 @@ def plot_eof(eof, pc, lat, lon, time, eof_title='EOF', pc_title='PC'):
ax['pc'].spines.top.set_visible(False)
ax['pc'].spines.right.set_visible(False)

return fig, ax
return fig, ax

def add_annotation(ax, fs=20, loc_x=0, loc_y=1.03, start=0, style=None):
if type(ax) is dict:
ax = ax.values()

if type(fs) is not list:
fs = [fs] * len(ax)

for i, v in enumerate(ax):
letter_str = string.ascii_lowercase[i+start]

if style == ')':
letter_str = f'{letter_str})'
elif style == '()':
letter_str = f'({letter_str})'

v.text(
loc_x, loc_y, letter_str,
transform=v.transAxes,
size=fs[i], weight='bold',
)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/_images/notebooks_graphem-real-pages2k_18_0.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/_images/notebooks_graphem-real-pages2k_40_0.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file removed docs/_images/notebooks_graphem-real-pages2k_9_1.png
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed docs/_images/notebooks_test-graphem_13_0.png
Diff not rendered.
Binary file removed docs/_images/notebooks_test-graphem_15_0.png
Diff not rendered.
4 changes: 3 additions & 1 deletion docs/_sources/notebooks/graphem-cli.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Running GraphEM with the Command Line Interface (CLI)"
"# Running GraphEM with the Command Line Interface (CLI)\n",
"\n",
"Note that the case setup in this tutorial is just for illustrating the CLI, and the reconstruction result is not necessarily reliable."
]
},
{
Expand Down
962 changes: 345 additions & 617 deletions docs/_sources/notebooks/graphem-real-pages2k.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

678 changes: 678 additions & 0 deletions docs/_sources/notebooks/test-graphem-real-pages2k-hybrid.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit dfdd56f

Please sign in to comment.