Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahAlidoost committed Mar 18, 2024
1 parent 90ef61d commit 99844c8
Showing 1 changed file with 157 additions and 31 deletions.
188 changes: 157 additions & 31 deletions tests/test_stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pytest
import xarray as xr
import pandas as pd
from shapely import geometry

from stmtools.stm import _validate_coords
Expand Down Expand Up @@ -254,22 +255,69 @@ def stmat_lonlat_morton():
).unify_chunks()

@pytest.fixture
def meteo():
lon_values = np.array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5])
lat_values = np.array([0.25, 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25])
def meteo_points():
lon_values = np.array([0.5, 1.5, 2.5, 3.5, 4.5])
lat_values = np.array([0.25, 1.25, 2.25, 3.25, 4.25])
time_values = pd.date_range(start='2021-01-01', periods=6)

return xr.Dataset(
data_vars=dict(
temperature=(["space", "time"], da.arange(10 * 5).reshape((10, 5))),
humidity=(["space", "time"], da.arange(10 * 5).reshape((10, 5))),
temperature=(["space", "time"], da.arange(5 * 6).reshape((5, 6))),
humidity=(["space", "time"], da.arange(5 * 6).reshape((5, 6))),
),
coords=dict(
lon=(["space"], lon_values),
lat=(["space"], lat_values),
time=(["time"], np.arange(5)),
time=(["time"], time_values),
),
).unify_chunks()

@pytest.fixture
def meteo_raster():
# create a raster with 5x5 grid
lon_values = np.array([0, 1, 2, 3, 4])
lat_values = np.array([0, 1, 2, 3, 4])
time_values = pd.date_range(start='2021-01-01', periods=6)

return xr.Dataset(
data_vars=dict(
temperature=(["lon", "lat", "time"], da.arange(5 * 5 * 6).reshape((5, 5, 6))),
humidity=(["lon", "lat", "time"], da.arange(5 * 5 * 6).reshape((5, 5, 6))),
),
coords=dict(
lon=(["lon"], lon_values),
lat=(["lat"], lat_values),
time=(["time"], time_values),
),
).unify_chunks()

@pytest.fixture
def stmat():
npoints = 10
ntime = 5
return xr.Dataset(
data_vars=dict(
amplitude=(
["space", "time"],
da.arange(npoints * ntime).reshape((npoints, ntime)),
),
phase=(
["space", "time"],
da.arange(npoints * ntime).reshape((npoints, ntime)),
),
pnt_height=(
["space"],
da.arange(npoints),
),
),
coords=dict(
lon=(["space"], da.arange(npoints)),
lat=(["space"], da.arange(npoints)),
time=(["time"], pd.date_range(start='2021-01-02', periods=ntime)),
),
).unify_chunks()


class TestRegulateDims:
def test_time_dim_exists(self, stmat_only_point):
stm_reg = stmat_only_point.stm.regulate_dims()
Expand Down Expand Up @@ -478,61 +526,139 @@ def test_reorder_lonlat(self, stmat_lonlat, stmat_lonlat_morton):
assert stmat.range.equals(stmat_lonlat_morton.range)


class TestEnrichmentFromDataset:
def test_enrich_from_dataset_one_filed(self, stmat, meteo):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo, "temperature")
class TestEnrichmentFromPointDataset:
def test_enrich_from_dataset_one_filed(self, stmat, meteo_points):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature")
assert "temperature" in stmat_enriched.data_vars

# check if the linear interpolation is correct
assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0]
# check if the nearest method is correct
assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1]

# check if coordinates are correct
assert stmat_enriched.lon.equals(stmat.lon)

def test_enrich_from_dataset_multi_filed(self, stmat, meteo):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo, ["temperature", "humidity"])
def test_enrich_from_dataset_multi_filed(self, stmat, meteo_points):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, ["temperature", "humidity"])
assert "temperature" in stmat_enriched.data_vars
assert "humidity" in stmat_enriched.data_vars

# check if the linear interpolation is correct
assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0]
assert stmat_enriched.humidity[0, 0] == meteo.humidity[0, 0]
assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1]
assert stmat_enriched.humidity[0, 0] == meteo_points.humidity[0, 1]

def test_enrich_from_dataset_exceptions(self, stmat, meteo):
def test_enrich_from_dataset_exceptions(self, stmat, meteo_points):
# valid fileds
with pytest.raises(ValueError) as excinfo:
field = "non_exist_field"
stmat.stm.enrich_from_dataset(meteo, field)
stmat.stm.enrich_from_dataset(meteo_points, field)
assert f'Field "{field}" not found' in str(excinfo.value)

# valid dtype of "time"
meteo["time"] = meteo["time"].astype("float64")
another_meteo_points = meteo_points.copy(deep=True)
another_meteo_points["time"] = another_meteo_points["time"].astype("float64")
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(meteo, "temperature")
stmat.stm.enrich_from_dataset(another_meteo_points, "temperature")
assert "different time dtype" in str(excinfo.value)

# "time" dimension should exist in the meteo
meteo = meteo.drop_vars("time")
# "time" dimension should exist in the meteo_points
another_meteo_points = meteo_points.copy(deep=True)
another_meteo_points = another_meteo_points.drop_vars("time")
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(meteo, "temperature")
stmat.stm.enrich_from_dataset(another_meteo_points, "temperature")
assert 'Missing dimension: "time"' in str(excinfo.value)

# shapes of "space" and "time" should be the same
meteo = meteo.sel(space=range(5))
# keys of coordinates should be the same
another_meteo_points = meteo_points.copy(deep=True)
another_meteo_points = another_meteo_points.rename({"lon": "long"})
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(another_meteo_points, "temperature")
assert 'Coordinate label "long" was not found' in str(excinfo.value)

# dimensions either space or lon/lat should exist
another_meteo_points = meteo_points.copy(deep=True)
another_meteo_points = another_meteo_points.drop_dims("space")
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(another_meteo_points, "temperature")
assert 'Missing dimension: "space" or "lon" and "lat"' in str(excinfo.value)

# field already exists
another_stmat = stmat.stm.enrich_from_dataset(meteo_points, "temperature")
with pytest.raises(ValueError) as excinfo:
another_stmat.stm.enrich_from_dataset(meteo_points, "temperature")
assert 'Field "temperature" already exists' in str(excinfo.value)

def test_enrich_from_dataarray_one_filed(self, stmat, meteo_points):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points.temperature, "temperature")
assert "temperature" in stmat_enriched.data_vars

# check if the linear interpolation is correct
assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1]


class TestEnrichmentFromRasterDataset:
def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, "temperature")
assert "temperature" in stmat_enriched.data_vars

# check if the nearest method is correct
assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1]

# check if coordinates are correct
assert stmat_enriched.lon.equals(stmat.lon)

def test_enrich_from_dataset_multi_filed(self, stmat, meteo_raster):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, ["temperature", "humidity"])
assert "temperature" in stmat_enriched.data_vars
assert "humidity" in stmat_enriched.data_vars

# check if the linear interpolation is correct
assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1]
assert stmat_enriched.humidity[0, 0] == meteo_raster.humidity[0, 0, 1]

def test_enrich_from_dataset_exceptions(self, stmat, meteo_raster):
# valid fileds
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(meteo, "temperature")
assert "different space shapes" in str(excinfo.value)
field = "non_exist_field"
stmat.stm.enrich_from_dataset(meteo_raster, field)
assert f'Field "{field}" not found' in str(excinfo.value)

# valid dtype of "time"
another_meteo_raster = meteo_raster.copy(deep=True)
another_meteo_raster["time"] = another_meteo_raster["time"].astype("float64")
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature")
assert "different time dtype" in str(excinfo.value)

# "time" dimension should exist in the meteo_raster
another_meteo_raster = meteo_raster.copy(deep=True)
another_meteo_raster = another_meteo_raster.drop_vars("time")
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature")
assert 'Missing dimension: "time"' in str(excinfo.value)

# keys of coordinates should be the same
meteo = meteo.rename({"lon": "long"})
another_meteo_raster = meteo_raster.copy(deep=True)
another_meteo_raster = another_meteo_raster.rename({"lon": "long"})
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(meteo, "temperature")
stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature")
assert 'Coordinate label "long" was not found' in str(excinfo.value)

# dimensions either space or lon/lat should exist
another_meteo_raster = meteo_raster.copy(deep=True)
another_meteo_raster = another_meteo_raster.drop_dims("lat")
with pytest.raises(ValueError) as excinfo:
stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature")
assert 'Missing dimension: "space" or "lon" and "lat"' in str(excinfo.value)

# field already exists
another_stmat = stmat.stm.enrich_from_dataset(meteo_raster, "temperature")
with pytest.raises(ValueError) as excinfo:
another_stmat.stm.enrich_from_dataset(meteo_raster, "temperature")
assert 'Field "temperature" already exists' in str(excinfo.value)

def test_enrich_from_dataarray_one_filed(self, stmat, meteo):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo.temperature, "temperature")
def test_enrich_from_dataarray_one_filed(self, stmat, meteo_raster):
stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster.temperature, "temperature")
assert "temperature" in stmat_enriched.data_vars

# check if the linear interpolation is correct
assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0]
assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1]

0 comments on commit 99844c8

Please sign in to comment.