From 374f3033bc1dc2698e21975a2b59d6161c4effc7 Mon Sep 17 00:00:00 2001 From: Sherwin-14 Date: Mon, 28 Oct 2024 19:47:49 +0530 Subject: [PATCH] Replaced fast_fail with pqdm_kwargs the code and added a test --- earthaccess/api.py | 7 ++++++ earthaccess/store.py | 50 ++++++++++++++++++++++-------------------- tests/unit/test_api.py | 25 +++++++++++++++++++++ 3 files changed, 58 insertions(+), 24 deletions(-) create mode 100644 tests/unit/test_api.py diff --git a/earthaccess/api.py b/earthaccess/api.py index 3f3a5001..a6e87352 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -218,6 +218,9 @@ def download( local_path: local directory to store the remote data granules provider: if we download a list of URLs, we need to specify the provider. threads: parallel number of threads to use to download the files, adjust as necessary, default = 8 + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: List of downloaded files @@ -226,6 +229,7 @@ def download( Exception: A file download failed. """ provider = _normalize_location(provider) + pqdm_kwargs = dict(pqdm_kwargs) if pqdm_kwargs is not None else {} pqdm_kwargs = { "exception_behavior": "immediate", "n_jobs": threads, @@ -259,6 +263,9 @@ def open( granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`. If a list of URLs is passed, we need to specify the data provider. provider: e.g. POCLOUD, NSIDC_CPRD, etc. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: A list of "file pointers" to remote (i.e. s3 or https) files. diff --git a/earthaccess/store.py b/earthaccess/store.py index 3e7dd48f..bbc1c7d1 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -68,7 +68,12 @@ def _open_files( ) -> List[fsspec.AbstractFileSystem]: def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile: urls, granule = data - return EarthAccessFile(fs.open(urls), granule) # type: ignore + return EarthAccessFile(fs.open(urls), granule) # type: ignore + + pqdm_kwargs = { + "exception_behavior": "immediate", + **(pqdm_kwargs or {}), + } fileset = pqdm( url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs @@ -348,6 +353,9 @@ def open( granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`. If a list of URLs is passed, we need to specify the data provider. provider: e.g. POCLOUD, NSIDC_CPRD, etc. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: A list of "file pointers" to remote (i.e. s3 or https) files. @@ -371,7 +379,6 @@ def _open_granules( granules: List[DataGranule], provider: Optional[str] = None, threads: int = 8, - pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: fileset: List = [] total_size = round(sum([granule.size() for granule in granules]) / 1024, 2) @@ -397,22 +404,14 @@ def _open_granules( else: access = "on_prem" s3_fs = None - access = "direct" - provider = granules[0]["meta"]["provider-id"] - # if the data has its own S3 credentials endpoint, we will use it - endpoint = self._own_s3_credentials(granules[0]["umm"]["RelatedUrls"]) - if endpoint is not None: - logger.info(f"using endpoint: {endpoint}") - s3_fs = self.get_s3_filesystem(endpoint=endpoint) - else: - logger.info(f"using provider: {provider}") - s3_fs = self.get_s3_filesystem(provider=provider) url_mapping = _get_url_granule_mapping(granules, access) if s3_fs is not None: try: fileset = _open_files( - url_mapping, fs=s3_fs, threads=threads, pqdm_kwargs=pqdm_kwargs + url_mapping, + fs=s3_fs, + threads=threads, ) except Exception as e: raise RuntimeError( @@ -421,15 +420,11 @@ def _open_granules( f"Exception: {traceback.format_exc()}" ) from e else: - fileset = self._open_urls_https( - url_mapping, threads=threads, pqdm_kwargs=pqdm_kwargs - ) + fileset = self._open_urls_https(url_mapping, threads=threads) return fileset else: url_mapping = _get_url_granule_mapping(granules, access="on_prem") - fileset = self._open_urls_https( - url_mapping, threads=threads, pqdm_kwargs=pqdm_kwargs - ) + fileset = self._open_urls_https(url_mapping, threads=threads) return fileset @_open.register @@ -512,6 +507,9 @@ def get( provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions threads: Parallel number of threads to use to download the files; adjust as necessary, default = 8. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: List of downloaded files @@ -553,6 +551,9 @@ def _get( provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions threads: Parallel number of threads to use to download the files; adjust as necessary, default = 8. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: None @@ -693,9 +694,9 @@ def _download_onprem_granules( directory: local directory to store the downloaded files threads: parallel number of threads to use to download the files; adjust as necessary, default = 8 - fail_fast: if set to True, the download process will stop immediately - upon encountering the first error. If set to False, errors will be - deferred, allowing the download of remaining files to continue. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: A list of local filepaths to which the files were downloaded. @@ -715,7 +716,7 @@ def _download_onprem_granules( self._download_file, n_jobs=threads, argument_type="args", - **pqdm_kwargs + **pqdm_kwargs, ) return results @@ -723,11 +724,12 @@ def _open_urls_https( self, url_mapping: Mapping[str, Union[DataGranule, None]], threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.AbstractFileSystem]: https_fs = self.get_fsspec_session() try: - return _open_files(url_mapping, https_fs, threads, **pqdm_kwargs) + return _open_files(url_mapping, https_fs, threads, pqdm_kwargs) except Exception: logger.exception( "An exception occurred while trying to access remote files via HTTPS" diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py new file mode 100644 index 00000000..27a7717e --- /dev/null +++ b/tests/unit/test_api.py @@ -0,0 +1,25 @@ +from unittest.mock import Mock + +import earthaccess +import pytest + + +def test_download(monkeypatch): + earthaccess.login() + + results = earthaccess.search_data( + short_name="ATL06", + bounding_box=(-10, 20, 10, 50), + temporal=("1999-02", "2019-03"), + count=10, + ) + + def mock_get(*args, **kwargs): + raise Exception("Download failed") + + mock_store = Mock() + monkeypatch.setattr(earthaccess, "__store__", mock_store) + monkeypatch.setattr(mock_store, "get", mock_get) + + with pytest.raises(Exception, match="Download failed"): + earthaccess.download(results, "/home/download-folder")