diff --git a/src/toil/jobStores/aws/jobStore.py b/src/toil/jobStores/aws/jobStore.py index 74958a2d77..6a4d40eac9 100644 --- a/src/toil/jobStores/aws/jobStore.py +++ b/src/toil/jobStores/aws/jobStore.py @@ -64,6 +64,7 @@ from toil.jobStores.utils import ReadablePipe, ReadableTransformingPipe, WritablePipe from toil.lib.aws import build_tag_dict_from_env from toil.lib.aws.session import establish_boto3_session +from toil.lib.aws.s3 import head_s3_object from toil.lib.aws.utils import ( NoBucketLocationError, boto3_pager, @@ -1418,15 +1419,11 @@ def readFrom(self, readable): if info.version is None: # Somehow we don't know the version. Try and get it. - for attempt in retry_s3(predicate=lambda e: retryable_s3_errors(e) or isinstance(e, AssertionError)): - with attempt: - version = client.head_object(Bucket=bucket_name, - Key=compat_bytes(info.fileID), - **headerArgs).get('VersionId', None) - logger.warning('Loaded key for upload with no version and got version %s', - str(version)) - info.version = version - assert info.version is not None + info.version = head_s3_object( + Bucket=bucket_name, + Key=compat_bytes(info.fileID), + **headerArgs + ).get('VersionId', None) # Make sure we actually wrote something, even if an empty file assert (bool(info.version) or info.content is not None) @@ -1467,23 +1464,21 @@ def readFrom(self, readable): ExtraArgs=headerArgs) # use head_object with the SSE headers to access versionId and content_length attributes - headObj = client.head_object(Bucket=bucket_name, - Key=compat_bytes(info.fileID), - **headerArgs) - assert dataLength == headObj.get('ContentLength', None) - info.version = headObj.get('VersionId', None) + resp = head_s3_object( + Bucket=bucket_name, + Key=compat_bytes(info.fileID), + **headerArgs + ) + assert dataLength == resp.get('ContentLength', None) + info.version = resp.get('VersionId', None) logger.debug('Upload received version %s', str(info.version)) if info.version is None: # Somehow we don't know the version - for attempt in retry_s3(predicate=lambda e: retryable_s3_errors(e) or isinstance(e, AssertionError)): - with attempt: - headObj = client.head_object(Bucket=bucket_name, - Key=compat_bytes(info.fileID), - **headerArgs) - info.version = headObj.get('VersionId', None) - logger.warning('Reloaded key with no version and got version %s', str(info.version)) - assert info.version is not None + resp = head_s3_object(Bucket=bucket_name, Key=compat_bytes(info.fileID), header=headerArgs) + info.version = resp.get('VersionId', None) + logger.warning('Reloaded key with no version and got version %s', str(info.version)) + assert info.version is not None # Make sure we actually wrote something, even if an empty file assert (bool(info.version) or info.content is not None) diff --git a/src/toil/lib/aws/s3.py b/src/toil/lib/aws/s3.py index 77cb94d56e..78cc5cc461 100644 --- a/src/toil/lib/aws/s3.py +++ b/src/toil/lib/aws/s3.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import Dict, Any, Optional, List -from mypy_boto3_s3.type_defs import ListMultipartUploadsOutputTypeDef +from mypy_boto3_s3.type_defs import ListMultipartUploadsOutputTypeDef, HeadObjectOutputTypeDef from toil.lib.aws import session, AWSServerErrors from toil.lib.retry import retry @@ -23,6 +23,20 @@ @retry(errors=[AWSServerErrors]) +def head_s3_object(bucket: str, key: str, header: Dict[str, Any], region: Optional[str] = None) -> HeadObjectOutputTypeDef: + """ + Attempt to HEAD an s3 object and return its response. + + :param bucket: AWS bucket name + :param key: AWS Key name for the s3 object + :param header: Headers to include (mostly for encryption). + See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/head_object.html + :param region: Region that we want to look for the bucket in + """ + s3_client = session.client("s3", region_name=region) + return s3_client.head_object(Bucket=bucket, Key=key, **header) + + def list_multipart_uploads(bucket: str, region: str, prefix: str, max_uploads: int = 1) -> ListMultipartUploadsOutputTypeDef: s3_client = session.client("s3", region_name=region) return s3_client.list_multipart_uploads(Bucket=bucket, MaxUploads=max_uploads, Prefix=prefix)