From 8068f94172b59cfed553cac45702830d60fef09e Mon Sep 17 00:00:00 2001 From: "Soroosh.Mani" Date: Thu, 16 Nov 2023 17:49:53 -0500 Subject: [PATCH] WIP: Fix tempfile removal for windows --- ocsmesh/hfun/raster.py | 15 ++++++++++--- ocsmesh/ops/combine_geom.py | 8 ++++--- ocsmesh/raster.py | 42 +++++++++++++++++++++++++++++-------- tests/api/__init__.py | 13 ++++++------ tests/api/hfun.py | 1 - tests/api/mesh.py | 9 +++++--- tests/api/raster.py | 39 ++++++++++++++++++++++++---------- tests/api/utils.py | 41 +++++++++++++++++++++++++----------- 8 files changed, 120 insertions(+), 48 deletions(-) diff --git a/ocsmesh/hfun/raster.py b/ocsmesh/hfun/raster.py index 18430e5e..0be62bf5 100644 --- a/ocsmesh/hfun/raster.py +++ b/ocsmesh/hfun/raster.py @@ -4,8 +4,10 @@ import functools import gc import logging +import os from multiprocessing import cpu_count, Pool import operator +import pathlib import tempfile from time import time from typing import Union, List, Callable, Optional, Iterable, Tuple @@ -234,6 +236,11 @@ def __init__(self, self._constraints = [] + def __del__(self): + for _, memfile_path in self._xy_cache.items(): + pathlib.Path(memfile_path).unlink() + + def msh_t( self, window: Optional[rasterio.windows.Window] = None, @@ -1296,15 +1303,17 @@ def get_xy_memcache( transformer = Transformer.from_crs( self.src.crs, dst_crs, always_xy=True) # pylint: disable=R1732 - tmpfile = tempfile.NamedTemporaryFile() +# tmpfile = tempfile.NamedTemporaryFile() + tmpfd, tmppath = tempfile.mkstemp() xy = self.get_xy(window) - fp = np.memmap(tmpfile, dtype='float32', mode='w+', shape=xy.shape) + fp = np.memmap(tmppath, dtype='float32', mode='w+', shape=xy.shape) + os.close(tmpfd) fp[:] = np.vstack( transformer.transform(xy[:, 0], xy[:, 1])).T _logger.info('Saving values to memcache...') fp.flush() _logger.info('Done!') - self._xy_cache[f'{window}{dst_crs}'] = tmpfile + self._xy_cache[f'{window}{dst_crs}'] = tmppath return fp[:] _logger.info('Loading values from memcache...') diff --git a/ocsmesh/ops/combine_geom.py b/ocsmesh/ops/combine_geom.py index 30412c17..5e6ea652 100644 --- a/ocsmesh/ops/combine_geom.py +++ b/ocsmesh/ops/combine_geom.py @@ -163,16 +163,17 @@ def run(self): poly_files_coll = [] _logger.info(f"Number of processes: {nprocs}") - with tempfile.TemporaryDirectory(dir=out_dir) as temp_dir, \ - tempfile.NamedTemporaryFile() as base_file: + with tempfile.TemporaryDirectory(dir=out_dir) as temp_dir: + tmpfd, tmppath = tempfile.mkstemp() if base_mult_poly: - base_mesh_path = base_file.name + base_mesh_path = tmppath self._multipolygon_to_disk( base_mesh_path, base_mult_poly, fix=False) else: base_mesh_path = None base_mult_poly = None + os.close(tmpfd) _logger.info("Processing DEM priorities ...") @@ -235,6 +236,7 @@ def run(self): ], ignore_index=True ) + pathlib.Path(tmppath).unlink() # The assumption is this returns polygon or multipolygon diff --git a/ocsmesh/raster.py b/ocsmesh/raster.py index 16d6a439..ad294769 100644 --- a/ocsmesh/raster.py +++ b/ocsmesh/raster.py @@ -13,6 +13,7 @@ import pathlib import tempfile import warnings +import platform from time import time from contextlib import contextmanager, ExitStack from typing import ( @@ -144,15 +145,23 @@ class TemporaryFile: cleanup capabities on object destruction. """ - def __set__(self, obj, val: tempfile.NamedTemporaryFile): + def __set__(self, obj, val: Optional[os.PathLike]): + tmpfile = obj.__dict__.get('tmpfile') + if tmpfile is not None: + obj._src = None + pathlib.Path(tmpfile).unlink() + obj.__dict__['tmpfile'] = val - obj._src = rasterio.open(val.name) + if val is None: + obj._src = None + else: + obj._src = rasterio.open(val) def __get__(self, obj, objtype=None) -> pathlib.Path: tmpfile = obj.__dict__.get('tmpfile') if tmpfile is None: return obj.path - return pathlib.Path(tmpfile.name) + return pathlib.Path(tmpfile) class SourceRaster: @@ -165,7 +174,10 @@ class SourceRaster: opening it everytime need arises. """ - def __set__(self, obj, val: rasterio.DatasetReader): + def __set__(self, obj, val: Optional[rasterio.DatasetReader]): + source = obj.__dict__.get('source') + if source is not None: + source.close() obj.__dict__['source'] = val def __get__(self, obj, objtype=None) -> rasterio.DatasetReader: @@ -345,6 +357,9 @@ def __init__( self._path = path self._crs = crs + def __del__(self): + self._tmpfile = None + def __iter__(self, chunk_size: int = None, overlap: int = None): for window in self.iter_windows(chunk_size, overlap): yield window, self.get_window_bounds(window) @@ -382,14 +397,15 @@ def modifying_raster( no_except = False try: # pylint: disable=R1732 - tmpfile = tempfile.NamedTemporaryFile(prefix=tmpdir) +# tmpfile = tempfile.NamedTemporaryFile(prefix=tmpdir, mode='w') + tmpfd, tmppath = tempfile.mkstemp(prefix=tmpdir) new_meta = kwargs # Flag to workaround cases where "src" is NOT set yet if use_src_meta: new_meta = self.src.meta.copy() new_meta.update(**kwargs) - with rasterio.open(tmpfile.name, 'w', **new_meta) as dst: + with rasterio.open(tmppath, 'w', **new_meta) as dst: if use_src_meta: for i, desc in enumerate(self.src.descriptions): dst.set_band_description(i+1, desc) @@ -399,9 +415,12 @@ def modifying_raster( finally: if no_except: - # So that tmpfile is NOT destroyed when it locally - # goes out of scope - self._tmpfile = tmpfile + self._tmpfile = tmppath + + # We don't need to keep the descriptor open, we kept it + # open # so that there's no race condition on the temp + # file up to now + os.close(tmpfd) @@ -944,6 +963,8 @@ def average_filter( # in other parts of the code. Thorough testing is needed for # modifying the raster (e.g. hfun add_contour is affected) + if platform.system() == 'Windows': + raise ImplementationError('Not supported on Windows!') bands = apply_on_bands if bands is None: bands = range(1, self.src.count + 1) @@ -1002,6 +1023,9 @@ def generic_filter(self, function, **kwargs: Any) -> None: None """ + if platform.system() == 'Windows': + raise ImplementationError('Not supported on Windows!') + # TODO: Don't overwrite; add additoinal bands for filtered values # NOTE: Adding new bands in this function can result in issues diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 21716b3a..2595fac3 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -14,9 +14,10 @@ ) TEST_FILE = os.path.join(tempfile.gettempdir(), 'test_dem.tif') if not Path(TEST_FILE).exists(): - with tempfile.NamedTemporaryFile() as tfp: - urllib.request.urlretrieve(tif_url, filename=tfp.name) - r = Raster(tfp.name) - r.resampling_method = Resampling.average - r.resample(scaling_factor=0.01) - r.save(TEST_FILE) + tmpfd, tmppath = tempfile.mkstemp() + urllib.request.urlretrieve(tif_url, filename=tmppath) + os.close(tmpfd) + r = Raster(tmppath) + r.resampling_method = Resampling.average + r.resample(scaling_factor=0.01) + r.save(TEST_FILE) diff --git a/tests/api/hfun.py b/tests/api/hfun.py index 65c39d4e..dbd4274b 100755 --- a/tests/api/hfun.py +++ b/tests/api/hfun.py @@ -907,7 +907,6 @@ def setUp(self): ) mesh = ocsmesh.Mesh(msh_t) mesh.write(str(self.mesh1), format='grd', overwrite=False) - mesh.write('/tmp/ocsmesh/mytest2.2dm', format='2dm', overwrite=True) def tearDown(self): diff --git a/tests/api/mesh.py b/tests/api/mesh.py index 65a3a612..68b18ca9 100644 --- a/tests/api/mesh.py +++ b/tests/api/mesh.py @@ -1,4 +1,5 @@ #! python +import os import tempfile import unittest import warnings @@ -279,9 +280,11 @@ def test_specified_boundary_order_withmerge(self): def test_specify_boundary_on_imported_mesh_with_boundary(self): self.mesh.boundaries.auto_generate() - with tempfile.NamedTemporaryFile(suffix='.grd') as fo: - self.mesh.write(fo.name, format='grd', overwrite=True) - imported_mesh = Mesh.open(fo.name) + tmpfd, tmppath = tempfile.mkstemp(suffix='.grd') + self.mesh.write(tmppath, format='grd', overwrite=True) + imported_mesh = Mesh.open(tmppath) + os.close(tmpfd) + os.unlink(tmppath) bdry = imported_mesh.boundaries diff --git a/tests/api/raster.py b/tests/api/raster.py index ad6f248a..487c148a 100644 --- a/tests/api/raster.py +++ b/tests/api/raster.py @@ -1,6 +1,7 @@ import shutil import tempfile import unittest +import platform from pathlib import Path import numpy as np @@ -9,6 +10,10 @@ from ocsmesh.utils import raster_from_numpy +IS_WINDOWS = platform.system() == 'Windows' + + + class Raster(unittest.TestCase): def setUp(self): self.tdir = Path(tempfile.mkdtemp()) @@ -38,21 +43,33 @@ def tearDown(self): shutil.rmtree(self.tdir) + @unittest.skipIf(IS_WINDOWS, 'Not supported due to LowLevelFunction int') def test_avg_filter_nomask(self): - rast = ocsmesh.Raster(self.rast1) - rast.average_filter(size=17) - self.assertTrue(np.all(rast.get_values() == 10)) + try: + rast = ocsmesh.Raster(self.rast1) + rast.average_filter(size=17) + self.assertTrue(np.all(rast.get_values() == 10)) + finally: + del rast + @unittest.skipIf(IS_WINDOWS, 'Not supported due to LowLevelFunction int') def test_avg_filter_masked_nanfill(self): - rast = ocsmesh.Raster(self.rast2) - rast.average_filter(size=17) - self.assertTrue( - np.all(rast.values[~np.isnan(rast.values)] == 10)) + try: + rast = ocsmesh.Raster(self.rast2) + rast.average_filter(size=17) + self.assertTrue( + np.all(rast.values[~np.isnan(rast.values)] == 10)) + finally: + del rast + @unittest.skipIf(IS_WINDOWS, 'Not supported due to LowLevelFunction int') def test_avg_filter_masked_nonnanfill(self): - rast = ocsmesh.Raster(self.rast3) - rast.average_filter(size=17) - self.assertTrue( - np.all(rast.values[rast.values != rast.nodata] == 10)) + try: + rast = ocsmesh.Raster(self.rast3) + rast.average_filter(size=17) + self.assertTrue( + np.all(rast.values[rast.values != rast.nodata] == 10)) + finally: + del rast diff --git a/tests/api/utils.py b/tests/api/utils.py index b5611844..2050f5da 100644 --- a/tests/api/utils.py +++ b/tests/api/utils.py @@ -3,6 +3,7 @@ import tempfile import unittest from copy import deepcopy +from pathlib import Path import numpy as np import geopandas as gpd @@ -638,9 +639,11 @@ def test_basic_create(self): in_rast_xy = np.mgrid[1:3:0.1, -1:1:0.1] in_rast_z = np.random.random(in_rast_xy[0].shape) - with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + with tempfile.TemporaryDirectory() as tdir: +# with tempfile.NamedTemporaryFile(suffix='.tiff', mode='w') as tf: retval = utils.raster_from_numpy( - tf.name, + Path(tdir) / 'test_rast.tiff', +# tf.name, data=in_rast_z, mgrid=in_rast_xy, crs=4326 @@ -648,7 +651,8 @@ def test_basic_create(self): self.assertEqual(retval, None) - rast = Raster(tf.name) + rast = Raster(Path(tdir) / 'test_rast.tiff') +# rast = Raster(tf.name) self.assertTrue(np.all(np.isclose( in_rast_xy.transpose([2,1,0]).reshape(-1, 2), @@ -656,6 +660,7 @@ def test_basic_create(self): ))) self.assertTrue(np.all(in_rast_z == rast.values)) self.assertEqual(rast.crs, CRS.from_epsg(4326)) + del rast def test_diff_extent_x_n_y(self): # TODO: Test when x and y extent are different @@ -672,27 +677,34 @@ def test_data_masking(self): fill_value=fill_value ) - with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: +# with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + with tempfile.TemporaryDirectory() as tdir: + tf_name = Path(tdir) / 'tiff1.tiff' utils.raster_from_numpy( - tf.name, +# tf.name, + tf_name, data=in_rast_z_nomask, mgrid=in_rast_xy, crs=4326 ) - rast = Raster(tf.name) + rast = Raster(tf_name) self.assertEqual(rast.src.nodata, None) + del rast - with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: +# with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + tf_name = Path(tdir) / 'tiff2.tiff' utils.raster_from_numpy( - tf.name, +# tf.name, + tf_name, data=in_rast_z_mask, mgrid=in_rast_xy, crs=4326 ) - rast = Raster(tf.name) + rast = Raster(tf_name) self.assertEqual(rast.src.nodata, fill_value) + del rast def test_multiband_raster_data(self): @@ -701,20 +713,25 @@ def test_multiband_raster_data(self): for i in range(nbands): in_data[:, :, i] *= i in_rast_xy = np.mgrid[-74:-71:1, 40.5:40.9:0.1] - with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: +# with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + with tempfile.TemporaryDirectory() as tdir: + tf_name = Path(tdir) / 'tiff3.tiff' utils.raster_from_numpy( - tf.name, + tf_name, +# tf.name, data=in_data, mgrid=in_rast_xy, crs=4326 ) - rast = Raster(tf.name) + rast = Raster(tf_name) +# rast = Raster(tf.name) self.assertEqual(rast.count, nbands) for i in range(nbands): with self.subTest(band_number=i): self.assertTrue( (rast.get_values(band=i+1) == i).all() ) + del rast def test_multiband_raster_invalid_io(self):