diff --git a/CHANGELOG.md b/CHANGELOG.md index e0dbad65..86742a5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [unreleased] +* bug fixes: + * Fix spelling mistake in `access` variable assignment (`direc` -> `direct`) + in `earthaccess.store._get_granules`. + * Pass `threads` arg to `_open_urls_https` in + `earthaccess.store._open_urls`, replacing the hard-coded value of 8. ## [v0.6.0] 2023-09-20 * bug fixes: diff --git a/earthaccess/api.py b/earthaccess/api.py index 5432bff7..5961b540 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -6,6 +6,7 @@ from fsspec import AbstractFileSystem from .auth import Auth +from .results import DataGranule from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery from .store import Store from .utils import _validation as validate @@ -150,8 +151,8 @@ def login(strategy: str = "all", persist: bool = False) -> Auth: def download( - granules: Union[List[earthaccess.results.DataGranule], List[str]], - local_path: Optional[str], + granules: Union[DataGranule, List[DataGranule], List[str]], + local_path: Union[str, None], provider: Optional[str] = None, threads: int = 8, ) -> List[str]: @@ -161,7 +162,7 @@ def download( * If we run it outside AWS (us-west-2 region) and the dataset is cloud hostes we'll use HTTP links Parameters: - granules: a list of granules(DataGranule) instances or a list of granule links (HTTP) + granules: a granule, list of granules, or a list of granule links (HTTP) local_path: local directory to store the remote data granules provider: if we download a list of URLs we need to specify the provider. threads: parallel number of threads to use to download the files, adjust as necessary, default = 8 @@ -169,6 +170,8 @@ def download( Returns: List of downloaded files """ + if isinstance(granules, DataGranule): + granules = [granules] try: results = earthaccess.__store__.get(granules, local_path, provider, threads) except AttributeError as err: diff --git a/earthaccess/store.py b/earthaccess/store.py index ef1ee72d..b8cab700 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -453,7 +453,9 @@ def _open_urls( "We cannot open S3 links when we are not in-region, try using HTTPS links" ) return None + fileset = self._open_urls_https(data_links, granules, 8, sizes) + return fileset def get( @@ -480,6 +482,12 @@ def get( Returns: List of downloaded files """ + if local_path is None: + local_path = os.path.join( + ".", + "data", + f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}", + ) if len(granules): files = self._get(granules, local_path, provider, threads) return files @@ -491,7 +499,7 @@ def get( def _get( self, granules: Union[List[DataGranule], List[str]], - local_path: Optional[str] = None, + local_path: str, provider: Optional[str] = None, threads: int = 8, ) -> Union[None, List[str]]: @@ -519,7 +527,7 @@ def _get( def _get_urls( self, granules: List[str], - local_path: Optional[str] = None, + local_path: str, provider: Optional[str] = None, threads: int = 8, ) -> Union[None, List[str]]: @@ -536,22 +544,21 @@ def _get_urls( s3_fs = self.get_s3fs_session(provider=provider) # TODO: make this parallel or concurrent for file in data_links: - file_name = file.split("/")[-1] s3_fs.get(file, local_path) - print(f"Retrieved: {file} to {local_path}") + file_name = os.path.join(local_path, os.path.basename(file)) + print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files else: # if we are not in AWS return self._download_onprem_granules(data_links, local_path, threads) - return None @_get.register def _get_granules( self, granules: List[DataGranule], - local_path: Optional[str] = None, + local_path: str, provider: Optional[str] = None, threads: int = 8, ) -> Union[None, List[str]]: @@ -560,7 +567,7 @@ def _get_granules( provider = granules[0]["meta"]["provider-id"] endpoint = self._own_s3_credentials(granules[0]["umm"]["RelatedUrls"]) cloud_hosted = granules[0].cloud_hosted - access = "direc" if (cloud_hosted and self.running_in_aws) else "external" + access = "direct" if (cloud_hosted and self.running_in_aws) else "external" data_links = list( # we are not in region chain.from_iterable( @@ -584,14 +591,13 @@ def _get_granules( # TODO: make this async for file in data_links: s3_fs.get(file, local_path) - file_name = file.split("/")[-1] - print(f"Retrieved: {file} to {local_path}") + file_name = os.path.join(local_path, os.path.basename(file)) + print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files else: # if the data is cloud based bu we are not in AWS it will be downloaded as if it was on prem return self._download_onprem_granules(data_links, local_path, threads) - return None def _download_file(self, url: str, directory: str) -> str: """ @@ -625,10 +631,10 @@ def _download_file(self, url: str, directory: str) -> str: raise Exception else: print(f"File {local_filename} already downloaded") - return local_filename + return local_path def _download_onprem_granules( - self, urls: List[str], directory: Optional[str] = None, threads: int = 8 + self, urls: List[str], directory: str, threads: int = 8 ) -> List[Any]: """ downloads a list of URLS into the data directory. @@ -645,14 +651,10 @@ def _download_onprem_granules( "We need to be logged into NASA EDL in order to download data granules" ) return [] - if directory is None: - directory_prefix = f"./data/{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}" - else: - directory_prefix = directory - if not os.path.exists(directory_prefix): - os.makedirs(directory_prefix) + if not os.path.exists(directory): + os.makedirs(directory) - arguments = [(url, directory_prefix) for url in urls] + arguments = [(url, directory) for url in urls] results = pqdm( arguments, self._download_file, diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index d4848d60..71745ff5 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -1,7 +1,6 @@ # package imports import logging import os -import shutil import unittest import earthaccess @@ -69,16 +68,15 @@ def test_granules_search_returns_valid_results(kwargs): assertions.assertTrue(len(results) <= 10) -def test_earthaccess_api_can_download_granules(): +@pytest.mark.parametrize("selection", [0, slice(None)]) +def test_earthaccess_api_can_download_granules(tmp_path, selection): results = earthaccess.search_data( count=2, short_name="ATL08", cloud_hosted=True, bounding_box=(-92.86, 16.26, -91.58, 16.97), ) - local_path = "./tests/integration/data/ATL08" - assertions.assertIsInstance(results, list) - assertions.assertTrue(len(results) <= 2) - files = earthaccess.download(results, local_path=local_path) + result = results[selection] + files = earthaccess.download(result, str(tmp_path)) assertions.assertIsInstance(files, list) - shutil.rmtree(local_path) + assert all(os.path.exists(f) for f in files)