Skip to content

Commit

Permalink
Replaced fast_fail with pqdm_kwargs the code and added a test
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherwin-14 committed Oct 28, 2024
1 parent 87cd8fc commit 374f303
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 24 deletions.
7 changes: 7 additions & 0 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def download(
local_path: local directory to store the remote data granules
provider: if we download a list of URLs, we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
List of downloaded files
Expand All @@ -226,6 +229,7 @@ def download(
Exception: A file download failed.
"""
provider = _normalize_location(provider)
pqdm_kwargs = dict(pqdm_kwargs) if pqdm_kwargs is not None else {}
pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
Expand Down Expand Up @@ -259,6 +263,9 @@ def open(
granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`.
If a list of URLs is passed, we need to specify the data provider.
provider: e.g. POCLOUD, NSIDC_CPRD, etc.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
Expand Down
50 changes: 26 additions & 24 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ def _open_files(
) -> List[fsspec.AbstractFileSystem]:
def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile:
urls, granule = data
return EarthAccessFile(fs.open(urls), granule) # type: ignore
return EarthAccessFile(fs.open(urls), granule) # type: ignore

pqdm_kwargs = {
"exception_behavior": "immediate",
**(pqdm_kwargs or {}),
}

fileset = pqdm(
url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs
Expand Down Expand Up @@ -348,6 +353,9 @@ def open(
granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`.
If a list of URLs is passed, we need to specify the data provider.
provider: e.g. POCLOUD, NSIDC_CPRD, etc.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
Expand All @@ -371,7 +379,6 @@ def _open_granules(
granules: List[DataGranule],
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
Expand All @@ -397,22 +404,14 @@ def _open_granules(
else:
access = "on_prem"
s3_fs = None
access = "direct"
provider = granules[0]["meta"]["provider-id"]
# if the data has its own S3 credentials endpoint, we will use it
endpoint = self._own_s3_credentials(granules[0]["umm"]["RelatedUrls"])
if endpoint is not None:
logger.info(f"using endpoint: {endpoint}")
s3_fs = self.get_s3_filesystem(endpoint=endpoint)
else:
logger.info(f"using provider: {provider}")
s3_fs = self.get_s3_filesystem(provider=provider)

url_mapping = _get_url_granule_mapping(granules, access)
if s3_fs is not None:
try:
fileset = _open_files(
url_mapping, fs=s3_fs, threads=threads, pqdm_kwargs=pqdm_kwargs
url_mapping,
fs=s3_fs,
threads=threads,
)
except Exception as e:
raise RuntimeError(
Expand All @@ -421,15 +420,11 @@ def _open_granules(
f"Exception: {traceback.format_exc()}"
) from e
else:
fileset = self._open_urls_https(
url_mapping, threads=threads, pqdm_kwargs=pqdm_kwargs
)
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset
else:
url_mapping = _get_url_granule_mapping(granules, access="on_prem")
fileset = self._open_urls_https(
url_mapping, threads=threads, pqdm_kwargs=pqdm_kwargs
)
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset

@_open.register
Expand Down Expand Up @@ -512,6 +507,9 @@ def get(
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
List of downloaded files
Expand Down Expand Up @@ -553,6 +551,9 @@ def _get(
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
None
Expand Down Expand Up @@ -693,9 +694,9 @@ def _download_onprem_granules(
directory: local directory to store the downloaded files
threads: parallel number of threads to use to download the files;
adjust as necessary, default = 8
fail_fast: if set to True, the download process will stop immediately
upon encountering the first error. If set to False, errors will be
deferred, allowing the download of remaining files to continue.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
A list of local filepaths to which the files were downloaded.
Expand All @@ -715,19 +716,20 @@ def _download_onprem_granules(
self._download_file,
n_jobs=threads,
argument_type="args",
**pqdm_kwargs
**pqdm_kwargs,
)
return results

def _open_urls_https(
self,
url_mapping: Mapping[str, Union[DataGranule, None]],
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.AbstractFileSystem]:
https_fs = self.get_fsspec_session()

try:
return _open_files(url_mapping, https_fs, threads, **pqdm_kwargs)
return _open_files(url_mapping, https_fs, threads, pqdm_kwargs)
except Exception:
logger.exception(
"An exception occurred while trying to access remote files via HTTPS"
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from unittest.mock import Mock

import earthaccess
import pytest


def test_download(monkeypatch):
earthaccess.login()

results = earthaccess.search_data(
short_name="ATL06",
bounding_box=(-10, 20, 10, 50),
temporal=("1999-02", "2019-03"),
count=10,
)

def mock_get(*args, **kwargs):
raise Exception("Download failed")

mock_store = Mock()
monkeypatch.setattr(earthaccess, "__store__", mock_store)
monkeypatch.setattr(mock_store, "get", mock_get)

with pytest.raises(Exception, match="Download failed"):
earthaccess.download(results, "/home/download-folder")

0 comments on commit 374f303

Please sign in to comment.