diff --git a/rohmu/factory.py b/rohmu/factory.py index ee1ec7c1..ad855736 100644 --- a/rohmu/factory.py +++ b/rohmu/factory.py @@ -66,19 +66,23 @@ def get_transfer_model(storage_config: Config) -> StorageModel: return storage_class.config_model(**storage_config) -def get_transfer(storage_config: Config) -> BaseTransfer[Any]: +def get_transfer(storage_config: Config, create_if_missing: bool = True) -> BaseTransfer[Any]: storage_config = storage_config.copy() notifier_config = storage_config.pop("notifier", None) notifier = None if notifier_config is not None: notifier = get_notifier(notifier_config) model = get_transfer_model(storage_config) - return get_transfer_from_model(model, notifier) + return get_transfer_from_model(model, notifier, create_if_missing=create_if_missing) -def get_transfer_from_model(model: StorageModelT, notifier: Optional[Notifier] = None) -> BaseTransfer[StorageModelT]: +def get_transfer_from_model( + model: StorageModelT, + notifier: Optional[Notifier] = None, + create_if_missing: bool = True, +) -> BaseTransfer[StorageModelT]: storage_class = get_class_for_storage_driver(model.storage_type) - return storage_class.from_model(model, notifier) + return storage_class.from_model(model, notifier, create_if_missing=create_if_missing) def _to_storage_driver(storage_type: str) -> StorageDriver: diff --git a/rohmu/object_storage/azure.py b/rohmu/object_storage/azure.py index 3e15e4af..b1ac7992 100644 --- a/rohmu/object_storage/azure.py +++ b/rohmu/object_storage/azure.py @@ -61,9 +61,10 @@ def __init__( proxy_info: Optional[dict[str, Union[str, int]]] = None, notifier: Optional[Notifier] = None, statsd_info: Optional[StatsdConfig] = None, + create_if_missing: bool = True, ) -> None: prefix = prefix.lstrip("/") if prefix else "" - super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info) + super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info, create_if_missing=create_if_missing) if not account_key and not sas_token: raise InvalidConfigurationError("One of account_key or sas_token must be specified to authenticate") @@ -409,6 +410,7 @@ def get_or_create_container(self, container_name: str) -> str: container_name = container_name.value start_time = time.monotonic() try: + # TODO respect the self.create_if_missing flag self.get_blob_service_client().create_container(container_name) except ResourceExistsError: pass diff --git a/rohmu/object_storage/base.py b/rohmu/object_storage/base.py index 3a18eb73..6d5cfe37 100644 --- a/rohmu/object_storage/base.py +++ b/rohmu/object_storage/base.py @@ -69,9 +69,15 @@ class BaseTransfer(Generic[StorageModelT]): is_thread_safe: bool = False supports_concurrent_upload: bool = False + # Set to true if the storage is ready to be used (e.g. the bucket is created and access has been verified) + _initialized: bool = False def __init__( - self, prefix: Optional[str], notifier: Optional[Notifier] = None, statsd_info: Optional[StatsdConfig] = None + self, + prefix: Optional[str], + notifier: Optional[Notifier] = None, + statsd_info: Optional[StatsdConfig] = None, + create_if_missing: bool = True, ) -> None: self.log = logging.getLogger(self.__class__.__name__) if not prefix: @@ -81,6 +87,12 @@ def __init__( self.prefix = prefix self.notifier = notifier or NullNotifier() self.stats = StatsClient(statsd_info) + self.create_if_missing = create_if_missing + self._initialized = True + + @property + def initialized(self): + return self._initialized def close(self) -> None: """Release all resources associated with the Transfer object.""" @@ -138,8 +150,10 @@ def _should_multipart( return int(size) > chunk_size @classmethod - def from_model(cls, model: StorageModelT, notifier: Optional[Notifier] = None) -> Self: - return cls(**model.dict(by_alias=True, exclude={"storage_type"}), notifier=notifier) + def from_model(cls, model: StorageModelT, notifier: Optional[Notifier] = None, create_if_missing: bool = True) -> Self: + return cls( + **model.dict(by_alias=True, exclude={"storage_type"}), notifier=notifier, create_if_missing=create_if_missing + ) def copy_file( self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **_kwargs: Any diff --git a/rohmu/object_storage/google.py b/rohmu/object_storage/google.py index 81643a2e..91978a06 100644 --- a/rohmu/object_storage/google.py +++ b/rohmu/object_storage/google.py @@ -6,7 +6,8 @@ from __future__ import annotations from contextlib import contextmanager -from googleapiclient.discovery import build, Resource +from googleapiclient._apis.storage.v1 import StorageResource +from googleapiclient.discovery import build from googleapiclient.errors import HttpError from googleapiclient.http import ( build_http, @@ -182,12 +183,13 @@ def __init__( proxy_info: Optional[dict[str, Union[str, int]]] = None, notifier: Optional[Notifier] = None, statsd_info: Optional[StatsdConfig] = None, + create_if_missing: bool = True, ) -> None: - super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info) + super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info, create_if_missing=create_if_missing) self.project_id = project_id self.proxy_info = proxy_info self.google_creds = get_credentials(credential_file=credential_file, credentials=credentials) - self.gs: Optional[Resource] = self._init_google_client() + self.gs: Optional[StorageResource] = self._init_google_client() self.gs_object_client: Any = None self.bucket_name = self.get_or_create_bucket(bucket_name) self.log.debug("GoogleTransfer initialized") @@ -200,7 +202,7 @@ def close(self) -> None: self.gs.close() self.gs = None - def _init_google_client(self) -> Resource: + def _init_google_client(self) -> StorageResource: start_time = time.monotonic() delay = 2 while True: @@ -598,16 +600,16 @@ def get_or_create_bucket(self, bucket_name: str) -> str: invalid bucket names ("Invalid bucket name") as well as for invalid project ("Invalid argument"), try to handle both gracefully.""" start_time = time.time() - gs_buckets = self.gs.buckets() # type: ignore[union-attr] + gs = self.gs + if gs is None: + raise RuntimeError("This method should not be called after closing the transfer") + gs_buckets = gs.buckets() try: - request = gs_buckets.get(bucket=bucket_name) - reporter = Reporter(StorageOperation.head_request) - self._retry_on_reset(request, request.execute, retry_reporter=reporter) - reporter.report(self.stats) + self._try_get_bucket(bucket_name, gs_buckets) self.log.debug("Bucket: %r already exists, took: %.3fs", bucket_name, time.time() - start_time) except HttpError as ex: if ex.resp["status"] == "404": - pass # we need to create it + pass # we may need to create it, depending on the create_if_missing_flag elif ex.resp["status"] == "403": raise InvalidConfigurationError(f"Bucket {repr(bucket_name)} exists but isn't accessible") else: @@ -615,11 +617,13 @@ def get_or_create_bucket(self, bucket_name: str) -> str: else: return bucket_name + if not self.create_if_missing: + # Mark the object as un-initialized so we don't attempt to use it for transfers when we did not create the bucket + self._initialized = False + return bucket_name + try: - req = gs_buckets.insert(project=self.project_id, body={"name": bucket_name}) - reporter = Reporter(StorageOperation.create_bucket) - self._retry_on_reset(req, req.execute, retry_reporter=reporter) - reporter.report(self.stats) + self._try_create_bucket(bucket_name, gs_buckets) self.log.debug("Created bucket: %r successfully, took: %.3fs", bucket_name, time.time() - start_time) except HttpError as ex: error = json.loads(ex.content.decode("utf-8"))["error"] @@ -634,6 +638,18 @@ def get_or_create_bucket(self, bucket_name: str) -> str: return bucket_name + def _try_get_bucket(self, bucket_name: str, gs_buckets: StorageResource.BucketsResource) -> None: + request = gs_buckets.get(bucket=bucket_name) + reporter = Reporter(StorageOperation.head_request) + self._retry_on_reset(request, request.execute, retry_reporter=reporter) + reporter.report(self.stats) + + def _try_create_bucket(self, bucket_name: str, gs_buckets: StorageResource.BucketsResource) -> None: + req = gs_buckets.insert(project=self.project_id, body={"name": bucket_name}) + reporter = Reporter(StorageOperation.create_bucket) + self._retry_on_reset(req, req.execute, retry_reporter=reporter) + reporter.report(self.stats) + class MediaStreamUpload(MediaUpload): """Support streaming arbitrary amount of data from non-seekable object supporting read method.""" diff --git a/rohmu/object_storage/local.py b/rohmu/object_storage/local.py index 7c8695a5..a9a6f235 100644 --- a/rohmu/object_storage/local.py +++ b/rohmu/object_storage/local.py @@ -50,9 +50,10 @@ def __init__( prefix: Optional[str] = None, notifier: Optional[Notifier] = None, statsd_info: Optional[StatsdConfig] = None, + create_if_missing: bool = True, ) -> None: prefix = os.path.join(directory, (prefix or "").strip("/")) - super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info) + super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info, create_if_missing=create_if_missing) self.log.debug("LocalTransfer initialized") def copy_file( diff --git a/rohmu/object_storage/s3.py b/rohmu/object_storage/s3.py index 04ee1073..ae46fdbd 100644 --- a/rohmu/object_storage/s3.py +++ b/rohmu/object_storage/s3.py @@ -124,8 +124,9 @@ def __init__( aws_session_token: Optional[str] = None, use_dualstack_endpoint: Optional[bool] = True, statsd_info: Optional[StatsdConfig] = None, + create_if_missing: bool = True, ) -> None: - super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info) + super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info, create_if_missing=create_if_missing) self.bucket_name = bucket_name self.region = region self.aws_access_key_id = aws_access_key_id @@ -605,6 +606,10 @@ def check_or_create_bucket(self) -> None: raise if create_bucket: + if not self.create_if_missing: + self._initialized = False + return + self.log.debug("Creating bucket: %r in location: %r", self.bucket_name, self.region) args: dict[str, Any] = { "Bucket": self.bucket_name, diff --git a/rohmu/object_storage/sftp.py b/rohmu/object_storage/sftp.py index 97cdc0ad..89ab5f2a 100644 --- a/rohmu/object_storage/sftp.py +++ b/rohmu/object_storage/sftp.py @@ -41,8 +41,9 @@ def __init__( prefix: Optional[str] = None, notifier: Optional[Notifier] = None, statsd_info: Optional[StatsdConfig] = None, + create_if_missing: bool = True, ) -> None: - super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info) + super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info, create_if_missing=create_if_missing) self.server = server self.port = port self.username = username diff --git a/rohmu/object_storage/swift.py b/rohmu/object_storage/swift.py index dfa6650d..6661932d 100644 --- a/rohmu/object_storage/swift.py +++ b/rohmu/object_storage/swift.py @@ -80,9 +80,10 @@ def __init__( endpoint_type: Optional[str] = None, notifier: Optional[Notifier] = None, statsd_info: Optional[StatsdConfig] = None, + create_if_missing: bool = True, ) -> None: prefix = prefix.lstrip("/") if prefix else "" - super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info) + super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info, create_if_missing=create_if_missing) self.container_name = container_name if auth_version == "3.0": diff --git a/test/object_storage/test_google.py b/test/object_storage/test_google.py index 1b57754e..6aed02cd 100644 --- a/test/object_storage/test_google.py +++ b/test/object_storage/test_google.py @@ -13,6 +13,8 @@ from unittest.mock import ANY, call, MagicMock, Mock, patch import base64 +import googleapiclient.errors +import httplib2 import pytest @@ -36,6 +38,84 @@ def test_close() -> None: assert transfer.gs is None +def _mock_403_response_from_google_api() -> Exception: + resp = httplib2.Response({"status": "403", "reason": "Unused"}) + uri = "https://storage.googleapis.com/storage/v1/b?project=project&alt=json" + content = ( + b'{\n "error": {\n "code": 403,\n "message": "account@project.iam.gserviceaccount.com does not have stor' + b"age.buckets.create access to the Google Cloud project. Permission 'storage.buckets.create' denied on resource " + b'(or it may not exist).",\n "errors": [\n {\n "message": "account@project.iam.gserviceaccount.com ' + b"does not have storage.buckets.create access to the Google Cloud project. Permission 'storage.buckets.create' " + b'denied on resource (or it may not exist).",\n "domain": "global",\n "reason": "forbidden"' + b"\n }\n ]\n }\n}\n" + ) + return googleapiclient.errors.HttpError(resp, content, uri) + + +def _mock_404_response_from_google_api() -> Exception: + resp = httplib2.Response({"status": "404", "reason": "Unused"}) + uri = "https://storage.googleapis.com/storage/v1/b?project=project&alt=json" + content = b"""{"error": {"code": 404, "message": "Does not matter"}}""" + return googleapiclient.errors.HttpError(resp, content, uri) + + +@pytest.mark.parametrize( + "create_if_missing,bucket_exists,sabotage_create,expect_create_call", + [ + # Happy path + pytest.param(True, True, False, False, id="happy-path-exists"), + pytest.param(True, False, False, True, id="happy-path-not-exists"), + # Happy path - without attempting to create buckets + pytest.param(False, True, False, False, id="no-create-exists"), + pytest.param(False, False, False, False, id="no-create-not-exists"), + # 403 failures when trying to create should not matter with create_if_missing=False + pytest.param(False, False, True, False, id="error-behaviour"), + # 403 failures when trying to create should crash with create_if_missing=False + pytest.param(True, False, True, True, id="graceful-403-handling"), + ], +) +def test_handle_missing_bucket( + create_if_missing: bool, bucket_exists: bool, sabotage_create: bool, expect_create_call: bool +) -> None: + with ExitStack() as stack: + stack.enter_context(patch("rohmu.object_storage.google.get_credentials")) + + _try_get_bucket = stack.enter_context(patch("rohmu.object_storage.google.GoogleTransfer._try_get_bucket")) + if not bucket_exists: + # If the bucket exists, the return value is ignored. This simulates a missing bucket. + _try_get_bucket.side_effect = _mock_404_response_from_google_api() + + _try_create_bucket = stack.enter_context(patch("rohmu.object_storage.google.GoogleTransfer._try_create_bucket")) + if sabotage_create: + _try_create_bucket.side_effect = _mock_403_response_from_google_api() + + if expect_create_call and sabotage_create: + with pytest.raises(googleapiclient.errors.HttpError): + _ = GoogleTransfer( + project_id="test-project-id", + bucket_name="test-bucket", + create_if_missing=create_if_missing, + ) + else: + transfer = GoogleTransfer( + project_id="test-project-id", + bucket_name="test-bucket", + create_if_missing=create_if_missing, + ) + if bucket_exists or expect_create_call: + # The bucket is here, we want to create it, and we didn't sabotage it with a 403 + assert transfer.initialized is True + else: + # The bucket is missing, we don't want to create it and/or the code path for a 403 was not exercised at all + assert transfer.initialized is False + + _try_get_bucket.assert_called_once() + if expect_create_call: + _try_create_bucket.assert_called_once() + else: + _try_create_bucket.assert_not_called() + + def test_store_file_from_memory() -> None: notifier = MagicMock() with ExitStack() as stack: