Skip to content

Commit

Permalink
resolve #3, update pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiying-ocean committed Sep 18, 2024
1 parent 10177b0 commit f625bff
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 156 deletions.
Binary file modified test/baseline_images/test_plot/test_line.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/baseline_images/test_plot/test_map.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/baseline_images/test_plot/test_scatterdatavis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def test_search_point():

def test_sel_modern_basin():
data = create_testdata()
assert data.sel_modern_basin(50,norm_lon_method='').mean().data.item() == 0.5019781051132972
assert data.sel_modern_basin(50,norm_lon_method='').mean().data.item() == 0.5019781051132972
18 changes: 16 additions & 2 deletions test/test_colormap.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from cgeniepy.plot import CommunityPalette
import matplotlib.colors

def test_get_palette():

def test_txt_palette():
cmap = CommunityPalette().get_palette('ODV')
assert cmap.colors[0] == '#FEB483'
assert matplotlib.colors.to_hex(cmap.colors[0]) == '#feb483'


def test_spk_palette():
cmap = CommunityPalette().get_palette('medium_rainbow')
assert matplotlib.colors.to_hex(cmap.colors[0]) == '#cc00ff'


def test_xml_palette():
cmap = CommunityPalette().get_palette('Section')
assert matplotlib.colors.to_hex(cmap.colors[0]) == '#c2ccb8'


def test_alt_init():
hex_codes = CommunityPalette('medium_rainbow').to_hex()
assert hex_codes[0] == '#cc00ff'
147 changes: 14 additions & 133 deletions test/test_model.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,17 @@
from cgeniepy.model import GenieModel
from importlib.resources import files
import cgeniepy

def test_model_getvar():
model_path = str(files("data").joinpath("muffin.CBE.worlg4.BASESFeTDTL.SPIN"))
model = GenieModel(model_path, gemflag=['ecogem'])
data= model.get_var('eco2D_Plankton_C_Total').isel(time=-1).mean().data.item()
assert data == 1.690475344657898
model = cgeniepy.sample_model()
assert model.get_var("ocn_sur_temp").mean().data.values.item() == 18.055744171142578

def test_eco_pft():
model = cgeniepy.sample_model(model_type='EcoModel',gemflag=['ecogem'])
## 1st PFT
data = model.get_pft(1).isel(time=-1).mean().data.values.item()
assert data == 0.12259095907211304

def tets_model_ncvardict():
model_path = str(files("data").joinpath("muffin.CBE.worlg4.BASESFeTDTL.SPIN"))
model = GenieModel(model_path, gemflag=['ecogem'])

