From 71b15d0ad268382111f9a13fcdf604ba9ec394a1 Mon Sep 17 00:00:00 2001 From: Rob Lim Date: Thu, 15 Jun 2023 11:30:24 -0700 Subject: [PATCH] [BUGFIX] Fix GXCloudStoreBackend updates by name (#8116) Co-authored-by: Gabriel --- .../store/gx_cloud_store_backend.py | 70 +++++++-- .../test_checkpoint_crud.py | 135 ++++++++++++++++-- 2 files changed, 180 insertions(+), 25 deletions(-) diff --git a/great_expectations/data_context/store/gx_cloud_store_backend.py b/great_expectations/data_context/store/gx_cloud_store_backend.py index 2d3ffad8b7a4..c9bd279ff841 100644 --- a/great_expectations/data_context/store/gx_cloud_store_backend.py +++ b/great_expectations/data_context/store/gx_cloud_store_backend.py @@ -44,7 +44,7 @@ class PayloadDataField(TypedDict): class ResponsePayload(TypedDict): - data: PayloadDataField + data: PayloadDataField | list[PayloadDataField] AnyPayload = Union[ResponsePayload, ErrorPayload] @@ -222,7 +222,7 @@ def __init__( # noqa: PLR0913 } filter_properties_dict(properties=self._config, inplace=True) - def _get(self, key: Tuple[str, ...]) -> ResponsePayload: # type: ignore[override] + def _get(self, key: Tuple[GXCloudRESTResource, str | None, str | None]) -> ResponsePayload: # type: ignore[override] ge_cloud_url = self.get_url_for_key(key=key) params: Optional[dict] = None try: @@ -387,7 +387,9 @@ def _post(self, value: Any, **kwargs) -> GXCloudResourceRef: response_json = response.json() object_id = response_json["data"]["id"] - object_url = self.get_url_for_key((self.ge_cloud_resource_type, object_id)) + object_url = self.get_url_for_key( + (self.ge_cloud_resource_type, object_id, None) + ) # This method is where posts get made for all cloud store endpoints. We pass # the response_json back up to the caller because the specific resource may # want to parse resource specific data out of the response. @@ -471,7 +473,9 @@ def list_keys(self, prefix: Tuple = ()) -> List[Tuple[GXCloudRESTResource, str, ) def get_url_for_key( # type: ignore[override] - self, key: Tuple[str, ...], protocol: Optional[Any] = None + self, + key: Tuple[GXCloudRESTResource, str | None, str | None], + protocol: Optional[Any] = None, ) -> str: id = key[1] url = construct_url( @@ -542,26 +546,70 @@ def remove_key(self, key): f"Unable to delete object in GX Cloud Store Backend: {repr(e)}" ) - def _update(self, key, value, **kwargs): - existing = self._get(key) + def _get_one_or_none_from_response_data( + self, + response_data: list[PayloadDataField] | PayloadDataField, + key: tuple[GXCloudRESTResource, str | None, str | None], + ) -> PayloadDataField | None: + """ + GET requests to cloud can either return response data that is a single object (get by id) or a + list of objects with length >= 0 (get by name). This method takes this response data and returns a single + object or None. + """ + if not isinstance(response_data, list): + return response_data + if len(response_data) == 0: + return None + if len(response_data) == 1: + return response_data[0] + raise StoreBackendError( + f"Unable to update object in GX Cloud Store Backend: the provided key ({key}) maps " + f"to more than one object." + ) + + def _update( + self, + key: tuple[GXCloudRESTResource, str | None, str | None], + value: dict, + **kwargs, + ): + response_data = self._get(key)["data"] + # if the provided key does not contain id (only name), cloud will return a list of resources filtered + # by name, with length >= 0, instead of a single object (or error if not found) + existing = self._get_one_or_none_from_response_data( + response_data=response_data, key=key + ) + + if existing is None: + raise StoreBackendError( + f"Unable to update object in GX Cloud Store Backend: could not find object associated with key {key}." + ) + if key[1] is None: - key = (key[0], existing["data"]["id"], key[2]) + key = (key[0], existing["id"], key[2]) return self.set(key=key, value=value, **kwargs) def _add_or_update(self, key, value, **kwargs): try: - existing = self._get(key) + response_data = self._get(key)["data"] except StoreBackendError as e: logger.info(f"Could not find object associated with key {key}: {e}") - existing = None + response_data = None + + # if the provided key does not contain id (only name), cloud will return a list of resources filtered + # by name, with length >= 0, instead of a single object (or error if not found) + existing = self._get_one_or_none_from_response_data( + response_data=response_data, key=key + ) + if existing is not None: - id = key[1] if key[1] is not None else existing["data"]["id"] + id = key[1] if key[1] is not None else existing["id"] key = (key[0], id, key[2]) return self.set(key=key, value=value, **kwargs) return self.add(key=key, value=value, **kwargs) - def _has_key(self, key: Tuple[str, ...]) -> bool: + def _has_key(self, key: Tuple[GXCloudRESTResource, str | None, str | None]) -> bool: try: _ = self._get(key) except StoreBackendTransientError: diff --git a/tests/data_context/cloud_data_context/test_checkpoint_crud.py b/tests/data_context/cloud_data_context/test_checkpoint_crud.py index 263c51c1343f..a6818ad8b7e4 100644 --- a/tests/data_context/cloud_data_context/test_checkpoint_crud.py +++ b/tests/data_context/cloud_data_context/test_checkpoint_crud.py @@ -19,7 +19,7 @@ checkpointConfigSchema, ) from great_expectations.data_context.types.resource_identifiers import GXCloudIdentifier -from great_expectations.exceptions import StoreBackendError +from great_expectations.exceptions import CheckpointNotFoundError, StoreBackendError from great_expectations.util import get_context from tests.data_context.conftest import MockResponse @@ -137,7 +137,7 @@ def _mocked_get_response(*args, **kwargs): @pytest.fixture -def mocked_get_by_name_response( +def mocked_get_by_name_response_1_result( mock_response_factory: Callable, checkpoint_config_with_ids: dict, checkpoint_id: str, @@ -173,6 +173,21 @@ def _mocked_get_by_name_response(*args, **kwargs): return _mocked_get_by_name_response +@pytest.fixture +def mocked_get_by_name_response_0_results( + mock_response_factory: Callable, +) -> Callable[[], MockResponse]: + def _mocked_get_by_name_response(*args, **kwargs): + return mock_response_factory( + { + "data": [], + }, + 200, + ) + + return _mocked_get_by_name_response + + @pytest.mark.cloud @pytest.mark.integration def test_cloud_backed_data_context_get_checkpoint_by_name( @@ -180,7 +195,7 @@ def test_cloud_backed_data_context_get_checkpoint_by_name( checkpoint_id: str, validation_ids: Tuple[str, str], checkpoint_config: dict, - mocked_get_by_name_response: Callable[[], MockResponse], + mocked_get_by_name_response_1_result: Callable[[], MockResponse], ge_cloud_base_url: str, ge_cloud_organization_id: str, ) -> None: @@ -193,7 +208,9 @@ def test_cloud_backed_data_context_get_checkpoint_by_name( validation_id_1, validation_id_2 = validation_ids with mock.patch( - "requests.Session.get", autospec=True, side_effect=mocked_get_by_name_response + "requests.Session.get", + autospec=True, + side_effect=mocked_get_by_name_response_1_result, ) as mock_get: checkpoint = context.get_checkpoint(name=checkpoint_config["name"]) @@ -250,7 +267,7 @@ def test_cloud_backed_data_context_add_checkpoint( "great_expectations.data_context.store.GXCloudStoreBackend._has_key", autospec=True, return_value=False, - ) as _: + ): checkpoint = context.add_checkpoint(**checkpoint_config) # Round trip through schema to mimic updates made during store serialization process @@ -365,7 +382,65 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_adds( "great_expectations.data_context.store.GXCloudStoreBackend._get", autospec=True, side_effect=StoreBackendError("Does not exist."), - ) as _: + ): + checkpoint = context.add_or_update_checkpoint(**checkpoint_config) + + # Round trip through schema to mimic updates made during store serialization process + expected_checkpoint_config = checkpointConfigSchema.dump( + CheckpointConfig(**checkpoint_config) + ) + + mock_post.assert_called_with( + mock.ANY, # requests.Session object + f"{ge_cloud_base_url}/organizations/{ge_cloud_organization_id}/checkpoints", + json={ + "data": { + "type": "checkpoint", + "attributes": { + "checkpoint_config": expected_checkpoint_config, + "organization_id": ge_cloud_organization_id, + }, + }, + }, + ) + + assert checkpoint.ge_cloud_id == checkpoint_id + assert checkpoint.config.ge_cloud_id == checkpoint_id + + assert checkpoint.config.validations[0]["id"] == validation_id_1 + assert checkpoint.validations[0]["id"] == validation_id_1 + + assert checkpoint.config.validations[1]["id"] == validation_id_2 + assert checkpoint.validations[1]["id"] == validation_id_2 + + +@pytest.mark.cloud +@pytest.mark.integration +def test_cloud_backed_data_context_add_or_update_checkpoint_adds_when_id_not_present( + empty_cloud_data_context: CloudDataContext, + checkpoint_id: str, + validation_ids: Tuple[str, str], + checkpoint_config: dict, + mocked_post_response: Callable[[], MockResponse], + mocked_get_by_name_response_0_results: Callable[[], MockResponse], + ge_cloud_base_url: str, + ge_cloud_organization_id: str, +) -> None: + """ + A Cloud-backed context should save to a Cloud-backed CheckpointStore when calling `add_checkpoint`. + When saving, it should use the id from the response to create the checkpoint. + """ + context = empty_cloud_data_context + + validation_id_1, validation_id_2 = validation_ids + + with mock.patch( + "requests.Session.post", autospec=True, side_effect=mocked_post_response + ) as mock_post, mock.patch( + "requests.Session.get", + autospec=True, + side_effect=mocked_get_by_name_response_0_results, + ): checkpoint = context.add_or_update_checkpoint(**checkpoint_config) # Round trip through schema to mimic updates made during store serialization process @@ -423,7 +498,7 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_updates_when_id_pres "requests.Session.get", autospec=True, side_effect=mocked_get_response, - ) as _: + ): checkpoint = context.add_or_update_checkpoint(**checkpoint_config_with_ids) # Round trip through schema to mimic updates made during store serialization process @@ -466,7 +541,7 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_updates_when_id_not_ validation_ids: Tuple[str, str], checkpoint_config_with_ids: dict, mocked_put_response: Callable[[], MockResponse], - mocked_get_response: Callable[[], MockResponse], + mocked_get_by_name_response_1_result: Callable[[], MockResponse], ge_cloud_base_url: str, ge_cloud_organization_id: str, ) -> None: @@ -483,8 +558,8 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_updates_when_id_not_ ) as mock_put, mock.patch( "requests.Session.get", autospec=True, - side_effect=mocked_get_response, - ) as _: + side_effect=mocked_get_by_name_response_1_result, + ): checkpoint_config = copy.deepcopy(checkpoint_config_with_ids) checkpoint_config.pop("id") checkpoint = context.add_or_update_checkpoint(**checkpoint_config) @@ -541,7 +616,7 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_present( "requests.Session.get", autospec=True, side_effect=mocked_get_response, - ) as _: + ): ge_cloud_id = checkpoint_config_with_ids.pop("id") checkpoint = context.update_checkpoint( Checkpoint.instantiate_from_config_with_runtime_args( @@ -585,6 +660,38 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_present( assert checkpoint.validations[1]["id"] == validation_id_2 +@pytest.mark.cloud +@pytest.mark.integration +def test_cloud_backed_data_context_update_non_existent_checkpoint_when_id_not_present( + empty_cloud_data_context: CloudDataContext, + checkpoint_config_with_ids: dict, + mocked_get_by_name_response_0_results: Callable[[], MockResponse], + ge_cloud_base_url: str, + ge_cloud_organization_id: str, +) -> None: + """ + A Cloud-backed context should raise StoreBackendError when calling `update_checkpoint` and the + referenced Checkpoint does not exist. + """ + context = empty_cloud_data_context + + with mock.patch( + "requests.Session.get", + autospec=True, + side_effect=mocked_get_by_name_response_0_results, + ): + checkpoint_config = copy.deepcopy(checkpoint_config_with_ids) + checkpoint_config.pop("id") + with pytest.raises(CheckpointNotFoundError): + context.update_checkpoint( + Checkpoint.instantiate_from_config_with_runtime_args( + checkpoint_config=CheckpointConfig(**checkpoint_config), + data_context=context, + name=checkpoint_config["name"], + ) + ) + + @pytest.mark.cloud @pytest.mark.integration def test_cloud_backed_data_context_update_checkpoint_updates_when_id_not_present( @@ -593,7 +700,7 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_not_present validation_ids: Tuple[str, str], checkpoint_config_with_ids: dict, mocked_put_response: Callable[[], MockResponse], - mocked_get_response: Callable[[], MockResponse], + mocked_get_by_name_response_1_result: Callable[[], MockResponse], ge_cloud_base_url: str, ge_cloud_organization_id: str, ) -> None: @@ -609,8 +716,8 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_not_present ) as mock_put, mock.patch( "requests.Session.get", autospec=True, - side_effect=mocked_get_response, - ) as _: + side_effect=mocked_get_by_name_response_1_result, + ): checkpoint_config = copy.deepcopy(checkpoint_config_with_ids) checkpoint_config.pop("id")