diff --git a/snowexsql/api.py b/snowexsql/api.py index f4548fb..2457a2a 100644 --- a/snowexsql/api.py +++ b/snowexsql/api.py @@ -484,25 +484,33 @@ class LayerMeasurements(BaseDataset): @classmethod def _filter_campaign(cls, qry, v): - - qry = qry.join(cls.MODEL.site).join( - Site.campaign).filter(Campaign.name == v) - return qry - + return qry.join( + cls.MODEL.site + ).join( + Site.campaign + ).filter( + Campaign.name == v + ) + @classmethod def _filter_observers(cls, qry, v): - qry = qry.join(cls.MODEL.site).join( - Site.observers).filter(Observer.name == v) - return qry - + return qry.join( + cls.MODEL.site + ).join( + Site.observers + ).filter( + Observer.name == v + ) + @property def all_sites(self): """ Return all specific site names """ with db_session(self.DB_NAME) as (session, engine): - qry = session.query(Site.name).distinct() - result = qry.all() + result = session.query( + Site.name + ).distinct().all() return self.retrieve_single_value_result(result) @property @@ -511,8 +519,9 @@ def all_dates(self): Return all distinct dates in the data """ with db_session(self.DB_NAME) as (session, engine): - qry = session.query(Site.date).distinct() - result = qry.all() + result = session.query( + Site.date + ).distinct().all() return self.retrieve_single_value_result(result) diff --git a/tests/api/test_layer_measurements.py b/tests/api/test_layer_measurements.py index 0c600f3..68c2a95 100644 --- a/tests/api/test_layer_measurements.py +++ b/tests/api/test_layer_measurements.py @@ -1,111 +1,105 @@ -from datetime import date - -import geopandas as gpd -import numpy as np +""" +Test the Layer Measurement class +""" import pytest +import datetime + +from geoalchemy2.shape import to_shape +from geoalchemy2.elements import WKTElement from snowexsql.api import LayerMeasurements -from tests import DBConnection +from snowexsql.tables import LayerData -class TestLayerMeasurements(DBConnection): - """ - Test the Layer Measurement class - """ - CLZ = LayerMeasurements +@pytest.fixture +def layer_data(layer_data_factory, db_session): + layer_data_factory.create() + return db_session.query(LayerData).all() - def test_all_types(self, clz): - result = clz().all_types - assert result == ["density"] +@pytest.mark.usefixtures("db_test_session") +@pytest.mark.usefixtures("db_test_connection") +@pytest.mark.usefixtures("layer_data") +class TestLayerMeasurements: + @pytest.fixture(autouse=True) + def setup_method(self, layer_data): + self.subject = LayerMeasurements() + self.db_data = layer_data - def test_all_campaigns(self, clz): - result = clz().all_campaigns - assert result == ['Grand Mesa'] + def test_all_campaigns(self): + result = self.subject.all_campaigns + assert result == [ + record.site.campaign.name + for record in self.db_data + ] + + def test_all_observers(self): + result = self.subject.all_observers + assert result == [ + observer.name + for record in self.db_data + for observer in record.site.observers + ] - def test_all_sites(self, clz): - result = clz().all_sites - assert result == ['Fakepit1'] + def test_all_sites(self): + result = self.subject.all_sites + assert result == [ + record.site.name + for record in self.db_data + ] - def test_all_dates(self, clz): - result = clz().all_dates - assert result == [date(2020, 1, 28)] + def test_all_dates(self): + result = self.subject.all_dates + assert result == [ + record.site.date + for record in self.db_data + ] - def test_all_observers(self, clz): - result = clz().all_observers - assert result == ['TEST'] +@pytest.fixture +def layer_data(layer_density_factory, db_session): + layer_density_factory.create() + return db_session.query(LayerData).all() - def test_all_instruments(self, clz): - result = clz().all_instruments - assert result == ['fakeinstrument'] +@pytest.mark.usefixtures("db_test_session") +@pytest.mark.usefixtures("db_test_connection") +@pytest.mark.usefixtures("layer_data") +class TestDensityLayerMeasurement: + @pytest.fixture(autouse=True) + def setup_method(self, layer_data): + self.subject = LayerMeasurements() + # Pick the first record for this test case + self.db_data = layer_data[0] - @pytest.mark.parametrize( - "kwargs, expected_length, mean_value", [ - ({ - "date": date(2020, 3, 12), "type": "density", - "site": "COERIB_20200312_0938" - }, 0, np.nan), # filter to 1 pit - ({"instrument": "IRIS", "limit": 10}, 0, np.nan), # limit works - ({ - "date": date(2020, 5, 28), - "instrument": 'IRIS' - }, 0, np.nan), # nothing returned - ({ - "date_less_equal": date(2019, 12, 15), - "type": 'density' - }, 0, np.nan), - ({ - "date_greater_equal": date(2020, 5, 13), - "type": 'density' - }, 0, np.nan), - ({ - "type": 'density', - "campaign": 'Grand Mesa' - }, 1, 42.5), - ({ - "observer": 'TEST', - "campaign": 'Grand Mesa' - }, 1, 42.5), - ] - ) - def test_from_filter(self, clz, kwargs, expected_length, mean_value): - result = clz.from_filter(**kwargs) - assert len(result) == expected_length - if expected_length > 0: - assert pytest.approx( - result["value"].astype("float").mean() - ) == mean_value + def test_date_and_instrument(self): + result = self.subject.from_filter( + date=self.db_data.site.datetime.date(), + instrument=self.db_data.instrument.name, + ) + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value - @pytest.mark.parametrize( - "kwargs, expected_error", [ - ({"notakey": "value"}, ValueError), - # ({"date": date(2020, 3, 12)}, LargeQueryCheckException), - ({"date": [date(2020, 5, 28), date(2019, 10, 3)]}, ValueError), - ] - ) - def test_from_filter_fails(self, clz, kwargs, expected_error): - """ - Test failure on not-allowed key and too many returns - """ - with pytest.raises(expected_error): - clz.from_filter(**kwargs) + def test_instrument_and_limit(self, layer_density_factory): + # Create 10 more records, but only fetch five + layer_density_factory.create_batch(10) + + result = self.subject.from_filter( + instrument=self.db_data.instrument.name, + limit=5 + ) + assert len(result) == 5 + assert pytest.approx(result["value"].astype("float").mean()) == \ + float(self.db_data.value) - def test_from_area(self, clz): - df = gpd.GeoDataFrame( - geometry=gpd.points_from_xy( - [743766.4794971556], [4321444.154620216], crs="epsg:26912" - ).buffer(1000.0) - ).set_crs("epsg:26912") - result = clz.from_area( - type="density", - shp=df.iloc[0].geometry, + def test_date_and_measurement_type(self): + result = self.subject.from_filter( + date=self.db_data.site.datetime.date(), + type=self.db_data.measurement_type.name, ) - assert len(result) == 0 + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value - def test_from_area_point(self, clz): - pts = gpd.points_from_xy([743766.4794971556], [4321444.154620216]) - crs = "26912" - result = clz.from_area( - pt=pts[0], buffer=1000, crs=crs, - type="density", + def test_doi(self): + result = self.subject.from_filter( + doi=self.db_data.doi.doi, ) - assert len(result) == 0 + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value diff --git a/tests/conftest.py b/tests/conftest.py index a55d362..da824b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,8 @@ DB_CONNECTION_OPTIONS, db_connection_string, initialize ) from tests.factories import (CampaignFactory, DOIFactory, InstrumentFactory, - LayerDataFactory, MeasurementTypeFactory, + LayerDataFactory, LayerDensityFactory, + MeasurementTypeFactory, ObserverFactory, PointDataFactory, PointObservationFactory, SiteFactory) from .db_setup import CREDENTIAL_FILE, DB_INFO, SESSION @@ -20,13 +21,13 @@ register(DOIFactory) register(InstrumentFactory) register(LayerDataFactory) +register(LayerDensityFactory) register(MeasurementTypeFactory) register(ObserverFactory) register(PointDataFactory) register(PointObservationFactory) register(SiteFactory) - # Add this factory to a test if you would like to debug the SQL statement # It will print the query from the BaseDataset.from_filter() method @pytest.fixture(scope='session') diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index e4a9dad..b203baf 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -2,6 +2,7 @@ from .doi import DOIFactory from .instrument import InstrumentFactory from .layer_data import LayerDataFactory +from .layer_data import LayerDensityFactory from .measurement_type import MeasurementTypeFactory from .observer import ObserverFactory from .point_data import PointDataFactory @@ -13,6 +14,7 @@ "DOIFactory", "InstrumentFactory", "LayerDataFactory", + "LayerDensityFactory", "MeasurementTypeFactory", "ObserverFactory", "PointDataFactory", diff --git a/tests/factories/layer_data.py b/tests/factories/layer_data.py index eeae66d..c94124e 100644 --- a/tests/factories/layer_data.py +++ b/tests/factories/layer_data.py @@ -6,7 +6,9 @@ from .instrument import InstrumentFactory from .measurement_type import MeasurementTypeFactory from .site import SiteFactory - +from geoalchemy2.elements import WKTElement +from datetime import datetime +from .campaign import CampaignFactory class LayerDataFactory(BaseFactory): class Meta: @@ -23,3 +25,19 @@ class Meta: instrument = factory.SubFactory(InstrumentFactory, name='Density Cutter') doi = factory.SubFactory(DOIFactory) site = factory.SubFactory(SiteFactory) + + +class LayerDensityFactory(LayerDataFactory): + + depth = 15.0 + bottom_depth = 5.0 + value = '236.0' + comments = 'Sample_A' + + site = factory.SubFactory( + SiteFactory, + name = 'IN20', + datetime=datetime(2020, 2, 15, 13, 30), + geom = WKTElement("POINT(743281 4324005)", srid=32612), + campaign = factory.SubFactory(CampaignFactory, name='Grand Mesa') + )