data = model._ncvar_dict
baseline = {'./muffin.CBE.worlg4.BASESFeTDTL.SPIN/ecogem/fields_ecogem_2d.nc': ['time',
'year',
'lon',
'lat',
'zt',
'lon_edges',
'lat_edges',
'zt_edges',
'grid_level',
'grid_mask',
'grid_topo',
'eco2D_Plankton_C_001',
'eco2D_Plankton_C_002',
'eco2D_Plankton_C_003',
'eco2D_Plankton_C_004',
'eco2D_Plankton_C_005',
'eco2D_Plankton_C_006',
'eco2D_Plankton_C_007',
'eco2D_Plankton_C_008',
'eco2D_Plankton_C_009',
'eco2D_Plankton_C_010',
'eco2D_Plankton_C_011',
'eco2D_Plankton_C_012',
'eco2D_Plankton_C_013',
'eco2D_Plankton_C_014',
'eco2D_Plankton_C_015',
'eco2D_Plankton_C_016',
'eco2D_Plankton_C_017',
'eco2D_Plankton_C_018',
'eco2D_Plankton_C_019',
'eco2D_Plankton_C_Total',
'eco2D_Uptake_Fluxes_C',
'eco2D_Plankton_P_001',
'eco2D_xGamma_P_001',
'eco2D_Plankton_P_002',
'eco2D_xGamma_P_002',
'eco2D_Plankton_P_003',
'eco2D_xGamma_P_003',
'eco2D_Plankton_P_004',
'eco2D_xGamma_P_004',
'eco2D_Plankton_P_005',
'eco2D_xGamma_P_005',
'eco2D_Plankton_P_006',
'eco2D_xGamma_P_006',
'eco2D_Plankton_P_007',
'eco2D_xGamma_P_007',
'eco2D_Plankton_P_008',
'eco2D_xGamma_P_008',
'eco2D_Plankton_P_009',
'eco2D_Plankton_P_010',
'eco2D_Plankton_P_011',
'eco2D_Plankton_P_012',
'eco2D_Plankton_P_013',
'eco2D_Plankton_P_014',
'eco2D_Plankton_P_015',
'eco2D_Plankton_P_016',
'eco2D_Plankton_P_017',
'eco2D_Plankton_P_018',
'eco2D_xGamma_P_018',
'eco2D_Plankton_P_019',
'eco2D_xGamma_P_019',
'eco2D_Plankton_P_Total',
'eco2D_Uptake_Fluxes_P',
'eco2D_Plankton_Fe_001',
'eco2D_xGamma_Fe_001',
'eco2D_Plankton_Fe_002',
'eco2D_xGamma_Fe_002',
'eco2D_Plankton_Fe_003',
'eco2D_xGamma_Fe_003',
'eco2D_Plankton_Fe_004',
'eco2D_xGamma_Fe_004',
'eco2D_Plankton_Fe_005',
'eco2D_xGamma_Fe_005',
'eco2D_Plankton_Fe_006',
'eco2D_xGamma_Fe_006',
'eco2D_Plankton_Fe_007',
'eco2D_xGamma_Fe_007',
'eco2D_Plankton_Fe_008',
'eco2D_xGamma_Fe_008',
'eco2D_Plankton_Fe_009',
'eco2D_Plankton_Fe_010',
'eco2D_Plankton_Fe_011',
'eco2D_Plankton_Fe_012',
'eco2D_Plankton_Fe_013',
'eco2D_Plankton_Fe_014',
'eco2D_Plankton_Fe_015',
'eco2D_Plankton_Fe_016',
'eco2D_Plankton_Fe_017',
'eco2D_Plankton_Fe_018',
'eco2D_xGamma_Fe_018',
'eco2D_Plankton_Fe_019',
'eco2D_xGamma_Fe_019',
'eco2D_Plankton_Fe_Total',
'eco2D_Uptake_Fluxes_Fe',
'eco2D_Plankton_Chl_001',
'eco2D_Plankton_Chl_002',
'eco2D_Plankton_Chl_003',
'eco2D_Plankton_Chl_004',
'eco2D_Plankton_Chl_005',
'eco2D_Plankton_Chl_006',
'eco2D_Plankton_Chl_007',
'eco2D_Plankton_Chl_008',
'eco2D_Plankton_Chl_018',
'eco2D_Plankton_Chl_019',
'eco2D_Plankton_Chl_Total',
'eco2D_xGamma_T',
'eco2D_Nutrients_DIC',
'eco2D_Nutrients_PO4',
'eco2D_Nutrients_Fe',
'eco2D_Size_Mean',
'eco2D_Size_Stdev',
'eco2D_Size_Minimum',
'eco2D_Size_Maximum',
'eco2D_Diversity_Threshold',
'eco2D_Diversity_Shannon',
'eco2D_Diversity_Simpson',
'eco2D_Diversity_Berger',
'eco2D_Size_Frac_Pico_Chl',
'eco2D_Size_Frac_Nano_Chl',
'eco2D_Size_Frac_Micro_Chl']}
assert data == baseline

def test_eco_multipft():
model = cgeniepy.sample_model(model_type='EcoModel',gemflag=['ecogem'])
## 1st and 2nd PFT
data = model.get_pft([1,2]).isel(time=-1).sum(dim='variable').mean().data.values.item()
assert data == 0.16739805042743683
27 changes: 10 additions & 17 deletions test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,27 @@
from cgeniepy.array import GriddedData
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
from matplotlib.testing.decorators import image_comparison
import matplotlib.pyplot as plt
from cgeniepy.table import ScatterData
from importlib.resources import files
import cgeniepy
import matplotlib.pyplot as plt
import cartopy.crs as ccrs


def create_testdata():
lat = np.linspace(-89.5,89.5,180)
lon = np.linspace(0,359,360)
np.random.seed(12349)
data = np.random.rand(lat.size,lon.size)
xdata = xr.DataArray(data, coords=[('lat',lat),('lon',lon)],
attrs={'long_name':'random data', 'units':'uniteless'})
return GriddedData(xdata,False, attrs=xdata.attrs)
def create_sample_data():
model = cgeniepy.sample_model()
return model.get_var('ocn_sur_temp').isel(time=-1)

@image_comparison(baseline_images=['test_map'], remove_text=True,
extensions=['png'], style='mpl20')
def test_map():
data = create_testdata()
def test_map():
fig, ax = plt.subplots(subplot_kw={'projection': ccrs.Mollweide()})
data = create_sample_data()
data.plot(ax=ax)

return fig

@image_comparison(baseline_images=['test_line'], remove_text=True,
extensions=['png'], style='mpl20')
def test_line():
data = create_testdata()
data = create_sample_data()
fig, ax = plt.subplots()
data.mean(dim='lon').plot(ax=ax)
return fig
Expand Down
6 changes: 3 additions & 3 deletions test/test_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ def create_testdata():
## calculate skill score
return ArrComparison(x, y)


def test_mscore():
ac = create_testdata()
assert ac.mscore()==1.0

def test_pearson_r():
ac = create_testdata()
assert ac.pearson_r()==0.9999999999999999
diff = ac.pearson_r().item() - 1.0
assert diff < 1E-8

def test_cos_sim():
ac = create_testdata()
assert ac.cos_similarity()==1.0

def test_rmse():
ac = create_testdata()
assert ac.rmse()==0.0
assert ac.rmse()==0.0

0 comments on commit f625bff

Please sign in to comment.