Skip to content

Commit

Permalink
Merge pull request #317 from jrbourbeau/download-minor
Browse files Browse the repository at this point in the history
Minor `earthaccess.download` updates
  • Loading branch information
betolink authored Oct 13, 2023
2 parents 25a38ea + bbe0aff commit 250848d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 28 deletions.
9 changes: 6 additions & 3 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fsspec import AbstractFileSystem

from .auth import Auth
from .results import DataGranule
from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery
from .store import Store
from .utils import _validation as validate
Expand Down Expand Up @@ -150,8 +151,8 @@ def login(strategy: str = "all", persist: bool = False) -> Auth:


def download(
granules: Union[List[earthaccess.results.DataGranule], List[str]],
local_path: Optional[str],
granules: Union[DataGranule, List[DataGranule], List[str]],
local_path: Union[str, None],
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -161,14 +162,16 @@ def download(
* If we run it outside AWS (us-west-2 region) and the dataset is cloud hostes we'll use HTTP links
Parameters:
granules: a list of granules(DataGranule) instances or a list of granule links (HTTP)
granules: a granule, list of granules, or a list of granule links (HTTP)
local_path: local directory to store the remote data granules
provider: if we download a list of URLs we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
Returns:
List of downloaded files
"""
if isinstance(granules, DataGranule):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
except AttributeError as err:
Expand Down
36 changes: 18 additions & 18 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ def get(
Returns:
List of downloaded files
"""
if local_path is None:
local_path = os.path.join(
".",
"data",
f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}",
)
if len(granules):
files = self._get(granules, local_path, provider, threads)
return files
Expand All @@ -464,7 +470,7 @@ def get(
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand Down Expand Up @@ -492,7 +498,7 @@ def _get(
def _get_urls(
self,
granules: List[str],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand All @@ -509,22 +515,21 @@ def _get_urls(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this parallel or concurrent
for file in data_links:
file_name = file.split("/")[-1]
s3_fs.get(file, local_path)
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files

else:
# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, threads)
return None

@_get.register
def _get_granules(
self,
granules: List[DataGranule],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand Down Expand Up @@ -557,14 +562,13 @@ def _get_granules(
# TODO: make this async
for file in data_links:
s3_fs.get(file, local_path)
file_name = file.split("/")[-1]
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
else:
# if the data is cloud based bu we are not in AWS it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return None

def _download_file(self, url: str, directory: str) -> str:
"""
Expand Down Expand Up @@ -598,10 +602,10 @@ def _download_file(self, url: str, directory: str) -> str:
raise Exception
else:
print(f"File {local_filename} already downloaded")
return local_filename
return local_path

def _download_onprem_granules(
self, urls: List[str], directory: Optional[str] = None, threads: int = 8
self, urls: List[str], directory: str, threads: int = 8
) -> List[Any]:
"""
downloads a list of URLS into the data directory.
Expand All @@ -618,14 +622,10 @@ def _download_onprem_granules(
"We need to be logged into NASA EDL in order to download data granules"
)
return []
if directory is None:
directory_prefix = f"./data/{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}"
else:
directory_prefix = directory
if not os.path.exists(directory_prefix):
os.makedirs(directory_prefix)
if not os.path.exists(directory):
os.makedirs(directory)

arguments = [(url, directory_prefix) for url in urls]
arguments = [(url, directory) for url in urls]
results = pqdm(
arguments,
self._download_file,
Expand Down
12 changes: 5 additions & 7 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# package imports
import logging
import os
import shutil
import unittest

import earthaccess
Expand Down Expand Up @@ -69,16 +68,15 @@ def test_granules_search_returns_valid_results(kwargs):
assertions.assertTrue(len(results) <= 10)


def test_earthaccess_api_can_download_granules():
@pytest.mark.parametrize("selection", [0, slice(None)])
def test_earthaccess_api_can_download_granules(tmp_path, selection):
results = earthaccess.search_data(
count=2,
short_name="ATL08",
cloud_hosted=True,
bounding_box=(-92.86, 16.26, -91.58, 16.97),
)
local_path = "./tests/integration/data/ATL08"
assertions.assertIsInstance(results, list)
assertions.assertTrue(len(results) <= 2)
files = earthaccess.download(results, local_path=local_path)
result = results[selection]
files = earthaccess.download(result, str(tmp_path))
assertions.assertIsInstance(files, list)
shutil.rmtree(local_path)
assert all(os.path.exists(f) for f in files)

0 comments on commit 250848d

Please sign in to comment.