Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added parameters and exception behaviour to pqdm #792

Merged
merged 11 commits into from
Nov 4, 2024
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html)

## [Unreleased]

- Fix `earthaccess.download` to not ignore errors by default
([#581](https://github.com/nsidc/earthaccess/issues/581))
([**@Sherwin-14**](https://github.com/Sherwin-14),
[**@chuckwondo**](https://github.com/chuckwondo),
[**@mfisher87**](https://github.com/mfisher87))

### Changed

- Use built-in `assert` statements instead of `unittest` assertions in
Expand Down
27 changes: 24 additions & 3 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests
import s3fs
from fsspec import AbstractFileSystem
from typing_extensions import Any, Dict, List, Optional, Union, deprecated
from typing_extensions import Any, Dict, List, Mapping, Optional, Union, deprecated

import earthaccess
from earthaccess.services import DataServices
Expand Down Expand Up @@ -205,6 +205,7 @@ def download(
local_path: Optional[str],
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
chuckwondo marked this conversation as resolved.
Show resolved Hide resolved
) -> List[str]:
"""Retrieves data granules from a remote storage system.

Expand All @@ -217,6 +218,9 @@ def download(
local_path: local directory to store the remote data granules
provider: if we download a list of URLs, we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.

Returns:
List of downloaded files
Expand All @@ -225,12 +229,19 @@ def download(
Exception: A file download failed.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
**(pqdm_kwargs or {}),
}
if isinstance(granules, DataGranule):
granules = [granules]
elif isinstance(granules, str):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
results = earthaccess.__store__.get(
granules, local_path, provider, threads, pqdm_kwargs
)
except AttributeError as err:
logger.error(
f"{err}: You must call earthaccess.login() before you can download data"
Expand All @@ -242,6 +253,7 @@ def download(
def open(
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
chuckwondo marked this conversation as resolved.
Show resolved Hide resolved
) -> List[AbstractFileSystem]:
"""Returns a list of file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -250,12 +262,21 @@ def open(
granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`.
If a list of URLs is passed, we need to specify the data provider.
provider: e.g. POCLOUD, NSIDC_CPRD, etc.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.

Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
provider = _normalize_location(provider)
results = earthaccess.__store__.open(granules=granules, provider=provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
**(pqdm_kwargs or {}),
}
results = earthaccess.__store__.open(
granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs
)
return results


Expand Down
55 changes: 47 additions & 8 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,20 @@ def _open_files(
url_mapping: Mapping[str, Union[DataGranule, None]],
fs: fsspec.AbstractFileSystem,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile:
url, granule = data
return EarthAccessFile(fs.open(url), granule) # type: ignore

fileset = pqdm(url_mapping.items(), multi_thread_open, n_jobs=threads)
pqdm_kwargs = {
"exception_behavior": "immediate",
**(pqdm_kwargs or {}),
}

fileset = pqdm(
url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs
)
chuckwondo marked this conversation as resolved.
Show resolved Hide resolved
return fileset


Expand Down Expand Up @@ -336,6 +344,7 @@ def open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
chuckwondo marked this conversation as resolved.
Show resolved Hide resolved
) -> List[fsspec.spec.AbstractBufferedFile]:
"""Returns a list of file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -344,19 +353,23 @@ def open(
granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`.
If a list of URLs is passed, we need to specify the data provider.
provider: e.g. POCLOUD, NSIDC_CPRD, etc.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.

Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
if len(granules):
return self._open(granules, provider)
return self._open(granules, provider, pqdm_kwargs)
return []

@singledispatchmethod
def _open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
raise NotImplementedError("granules should be a list of DataGranule or URLs")

Expand Down Expand Up @@ -420,6 +433,7 @@ def _open_urls(
granules: List[str],
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []

Expand Down Expand Up @@ -447,6 +461,7 @@ def _open_urls(
url_mapping,
fs=s3_fs,
threads=threads,
pqdm_kwargs=pqdm_kwargs,
)
except Exception as e:
raise RuntimeError(
Expand All @@ -466,7 +481,7 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(url_mapping, threads)
fileset = self._open_urls_https(url_mapping, threads, pqdm_kwargs)
return fileset

def get(
Expand All @@ -475,6 +490,7 @@ def get(
local_path: Union[Path, str, None] = None,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.

Expand All @@ -491,6 +507,9 @@ def get(
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.

Returns:
List of downloaded files
Expand All @@ -503,7 +522,7 @@ def get(
local_path = Path(local_path)

if len(granules):
files = self._get(granules, local_path, provider, threads)
files = self._get(granules, local_path, provider, threads, pqdm_kwargs)
return files
else:
raise ValueError("List of URLs or DataGranule instances expected")
Expand All @@ -515,6 +534,7 @@ def _get(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.

Expand All @@ -531,6 +551,9 @@ def _get(
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.

Returns:
None
Expand All @@ -544,6 +567,7 @@ def _get_urls(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links = granules
downloaded_files: List = []
Expand All @@ -565,7 +589,9 @@ def _get_urls(

else:
# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, threads)
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
)

@_get.register
def _get_granules(
Expand All @@ -574,6 +600,7 @@ def _get_granules(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links: List = []
downloaded_files: List = []
Expand Down Expand Up @@ -614,7 +641,9 @@ def _get_granules(
else:
# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
)

def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.
Expand Down Expand Up @@ -652,7 +681,11 @@ def _download_file(self, url: str, directory: Path) -> str:
return str(path)

def _download_onprem_granules(
self, urls: List[str], directory: Path, threads: int = 8
self,
urls: List[str],
directory: Path,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
"""Downloads a list of URLS into the data directory.

Expand All @@ -661,6 +694,9 @@ def _download_onprem_granules(
directory: local directory to store the downloaded files
threads: parallel number of threads to use to download the files;
adjust as necessary, default = 8
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.

Returns:
A list of local filepaths to which the files were downloaded.
Expand All @@ -674,23 +710,26 @@ def _download_onprem_granules(
directory.mkdir(parents=True, exist_ok=True)

arguments = [(url, directory) for url in urls]

results = pqdm(
arguments,
self._download_file,
n_jobs=threads,
argument_type="args",
**pqdm_kwargs,
)
return results

def _open_urls_https(
self,
url_mapping: Mapping[str, Union[DataGranule, None]],
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.AbstractFileSystem]:
https_fs = self.get_fsspec_session()

try:
return _open_files(url_mapping, https_fs, threads)
return _open_files(url_mapping, https_fs, threads, pqdm_kwargs)
except Exception:
logger.exception(
"An exception occurred while trying to access remote files via HTTPS"
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import Mock

import earthaccess
import pytest


def test_download_immediate_failure(monkeypatch):
earthaccess.login()

results = earthaccess.search_data(
short_name="ATL06",
bounding_box=(-10, 20, 10, 50),
temporal=("1999-02", "2019-03"),
count=10,
)

def mock_get(*args, **kwargs):
raise Exception("Download failed")

mock_store = Mock()
monkeypatch.setattr(earthaccess, "__store__", mock_store)
monkeypatch.setattr(mock_store, "get", mock_get)

with pytest.raises(Exception, match="Download failed"):
earthaccess.download(results, "/home/download-folder")


Sherwin-14 marked this conversation as resolved.
Show resolved Hide resolved
def test_download_deferred_failure(monkeypatch):
earthaccess.login()

results = earthaccess.search_data(
short_name="ATL06",
bounding_box=(-10, 20, 10, 50),
temporal=("1999-02", "2019-03"),
count=10,
)

def mock_get(*args, **kwargs):
return [Exception("Download failed")] * len(results)

mock_store = Mock()
monkeypatch.setattr(earthaccess, "__store__", mock_store)
monkeypatch.setattr(mock_store, "get", mock_get)

results = earthaccess.download(
results, "/home/download-folder", None, 8, {"exception_behavior": "deferred"}
)

assert all(isinstance(e, Exception) for e in results)
assert len(results) == 10
Loading