Skip to content

Commit

Permalink
Merge pull request #66 from s-scherrer/master
Browse files Browse the repository at this point in the history
Several small fixes
  • Loading branch information
wpreimes authored Nov 30, 2020
2 parents 83f02bc + 4679d8f commit 764d503
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 17 deletions.
26 changes: 22 additions & 4 deletions src/pygeogrids/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import numpy as np
import numpy.testing as nptest
import warnings
try:
from osgeo import ogr
ogr_installed = True
Expand Down Expand Up @@ -93,6 +94,11 @@ class BasicGrid(object):
The shape has to be given as (lat2d, lon2d)
It it is not given the shape is set to the length of the input
lon and lat arrays.
transform_lon : bool or None, optional (default: None)
Whether to transform longitudes to values between -180 and 180.
By default values are transformed, but a warning is issued.
To turn off the warning, set this to ``True``, to turn of
transformation set this to ``False``.
Attributes
----------
Expand Down Expand Up @@ -148,7 +154,7 @@ class BasicGrid(object):
"""

def __init__(self, lon, lat, gpis=None, geodatum='WGS84', subset=None,
setup_kdTree=True, shape=None):
setup_kdTree=True, shape=None, transform_lon=None):
"""
init method, prepares lon and lat arrays for _transform_lonlats if
necessary
Expand All @@ -167,6 +173,18 @@ def __init__(self, lon, lat, gpis=None, geodatum='WGS84', subset=None,

self.n_gpi = len(lon)

# transfrom longitudes to be between -180 and 180 if they are between 0
# and 360
if transform_lon or transform_lon is None:
if np.any(lon > 180):
lon[lon > 180] -= 360
if transform_lon is None:
warnings.warn(
"Longitude values have been transformed to be in"
" (-180, 180]. If this was not intended or to suppress"
" this warning set the transform_lon keyword argument"
)

self.arrlon = lon
self.arrlat = lat

Expand Down Expand Up @@ -368,7 +386,7 @@ def find_nearest_gpi(self, lon, lat, max_dist=np.Inf):
lat : float or iterable
Latitude of point.
max_dist : float, optional
Maximum distance to consider for search (default: np.Inf).
Maximum distance [m] to consider for search (default: np.Inf).
Returns
-------
Expand All @@ -379,9 +397,9 @@ def find_nearest_gpi(self, lon, lat, max_dist=np.Inf):
At the moment not on a great circle but in spherical
cartesian coordinates.
"""
gpi, distance = self.find_k_nearest_gpi(lon, lat, max_dist=np.Inf, k=1)
gpi, distance = self.find_k_nearest_gpi(lon, lat, max_dist=max_dist, k=1)

if not _element_iterable(lon):
if not _element_iterable(lon) and len(gpi) > 0:
gpi = gpi[0]
distance = distance[0]

Expand Down
24 changes: 19 additions & 5 deletions src/pygeogrids/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,22 @@ def find_nearest_index(self, lon, lat, max_dist=np.Inf, k=1):
great circle distance at the moment. This should be OK for
most applications that look for the nearest neighbor which
should not be hundreds of kilometers away.
If no point was found within the maximum distance to consider, an
empty array is returned.
ind : int, numpy.array
indices of nearest neighbor
If ``self.grid`` is ``False`` indices of nearest neighbor.
If no point was found within the maximum distance to consider, an
empty array is returned.
index_lon : numpy.array, optional
if self.grid is True then return index into lon array of grid definition
If ``self.grid`` is ``True`` then return index into lon array of
grid definition.
If no point was found within the maximum distance to consider, an
empty array is returned.
index_lat : numpy.array, optional
if self.grid is True then return index into lat array of grid definition
If ``self.grid`` is ``True`` then return index into lat array of
grid definition.
If no point was found within the maximum distance to consider, an
empty array is returned.
"""
if self.kdtree is None:
self._build_kdtree()
Expand All @@ -193,11 +203,15 @@ def find_nearest_index(self, lon, lat, max_dist=np.Inf, k=1):
d, ind = self.kdtree.query(
query_coords, distance_upper_bound=max_dist, k=k)

# if no point was found, d == inf
if not np.all(np.isfinite(d)):
d, ind = np.array([]), np.array([])

if not self.grid:
return d, ind
else:
# calculate index position in grid definition arrays assuming row-major
# flattening of arrays after numpy.meshgrid
# calculate index position in grid definition arrays assuming
# row-major flattening of arrays after numpy.meshgrid
index_lat = ind / self.lon_size
index_lon = ind % self.lon_size
return d, index_lon.astype(np.int32), index_lat.astype(np.int32)
4 changes: 2 additions & 2 deletions src/pygeogrids/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
type(global_attrs['shape']) is not int and
len(global_attrs['shape']) == 2):

