Skip to content

Commit

Permalink
[MAINTENANCE] Improve DatasourceDict performance (#8711)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel <gabriel59kg@gmail.com>
  • Loading branch information
cdkini and Kilo59 authored Sep 21, 2023
1 parent 1c59e51 commit f7fc291
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 67 deletions.
107 changes: 57 additions & 50 deletions great_expectations/core/datasource_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ def __init__(
self._datasource_store = datasource_store
self._in_memory_data_assets: dict[str, DataAsset] = {}

@property
def _names(self) -> set[str]:
# The contents of the store may change between uses so we constantly refresh when requested
keys = self._datasource_store.list_keys()
return {key.resource_name for key in keys} # type: ignore[attr-defined] # list_keys() is annotated with generic DataContextKey instead of subclass

@staticmethod
def _get_in_memory_data_asset_name(
datasource_name: str, data_asset_name: str
Expand All @@ -78,40 +72,49 @@ def data(self) -> dict[str, FluentDatasource | BaseDatasource]: # type: ignore[
This is generated just-in-time as the contents of the store may have changed.
"""
datasources = {}
for name in self._names:
datasources: dict[str, FluentDatasource | BaseDatasource] = {}

configs = self._datasource_store.get_all()
for config in configs:
try:
datasources[name] = self.__getitem__(name)
if isinstance(config, FluentDatasource):
name = config.name
datasources[name] = self._init_fluent_datasource(
name=name, ds=config
)
else:
name = config["name"]
datasources[name] = self._init_block_style_datasource(
name=name, config=config
)
except gx_exceptions.DatasourceInitializationError as e:
logger.warning(f"Cannot initialize datasource {name}: {e}")

return datasources

@override
def __contains__(self, name: object) -> bool:
# Minor optimization - only pulls names instead of building all datasources in self.data
return name in self._names

@override
def __setitem__(self, name: str, ds: FluentDatasource | BaseDatasource) -> None:
config: FluentDatasource | DatasourceConfig
if isinstance(ds, FluentDatasource):
if isinstance(ds, SupportsInMemoryDataAssets):
for asset in ds.assets:
if asset.type == _IN_MEMORY_DATA_ASSET_TYPE:
in_memory_asset_name: str = (
DatasourceDict._get_in_memory_data_asset_name(
datasource_name=name,
data_asset_name=asset.name,
)
)
self._in_memory_data_assets[in_memory_asset_name] = asset
config = ds
config = self._prep_fds_config(name=name, ds=ds)
else:
config = self._prep_legacy_datasource_config(name=name, ds=ds)

self._datasource_store.set(key=None, value=config)

def _prep_fds_config(self, name: str, ds: FluentDatasource) -> FluentDatasource:
if isinstance(ds, SupportsInMemoryDataAssets):
for asset in ds.assets:
if asset.type == _IN_MEMORY_DATA_ASSET_TYPE:
in_memory_asset_name: str = (
DatasourceDict._get_in_memory_data_asset_name(
datasource_name=name,
data_asset_name=asset.name,
)
)
self._in_memory_data_assets[in_memory_asset_name] = asset
return ds

def _prep_legacy_datasource_config(
self, name: str, ds: BaseDatasource
) -> DatasourceConfig:
Expand All @@ -122,42 +125,44 @@ def _prep_legacy_datasource_config(
config["class_name"] = ds.__class__.__name__
return datasourceConfigSchema.load(config)

@override
def __delitem__(self, name: str) -> None:
if not self.__contains__(name):
def _get_ds_from_store(self, name: str) -> FluentDatasource | DatasourceConfig:
try:
return self._datasource_store.retrieve_by_name(name)
except ValueError:
raise KeyError(f"Could not find a datasource named '{name}'")

ds = self._datasource_store.retrieve_by_name(name)
@override
def __delitem__(self, name: str) -> None:
ds = self._get_ds_from_store(name)
self._datasource_store.delete(ds)

@override
def __getitem__(self, name: str) -> FluentDatasource | BaseDatasource:
if not self.__contains__(name):
raise KeyError(f"Could not find a datasource named '{name}'")
ds = self._get_ds_from_store(name)

ds = self._datasource_store.retrieve_by_name(name)
if isinstance(ds, FluentDatasource):
hydrated_ds = self._init_fluent_datasource(ds)
if isinstance(hydrated_ds, SupportsInMemoryDataAssets):
for asset in hydrated_ds.assets:
if asset.type == _IN_MEMORY_DATA_ASSET_TYPE:
in_memory_asset_name: str = (
DatasourceDict._get_in_memory_data_asset_name(
datasource_name=name,
data_asset_name=asset.name,
)
)
cached_data_asset = self._in_memory_data_assets.get(
in_memory_asset_name
)
if cached_data_asset:
asset.dataframe = cached_data_asset.dataframe
return hydrated_ds
return self._init_fluent_datasource(name=name, ds=ds)
return self._init_block_style_datasource(name=name, config=ds)

def _init_fluent_datasource(self, ds: FluentDatasource) -> FluentDatasource:
def _init_fluent_datasource(
self, name: str, ds: FluentDatasource
) -> FluentDatasource:
ds._data_context = self._context
ds._rebuild_asset_data_connectors()
if isinstance(ds, SupportsInMemoryDataAssets):
for asset in ds.assets:
if asset.type == _IN_MEMORY_DATA_ASSET_TYPE:
in_memory_asset_name: str = (
DatasourceDict._get_in_memory_data_asset_name(
datasource_name=name,
data_asset_name=asset.name,
)
)
cached_data_asset = self._in_memory_data_assets.get(
in_memory_asset_name
)
if cached_data_asset:
asset.dataframe = cached_data_asset.dataframe
return ds

# To be removed once block-style is fully removed (deprecated as of v0.17.2)
Expand Down Expand Up @@ -198,7 +203,9 @@ def __contains__(self, name: object) -> bool:
if name in self.data:
return True
try:
return super().__contains__(name)
# Resort to store only if not in cache
_ = self._get_ds_from_store(str(name))
return True
except KeyError:
return False

Expand Down
26 changes: 24 additions & 2 deletions great_expectations/data_context/store/datasource_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,37 @@ def gx_cloud_response_json_to_object_dict(
data = response_json["data"]
if isinstance(data, list):
if len(data) > 1:
# TODO: handle larger arrays of Datasources
# Larger arrays of datasources should be handled by `gx_cloud_response_json_to_object_collection`
raise TypeError(
f"GX Cloud returned {len(data)} Datasources but expected 1"
)
data = data[0]

return DatasourceStore._convert_raw_json_to_object_dict(data)

@override
@staticmethod
def gx_cloud_response_json_to_object_collection(
response_json: CloudResponsePayloadTD, # type: ignore[override]
) -> list[dict]:
"""
This method takes full json response from GX cloud and outputs a list of dicts appropriate for
deserialization into a collection of GX objects
"""
logger.debug(f"GE Cloud Response JSON ->\n{pf(response_json, depth=3)}")
data = response_json["data"]
if not isinstance(data, list):
raise TypeError(
"GX Cloud did not return a collection of Datasources when expected"
)

return [DatasourceStore._convert_raw_json_to_object_dict(d) for d in data]

@staticmethod
def _convert_raw_json_to_object_dict(data: DataPayload) -> dict:
datasource_ge_cloud_id: str = data["id"]
datasource_config_dict: dict = data["attributes"]["datasource_config"]
datasource_config_dict["id"] = datasource_ge_cloud_id

return datasource_config_dict

def retrieve_by_name(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _get(self, key):

@override
def _get_all(self) -> list[Any]:
raise NotImplementedError
return [
val for key, val in self._store.items() if key != self.STORE_BACKEND_ID_KEY
]

@override
def _set(self, key, value, **kwargs) -> None:
Expand Down
21 changes: 19 additions & 2 deletions great_expectations/data_context/store/store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
)

from typing_extensions import TypedDict

Expand Down Expand Up @@ -101,6 +110,14 @@ def gx_cloud_response_json_to_object_dict(response_json: Dict) -> Dict:
"""
return response_json

@staticmethod
def gx_cloud_response_json_to_object_collection(response_json: Dict) -> List[Dict]:
"""
This method takes full json response from GX cloud and outputs a list of dicts appropriate for
deserialization into a collection of GX objects
"""
raise NotImplementedError

def _validate_key(self, key: DataContextKey) -> None:
# STORE_BACKEND_ID_KEY always validated
if key == StoreBackend.STORE_BACKEND_ID_KEY:
Expand Down Expand Up @@ -202,7 +219,7 @@ def get(
def get_all(self) -> list[Any]:
objs = self._store_backend.get_all()
if self.cloud_mode:
objs = self.gx_cloud_response_json_to_object_dict(objs)
objs = self.gx_cloud_response_json_to_object_collection(objs)

return list(map(self.deserialize, objs))

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3674,11 +3674,11 @@ def empty_cloud_data_context(

@pytest.fixture
@mock.patch(
"great_expectations.data_context.store.DatasourceStore.list_keys",
"great_expectations.data_context.store.DatasourceStore.get_all",
return_value=[],
)
def empty_base_data_context_in_cloud_mode_custom_base_url(
mock_list_keys: mock.MagicMock, # Avoid making a call to Cloud backend during datasource instantiation
mock_get_all: mock.MagicMock, # Avoid making a call to Cloud backend during datasource instantiation
tmp_path: pathlib.Path,
empty_ge_cloud_data_context_config: DataContextConfig,
ge_cloud_config: GXCloudConfig,
Expand Down
41 changes: 31 additions & 10 deletions tests/core/test_datasource_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class DatasourceStoreSpy(DatasourceStore):
def __init__(self, datasource_configs: list[dict] | None = None) -> None:
self.list_keys_count = 0
self.has_key_count = 0
self.set_count = 0
self.get_count = 0
self.remove_key_count = 0
Expand All @@ -38,6 +39,7 @@ def __init__(self, datasource_configs: list[dict] | None = None) -> None:

# Reset counters
self.list_keys_count = 0
self.has_key_count = 0
self.set_count = 0
self.get_count = 0
self.remove_key_count = 0
Expand All @@ -54,6 +56,10 @@ def list_keys(self):
self.list_keys_count += 1
return super().list_keys()

def has_key(self, key) -> bool:
self.has_key_count += 1
return super().has_key(key)

def remove_key(self, key):
self.remove_key_count += 1
return super().remove_key(key)
Expand Down Expand Up @@ -238,12 +244,20 @@ def build_cacheable_datasource_dict_with_store_spy(
) -> Callable:
def _build_cacheable_datasource_dict_with_store_spy(
datasource_configs: list[dict] | None = None,
populate_cache: bool = True,
) -> CacheableDatasourceDict:
return CacheableDatasourceDict(
datasource_dict = CacheableDatasourceDict(
context=in_memory_runtime_context,
datasource_store=DatasourceStoreSpy(datasource_configs=datasource_configs),
)

# Populate cache
if populate_cache and datasource_configs:
for ds in datasource_configs:
datasource_dict.data[ds.name] = ds

return datasource_dict

return _build_cacheable_datasource_dict_with_store_spy


Expand All @@ -259,9 +273,10 @@ def cacheable_datasource_dict_with_fds(
build_cacheable_datasource_dict_with_store_spy: Callable,
pandas_fds: PandasDatasource,
) -> CacheableDatasourceDict:
return build_cacheable_datasource_dict_with_store_spy(
datasource_dict = build_cacheable_datasource_dict_with_store_spy(
datasource_configs=[pandas_fds]
)
return datasource_dict


@pytest.mark.unit
Expand All @@ -283,12 +298,12 @@ def test_cacheable_datasource_dict___contains___requests_store_upon_cache_miss(
store = cacheable_datasource_dict_with_fds._datasource_store

assert store.get_count == 0
assert store.list_keys_count == 0
assert store.has_key_count == 0

# Lookup will check store due to lack of presence in cache (but won't retrieve value)
assert "my_fake_name" not in cacheable_datasource_dict_with_fds
assert store.get_count == 0
assert store.list_keys_count == 1
assert store.has_key_count == 1


@pytest.mark.unit
Expand Down Expand Up @@ -320,17 +335,22 @@ def test_cacheable_datasource_dict___setitem___with_block_datasource(

@pytest.mark.unit
def test_cacheable_datasource_dict___delitem__updates_both_cache_and_store(
cacheable_datasource_dict_with_fds: CacheableDatasourceDict, pandas_fds_name: str
build_cacheable_datasource_dict_with_store_spy: Callable,
pandas_block_datasource_config: dict,
):
store = cacheable_datasource_dict_with_fds._datasource_store
datasource_dict = build_cacheable_datasource_dict_with_store_spy(
datasource_configs=[pandas_block_datasource_config], populate_cache=True
)
store = datasource_dict._datasource_store
name = pandas_block_datasource_config["name"]
assert store.remove_key_count == 0

# Deletion will go down to the store level
del cacheable_datasource_dict_with_fds[pandas_fds_name]
del datasource_dict[name]
assert store.remove_key_count == 1

# Should also impact the cache
assert pandas_fds_name not in cacheable_datasource_dict_with_fds.data
assert name not in datasource_dict.data


@pytest.mark.unit
Expand Down Expand Up @@ -363,7 +383,8 @@ def test_cacheable_datasource_dict___getitem___with_fds(
pandas_fds: PandasDatasource,
):
datasource_dict = build_cacheable_datasource_dict_with_store_spy(
datasource_configs=[pandas_fds]
datasource_configs=[pandas_fds],
populate_cache=False,
)
store = datasource_dict._datasource_store
assert store.get_count == 0
Expand All @@ -379,7 +400,7 @@ def test_cacheable_datasource_dict___getitem___with_block_datasource(
pandas_block_datasource_config: dict,
):
datasource_dict = build_cacheable_datasource_dict_with_store_spy(
datasource_configs=[pandas_block_datasource_config]
datasource_configs=[pandas_block_datasource_config], populate_cache=False
)
store = datasource_dict._datasource_store
assert store.get_count == 0
Expand Down
Loading

0 comments on commit f7fc291

Please sign in to comment.