Skip to content

Commit

Permalink
[BUGFIX] Fix GXCloudStoreBackend updates by name (#8116)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel <gabriel59kg@gmail.com>
  • Loading branch information
roblim and Kilo59 authored Jun 15, 2023
1 parent 823620c commit 71b15d0
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 25 deletions.
70 changes: 59 additions & 11 deletions great_expectations/data_context/store/gx_cloud_store_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PayloadDataField(TypedDict):


class ResponsePayload(TypedDict):
data: PayloadDataField
data: PayloadDataField | list[PayloadDataField]


AnyPayload = Union[ResponsePayload, ErrorPayload]
Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__( # noqa: PLR0913
}
filter_properties_dict(properties=self._config, inplace=True)

def _get(self, key: Tuple[str, ...]) -> ResponsePayload: # type: ignore[override]
def _get(self, key: Tuple[GXCloudRESTResource, str | None, str | None]) -> ResponsePayload: # type: ignore[override]
ge_cloud_url = self.get_url_for_key(key=key)
params: Optional[dict] = None
try:
Expand Down Expand Up @@ -387,7 +387,9 @@ def _post(self, value: Any, **kwargs) -> GXCloudResourceRef:
response_json = response.json()

object_id = response_json["data"]["id"]
object_url = self.get_url_for_key((self.ge_cloud_resource_type, object_id))
object_url = self.get_url_for_key(
(self.ge_cloud_resource_type, object_id, None)
)
# This method is where posts get made for all cloud store endpoints. We pass
# the response_json back up to the caller because the specific resource may
# want to parse resource specific data out of the response.
Expand Down Expand Up @@ -471,7 +473,9 @@ def list_keys(self, prefix: Tuple = ()) -> List[Tuple[GXCloudRESTResource, str,
)

def get_url_for_key( # type: ignore[override]
self, key: Tuple[str, ...], protocol: Optional[Any] = None
self,
key: Tuple[GXCloudRESTResource, str | None, str | None],
protocol: Optional[Any] = None,
) -> str:
id = key[1]
url = construct_url(
Expand Down Expand Up @@ -542,26 +546,70 @@ def remove_key(self, key):
f"Unable to delete object in GX Cloud Store Backend: {repr(e)}"
)

def _update(self, key, value, **kwargs):
existing = self._get(key)
def _get_one_or_none_from_response_data(
self,
response_data: list[PayloadDataField] | PayloadDataField,
key: tuple[GXCloudRESTResource, str | None, str | None],
) -> PayloadDataField | None:
"""
GET requests to cloud can either return response data that is a single object (get by id) or a
list of objects with length >= 0 (get by name). This method takes this response data and returns a single
object or None.
"""
if not isinstance(response_data, list):
return response_data
if len(response_data) == 0:
return None
if len(response_data) == 1:
return response_data[0]
raise StoreBackendError(
f"Unable to update object in GX Cloud Store Backend: the provided key ({key}) maps "
f"to more than one object."
)

def _update(
self,
key: tuple[GXCloudRESTResource, str | None, str | None],
value: dict,
**kwargs,
):
response_data = self._get(key)["data"]
# if the provided key does not contain id (only name), cloud will return a list of resources filtered
# by name, with length >= 0, instead of a single object (or error if not found)
existing = self._get_one_or_none_from_response_data(
response_data=response_data, key=key
)

if existing is None:
raise StoreBackendError(
f"Unable to update object in GX Cloud Store Backend: could not find object associated with key {key}."
)

if key[1] is None:
key = (key[0], existing["data"]["id"], key[2])
key = (key[0], existing["id"], key[2])

return self.set(key=key, value=value, **kwargs)

def _add_or_update(self, key, value, **kwargs):
try:
existing = self._get(key)
response_data = self._get(key)["data"]
except StoreBackendError as e:
logger.info(f"Could not find object associated with key {key}: {e}")
existing = None
response_data = None

# if the provided key does not contain id (only name), cloud will return a list of resources filtered
# by name, with length >= 0, instead of a single object (or error if not found)
existing = self._get_one_or_none_from_response_data(
response_data=response_data, key=key
)

if existing is not None:
id = key[1] if key[1] is not None else existing["data"]["id"]
id = key[1] if key[1] is not None else existing["id"]
key = (key[0], id, key[2])
return self.set(key=key, value=value, **kwargs)
return self.add(key=key, value=value, **kwargs)

def _has_key(self, key: Tuple[str, ...]) -> bool:
def _has_key(self, key: Tuple[GXCloudRESTResource, str | None, str | None]) -> bool:
try:
_ = self._get(key)
except StoreBackendTransientError:
Expand Down
135 changes: 121 additions & 14 deletions tests/data_context/cloud_data_context/test_checkpoint_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
checkpointConfigSchema,
)
from great_expectations.data_context.types.resource_identifiers import GXCloudIdentifier
from great_expectations.exceptions import StoreBackendError
from great_expectations.exceptions import CheckpointNotFoundError, StoreBackendError
from great_expectations.util import get_context
from tests.data_context.conftest import MockResponse

Expand Down Expand Up @@ -137,7 +137,7 @@ def _mocked_get_response(*args, **kwargs):


@pytest.fixture
def mocked_get_by_name_response(
def mocked_get_by_name_response_1_result(
mock_response_factory: Callable,
checkpoint_config_with_ids: dict,
checkpoint_id: str,
Expand Down Expand Up @@ -173,14 +173,29 @@ def _mocked_get_by_name_response(*args, **kwargs):
return _mocked_get_by_name_response


@pytest.fixture
def mocked_get_by_name_response_0_results(
mock_response_factory: Callable,
) -> Callable[[], MockResponse]:
def _mocked_get_by_name_response(*args, **kwargs):
return mock_response_factory(
{
"data": [],
},
200,
)

return _mocked_get_by_name_response


@pytest.mark.cloud
@pytest.mark.integration
def test_cloud_backed_data_context_get_checkpoint_by_name(
empty_cloud_data_context: CloudDataContext,
checkpoint_id: str,
validation_ids: Tuple[str, str],
checkpoint_config: dict,
mocked_get_by_name_response: Callable[[], MockResponse],
mocked_get_by_name_response_1_result: Callable[[], MockResponse],
ge_cloud_base_url: str,
ge_cloud_organization_id: str,
) -> None:
Expand All @@ -193,7 +208,9 @@ def test_cloud_backed_data_context_get_checkpoint_by_name(
validation_id_1, validation_id_2 = validation_ids

with mock.patch(
"requests.Session.get", autospec=True, side_effect=mocked_get_by_name_response
"requests.Session.get",
autospec=True,
side_effect=mocked_get_by_name_response_1_result,
) as mock_get:
checkpoint = context.get_checkpoint(name=checkpoint_config["name"])

Expand Down Expand Up @@ -250,7 +267,7 @@ def test_cloud_backed_data_context_add_checkpoint(
"great_expectations.data_context.store.GXCloudStoreBackend._has_key",
autospec=True,
return_value=False,
) as _:
):
checkpoint = context.add_checkpoint(**checkpoint_config)

# Round trip through schema to mimic updates made during store serialization process
Expand Down Expand Up @@ -365,7 +382,65 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_adds(
"great_expectations.data_context.store.GXCloudStoreBackend._get",
autospec=True,
side_effect=StoreBackendError("Does not exist."),
) as _:
):
checkpoint = context.add_or_update_checkpoint(**checkpoint_config)

# Round trip through schema to mimic updates made during store serialization process
expected_checkpoint_config = checkpointConfigSchema.dump(
CheckpointConfig(**checkpoint_config)
)

mock_post.assert_called_with(
mock.ANY, # requests.Session object
f"{ge_cloud_base_url}/organizations/{ge_cloud_organization_id}/checkpoints",
json={
"data": {
"type": "checkpoint",
"attributes": {
"checkpoint_config": expected_checkpoint_config,
"organization_id": ge_cloud_organization_id,
},
},
},
)

assert checkpoint.ge_cloud_id == checkpoint_id
assert checkpoint.config.ge_cloud_id == checkpoint_id

assert checkpoint.config.validations[0]["id"] == validation_id_1
assert checkpoint.validations[0]["id"] == validation_id_1

assert checkpoint.config.validations[1]["id"] == validation_id_2
assert checkpoint.validations[1]["id"] == validation_id_2


@pytest.mark.cloud
@pytest.mark.integration
def test_cloud_backed_data_context_add_or_update_checkpoint_adds_when_id_not_present(
empty_cloud_data_context: CloudDataContext,
checkpoint_id: str,
validation_ids: Tuple[str, str],
checkpoint_config: dict,
mocked_post_response: Callable[[], MockResponse],
mocked_get_by_name_response_0_results: Callable[[], MockResponse],
ge_cloud_base_url: str,
ge_cloud_organization_id: str,
) -> None:
"""
A Cloud-backed context should save to a Cloud-backed CheckpointStore when calling `add_checkpoint`.
When saving, it should use the id from the response to create the checkpoint.
"""
context = empty_cloud_data_context

validation_id_1, validation_id_2 = validation_ids

with mock.patch(
"requests.Session.post", autospec=True, side_effect=mocked_post_response
) as mock_post, mock.patch(
"requests.Session.get",
autospec=True,
side_effect=mocked_get_by_name_response_0_results,
):
checkpoint = context.add_or_update_checkpoint(**checkpoint_config)

# Round trip through schema to mimic updates made during store serialization process
Expand Down Expand Up @@ -423,7 +498,7 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_updates_when_id_pres
"requests.Session.get",
autospec=True,
side_effect=mocked_get_response,
) as _:
):
checkpoint = context.add_or_update_checkpoint(**checkpoint_config_with_ids)

