Skip to content

Commit

Permalink
Handle S3 credential expiration more gracefully (#354)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Nov 30, 2023
1 parent d6fe7f8 commit cbc3bc4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 46 deletions.
84 changes: 49 additions & 35 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import os
import shutil
import traceback
from copy import deepcopy
from functools import lru_cache
from itertools import chain
from pathlib import Path
from pickle import dumps, loads
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import uuid4

import fsspec
Expand Down Expand Up @@ -98,8 +97,9 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
"""
if auth.authenticated is True:
self.auth = auth
self.s3_fs = None
self.initial_ts = datetime.datetime.now()
self._s3_credentials: Dict[
Tuple, Tuple[datetime.datetime, Dict[str, str]]
] = {}
oauth_profile = "https://urs.earthdata.nasa.gov/profile"
# sets the initial URS cookie
self._requests_cookies: Dict[str, Any] = {}
Expand Down Expand Up @@ -183,7 +183,6 @@ def set_requests_session(
elif resp.status_code >= 500:
resp.raise_for_status()

@lru_cache
def get_s3fs_session(
self,
daac: Optional[str] = None,
Expand All @@ -201,39 +200,54 @@ def get_s3fs_session(
Returns:
a s3fs file instance
"""
if self.auth is not None:
if not any([concept_id, daac, provider, endpoint]):
raise ValueError(
"At least one of the concept_id, daac, provider or endpoint"
"parameters must be specified. "
)
if endpoint is not None:
s3_credentials = self.auth.get_s3_credentials(endpoint=endpoint)
elif concept_id is not None:
provider = self._derive_concept_provider(concept_id)
s3_credentials = self.auth.get_s3_credentials(provider=provider)
elif daac is not None:
s3_credentials = self.auth.get_s3_credentials(daac=daac)
elif provider is not None:
s3_credentials = self.auth.get_s3_credentials(provider=provider)
now = datetime.datetime.now()
delta_minutes = now - self.initial_ts
# TODO: test this mocking the time or use https://github.com/dbader/schedule
# if we exceed 1 hour
if (
self.s3_fs is None or round(delta_minutes.seconds / 60, 2) > 59
) and s3_credentials is not None:
self.s3_fs = s3fs.S3FileSystem(
key=s3_credentials["accessKeyId"],
secret=s3_credentials["secretAccessKey"],
token=s3_credentials["sessionToken"],
)
self.initial_ts = datetime.datetime.now()
return deepcopy(self.s3_fs)
else:
if self.auth is None:
raise ValueError(
"A valid Earthdata login instance is required to retrieve S3 credentials"
)
if not any([concept_id, daac, provider, endpoint]):
raise ValueError(
"At least one of the concept_id, daac, provider or endpoint"
"parameters must be specified. "
)

if concept_id is not None:
provider = self._derive_concept_provider(concept_id)

# Get existing S3 credentials if we already have them
location = (
daac,
provider,
endpoint,
) # Identifier for where to get S3 credentials from
need_new_creds = False
try:
dt_init, creds = self._s3_credentials[location]
except KeyError:
need_new_creds = True
else:
# If cached credentials are expired, invalidate the cache
delta = datetime.datetime.now() - dt_init
if round(delta.seconds / 60, 2) > 55:
need_new_creds = True
self._s3_credentials.pop(location)

if need_new_creds:
# Don't have existing valid S3 credentials, so get new ones
now = datetime.datetime.now()
if endpoint is not None:
creds = self.auth.get_s3_credentials(endpoint=endpoint)
elif daac is not None:
creds = self.auth.get_s3_credentials(daac=daac)
elif provider is not None:
creds = self.auth.get_s3_credentials(provider=provider)
# Include new credentials in the cache
self._s3_credentials[location] = now, creds

return s3fs.S3FileSystem(
key=creds["accessKeyId"],
secret=creds["secretAccessKey"],
token=creds["sessionToken"],
)

@lru_cache
def get_fsspec_session(self) -> fsspec.AbstractFileSystem:
Expand Down
32 changes: 21 additions & 11 deletions tests/unit/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import fsspec
import pytest
import responses
import s3fs
from earthaccess import Auth, Store


Expand Down Expand Up @@ -60,12 +61,22 @@ def test_store_can_create_s3_fsspec_session(self):
"https://api.giovanni.earthdata.nasa.gov/s3credentials",
"https://data.laadsdaac.earthdatacloud.nasa.gov/s3credentials",
]
mock_creds = {
"accessKeyId": "sure",
"secretAccessKey": "correct",
"sessionToken": "whynot",
}
expected_storage_options = {
"key": mock_creds["accessKeyId"],
"secret": mock_creds["secretAccessKey"],
"token": mock_creds["sessionToken"],
}

for endpoint in custom_endpoints:
responses.add(
responses.GET,
endpoint,
json={},
json=mock_creds,
status=200,
)

Expand All @@ -74,40 +85,39 @@ def test_store_can_create_s3_fsspec_session(self):
responses.add(
responses.GET,
daac["s3-credentials"],
json={
"accessKeyId": "sure",
"secretAccessKey": "correct",
"sessionToken": "whynot",
},
json=mock_creds,
status=200,
)
responses.add(
responses.GET,
"https://urs.earthdata.nasa.gov/profile",
json={},
json=mock_creds,
status=200,
)

store = Store(self.auth)
self.assertTrue(isinstance(store.auth, Auth))
for daac in ["NSIDC", "PODAAC", "LPDAAC", "ORNLDAAC", "GES_DISC", "ASF"]:
s3_fs = store.get_s3fs_session(daac=daac)
self.assertEqual(type(s3_fs), type(fsspec.filesystem("s3")))
assert isinstance(s3_fs, s3fs.S3FileSystem)
assert s3_fs.storage_options == expected_storage_options

for endpoint in custom_endpoints:
s3_fs = store.get_s3fs_session(endpoint=endpoint)
self.assertEqual(type(s3_fs), type(fsspec.filesystem("s3")))
assert isinstance(s3_fs, s3fs.S3FileSystem)
assert s3_fs.storage_options == expected_storage_options

for provider in [
"NSIDC_CPRD",
"POCLOUD",
"LPCLOUD",
"ORNLCLOUD",
"ORNL_CLOUD",
"GES_DISC",
"ASF",
]:
s3_fs = store.get_s3fs_session(provider=provider)
assert isinstance(s3_fs, fsspec.AbstractFileSystem)
assert isinstance(s3_fs, s3fs.S3FileSystem)
assert s3_fs.storage_options == expected_storage_options

# Ensure informative error is raised
with pytest.raises(ValueError, match="parameters must be specified"):
Expand Down

0 comments on commit cbc3bc4

Please sign in to comment.