latsize = global_attrs['shape'][1]
lonsize = global_attrs['shape'][0]
latsize = global_attrs['shape'][0]
lonsize = global_attrs['shape'][1]
ncfile.createDimension("lat", latsize)
ncfile.createDimension("lon", lonsize)
gpisize = global_attrs['shape'][0] * global_attrs['shape'][1]
Expand Down
42 changes: 41 additions & 1 deletion tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import numpy.testing as nptest
import numpy as np
from osgeo import ogr
import pytest
import warnings

from pygeogrids.grids import lonlat2cell
from pygeogrids.grids import lonlat2cell, BasicGrid
import pygeogrids as grids


Expand Down Expand Up @@ -135,6 +137,19 @@ def test_k_nearest_neighbor_list(self):
assert gpi[1, 0] == 38430
assert gpi[1, 1] == 38429

def test_nearest_neighbor_max_dist(self):
# test with maxdist higher than nearest point
gpi, dist = self.grid.find_nearest_gpi(14.3, 18.5, max_dist=100e3)
assert gpi == 25754
assert len([dist]) == 1
lon, lat = self.grid.gpi2lonlat(gpi)
assert lon == 14.5
assert lat == 18.5

# test with maxdist lower than nearest point
gpi, dist = self.grid.find_nearest_gpi(14.3, 18.5, max_dist=10e3)
assert len(gpi) == 0
assert len(dist) == 0

class TestCellGridNotGpiDirect(unittest.TestCase):

Expand Down Expand Up @@ -659,5 +674,30 @@ def test_shpgrid(self):
assert subgrid.activearrlat == 46


@pytest.mark.filterwarnings("error")
def test_BasicGrid_transform_lon():
"""
Tests whether transforming longitudes works as expected.
"""

lat = np.asarray([10, -10, 5, 42])
lon_pos = np.asarray([0, 90, 180, 270])
lon_centered = np.asarray([0, 90, 180, -90])

# case 1: warning and transformation
with pytest.warns(UserWarning):
grid = BasicGrid(lon_pos, lat)
assert np.all(grid.arrlon == lon_centered)

# case 2: no warning and transform
grid = BasicGrid(lon_pos, lat, transform_lon=True)
assert np.all(grid.arrlon == lon_centered)

# case 3: no warning and no transform
grid = BasicGrid(lon_pos, lat, transform_lon=False)
assert np.all(grid.arrlon == lon_pos)



if __name__ == "__main__":
unittest.main()
10 changes: 5 additions & 5 deletions tests/test_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def setUp(self):
self.subset = np.sort(np.random.choice(np.arange(self.lats.size),
size=500, replace=False))
self.basic = grids.BasicGrid(self.lons, self.lats, subset=self.subset,
shape=(360, 180))
shape=(180, 360))

self.basic_shape_gpis = grids.BasicGrid(self.lons, self.lats,
gpis=np.arange(self.lats.size),
subset=self.subset,
shape=(360, 180))
shape=(180, 360))
self.basic_generated = grids.genreg_grid(1, 1)
self.basic_irregular = grids.BasicGrid(np.random.random(360 * 180) * 360 - 180,
np.random.random(
Expand All @@ -63,7 +63,7 @@ def setUp(self):

self.cellgrid_shape = grids.CellGrid(self.lons, self.lats, self.cells,
subset=self.subset,
shape=(360, 180))
shape=(180, 360))

self.testfile = tempfile.NamedTemporaryFile().name

Expand Down Expand Up @@ -106,8 +106,8 @@ def test_save_basicgrid_generated(self):
nptest.assert_array_equal(sorted(self.basic.gpis[self.subset]),
sorted(nc_data.variables['gpi'][:].flatten()[stored_subset]))
assert nc_data.test == 'test_attribute'
assert nc_data.shape[1] == 180
assert nc_data.shape[0] == 360
assert nc_data.shape[0] == 180
assert nc_data.shape[1] == 360

def test_save_basicgrid_irregular_nc(self):
grid_nc.save_grid(self.testfile,
Expand Down

0 comments on commit 764d503

Please sign in to comment.