# Round trip through schema to mimic updates made during store serialization process
Expand Down Expand Up @@ -466,7 +541,7 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_updates_when_id_not_
validation_ids: Tuple[str, str],
checkpoint_config_with_ids: dict,
mocked_put_response: Callable[[], MockResponse],
mocked_get_response: Callable[[], MockResponse],
mocked_get_by_name_response_1_result: Callable[[], MockResponse],
ge_cloud_base_url: str,
ge_cloud_organization_id: str,
) -> None:
Expand All @@ -483,8 +558,8 @@ def test_cloud_backed_data_context_add_or_update_checkpoint_updates_when_id_not_
) as mock_put, mock.patch(
"requests.Session.get",
autospec=True,
side_effect=mocked_get_response,
) as _:
side_effect=mocked_get_by_name_response_1_result,
):
checkpoint_config = copy.deepcopy(checkpoint_config_with_ids)
checkpoint_config.pop("id")
checkpoint = context.add_or_update_checkpoint(**checkpoint_config)
Expand Down Expand Up @@ -541,7 +616,7 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_present(
"requests.Session.get",
autospec=True,
side_effect=mocked_get_response,
) as _:
):
ge_cloud_id = checkpoint_config_with_ids.pop("id")
checkpoint = context.update_checkpoint(
Checkpoint.instantiate_from_config_with_runtime_args(
Expand Down Expand Up @@ -585,6 +660,38 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_present(
assert checkpoint.validations[1]["id"] == validation_id_2


@pytest.mark.cloud
@pytest.mark.integration
def test_cloud_backed_data_context_update_non_existent_checkpoint_when_id_not_present(
empty_cloud_data_context: CloudDataContext,
checkpoint_config_with_ids: dict,
mocked_get_by_name_response_0_results: Callable[[], MockResponse],
ge_cloud_base_url: str,
ge_cloud_organization_id: str,
) -> None:
"""
A Cloud-backed context should raise StoreBackendError when calling `update_checkpoint` and the
referenced Checkpoint does not exist.
"""
context = empty_cloud_data_context

with mock.patch(
"requests.Session.get",
autospec=True,
side_effect=mocked_get_by_name_response_0_results,
):
checkpoint_config = copy.deepcopy(checkpoint_config_with_ids)
checkpoint_config.pop("id")
with pytest.raises(CheckpointNotFoundError):
context.update_checkpoint(
Checkpoint.instantiate_from_config_with_runtime_args(
checkpoint_config=CheckpointConfig(**checkpoint_config),
data_context=context,
name=checkpoint_config["name"],
)
)


@pytest.mark.cloud
@pytest.mark.integration
def test_cloud_backed_data_context_update_checkpoint_updates_when_id_not_present(
Expand All @@ -593,7 +700,7 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_not_present
validation_ids: Tuple[str, str],
checkpoint_config_with_ids: dict,
mocked_put_response: Callable[[], MockResponse],
mocked_get_response: Callable[[], MockResponse],
mocked_get_by_name_response_1_result: Callable[[], MockResponse],
ge_cloud_base_url: str,
ge_cloud_organization_id: str,
) -> None:
Expand All @@ -609,8 +716,8 @@ def test_cloud_backed_data_context_update_checkpoint_updates_when_id_not_present
) as mock_put, mock.patch(
"requests.Session.get",
autospec=True,
side_effect=mocked_get_response,
) as _:
side_effect=mocked_get_by_name_response_1_result,
):
checkpoint_config = copy.deepcopy(checkpoint_config_with_ids)
checkpoint_config.pop("id")

Expand Down

0 comments on commit 71b15d0

Please sign in to comment.