From bfccfd3ddf25cfff3e2c39162c7b2b1dd33c12b9 Mon Sep 17 00:00:00 2001 From: Abe Gong Date: Thu, 24 Aug 2023 12:20:47 -0600 Subject: [PATCH] [MAINTENANCE] Add mode param to get_context (#8593) Co-authored-by: Gabriel Co-authored-by: Chetan Kini --- great_expectations/util.py | 96 ++++++++++++++++++++- tests/data_context/test_get_data_context.py | 30 +++++++ 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/great_expectations/util.py b/great_expectations/util.py index 70bc0bd7ac41..de65d45faa4e 100644 --- a/great_expectations/util.py +++ b/great_expectations/util.py @@ -45,6 +45,7 @@ Set, SupportsFloat, Tuple, + Type, Union, cast, overload, @@ -78,7 +79,7 @@ if TYPE_CHECKING: # needed until numpy min version 1.20 import numpy.typing as npt - from typing_extensions import TypeGuard + from typing_extensions import TypeAlias, TypeGuard from great_expectations.alias_types import PathStr from great_expectations.data_context import ( @@ -1739,6 +1740,9 @@ def convert_pandas_series_decimal_to_float_dtype( return data +ContextModes: TypeAlias = Literal["file", "cloud", "ephemeral"] + + @overload def get_context( # type: ignore[misc] # noqa: PLR0913 project_config: DataContextConfig | Mapping | None = ..., @@ -1763,6 +1767,7 @@ def get_context( # type: ignore[misc] # noqa: PLR0913 cloud_access_token: None = ..., cloud_organization_id: None = ..., cloud_mode: Literal[False] | None = ..., + mode: Literal["file"] | None = ..., ) -> FileDataContext: ... @@ -1777,6 +1782,7 @@ def get_context( # noqa: PLR0913 cloud_access_token: str | None = ..., cloud_organization_id: str | None = ..., cloud_mode: Literal[True] = ..., + mode: Literal["cloud"] | None = ..., ) -> CloudDataContext: ... @@ -1791,6 +1797,7 @@ def get_context( # noqa: PLR0913 cloud_access_token: str | None = ..., cloud_organization_id: str | None = ..., cloud_mode: bool | None = ..., + mode: None = ..., ) -> AbstractDataContext: ... @@ -1805,6 +1812,7 @@ def get_context( # noqa: PLR0913 cloud_access_token: str | None = None, cloud_organization_id: str | None = None, cloud_mode: bool | None = None, + mode: ContextModes | None = None, ) -> AbstractDataContext: """Method to return the appropriate Data Context depending on parameters and environment. @@ -1866,6 +1874,8 @@ def get_context( # noqa: PLR0913 cloud_organization_id: org_id for GX Cloud account. cloud_mode: whether to run GX in Cloud mode (default is None). If None, cloud mode is assumed if cloud credentials are set up. Set to False to override. + mode: which mode to use. One of: ephemeral, file, cloud. + Note: if mode is specified, cloud_mode is ignored. Returns: A Data Context. Either a FileDataContext, EphemeralDataContext, or @@ -1877,6 +1887,89 @@ def get_context( # noqa: PLR0913 """ project_config = _prepare_project_config(project_config) + param_lookup: dict[ContextModes | None, dict] = { + "ephemeral": dict( + project_config=project_config, + runtime_environment=runtime_environment, + cloud_mode=False, + ), + "file": dict( + project_config=project_config, + context_root_dir=context_root_dir, + project_root_dir=project_root_dir or Path.cwd(), + runtime_environment=runtime_environment, + cloud_mode=False, + ), + "cloud": dict( + project_config=project_config, + context_root_dir=context_root_dir, + project_root_dir=project_root_dir, + runtime_environment=runtime_environment, + cloud_base_url=cloud_base_url, + cloud_access_token=cloud_access_token, + cloud_organization_id=cloud_organization_id, + cloud_mode=True, + ), + None: dict( + project_config=project_config, + context_root_dir=context_root_dir, + project_root_dir=project_root_dir, + runtime_environment=runtime_environment, + cloud_base_url=cloud_base_url, + cloud_access_token=cloud_access_token, + cloud_organization_id=cloud_organization_id, + cloud_mode=cloud_mode, + ), + } + try: + kwargs = param_lookup[mode] + except KeyError: + raise ValueError( + f"Unknown mode {mode}. Please choose one of: ephemeral, file, cloud." + ) + + from great_expectations.data_context.data_context import ( + AbstractDataContext, + CloudDataContext, + EphemeralDataContext, + FileDataContext, + ) + + expected_ctx_types: dict[ + ContextModes | None, + Type[CloudDataContext] + | Type[EphemeralDataContext] + | Type[FileDataContext] + | Type[AbstractDataContext], + ] = { + "ephemeral": EphemeralDataContext, + "file": FileDataContext, + "cloud": CloudDataContext, + None: AbstractDataContext, + } + context = _get_context(**kwargs) + + expected_type = expected_ctx_types[mode] + if not isinstance(context, expected_type): + # example I want an emphemeral context but the presence of a GX_CLOUD env var gives me a cloud context + # this kind of thing should not be possbile but there may be some edge cases + raise ValueError( + f"Provided mode {mode} returned context of type {type(context).__name__} instead of {expected_type.__name__}; please check your input arguments." + ) + + return context + + +def _get_context( # noqa: PLR0913 + project_config: DataContextConfig | None = None, + context_root_dir: PathStr | None = None, + project_root_dir: PathStr | None = None, + runtime_environment: dict | None = None, + cloud_base_url: str | None = None, + cloud_access_token: str | None = None, + cloud_organization_id: str | None = None, + cloud_mode: bool | None = None, +) -> AbstractDataContext: # First, check for GX Cloud conditions cloud_context = _get_cloud_context( project_config=project_config, @@ -1888,6 +1981,7 @@ def get_context( # noqa: PLR0913 cloud_access_token=cloud_access_token, cloud_organization_id=cloud_organization_id, ) + if cloud_context: return cloud_context diff --git a/tests/data_context/test_get_data_context.py b/tests/data_context/test_get_data_context.py index fa9acfbdce21..e7ce8482b93f 100644 --- a/tests/data_context/test_get_data_context.py +++ b/tests/data_context/test_get_data_context.py @@ -248,6 +248,36 @@ def test_get_context_with_no_arguments_returns_ephemeral_with_sensible_defaults( assert context.config.stores == defaults.stores +@pytest.mark.unit +def test_get_context_with_mode_equals_ephemeral_returns_ephemeral_data_context(): + context = gx.get_context(mode="ephemeral") + assert isinstance(context, EphemeralDataContext) + + +@pytest.mark.unit +def test_get_context_with_mode_equals_file_returns_file_data_context( + tmp_path: pathlib.Path, +): + with working_directory(tmp_path): + context = gx.get_context(mode="file") + assert isinstance(context, FileDataContext) + + +@pytest.mark.cloud +def test_get_context_with_mode_equals_cloud_returns_cloud_data_context( + empty_ge_cloud_data_context_config: DataContextConfig, set_up_cloud_envs +): + with mock.patch.object( + CloudDataContext, + "retrieve_data_context_config_from_cloud", + return_value=empty_ge_cloud_data_context_config, + ) as mock_retrieve_config: + context = gx.get_context(mode="cloud") + + mock_retrieve_config.assert_called_once() + assert isinstance(context, CloudDataContext) + + @pytest.mark.parametrize("ge_cloud_mode", [True, None]) @pytest.mark.cloud def test_cloud_context_include_rendered_content(