diff --git a/great_expectations/compatibility/sqlalchemy.py b/great_expectations/compatibility/sqlalchemy.py index 057fdce787ef..decd902435f8 100644 --- a/great_expectations/compatibility/sqlalchemy.py +++ b/great_expectations/compatibility/sqlalchemy.py @@ -278,6 +278,10 @@ except (ImportError, AttributeError): __version__ = None +try: + from sqlalchemy.sql.type_api import TypeEngine +except (ImportError, AttributeError): + TypeEngine = SQLALCHEMY_NOT_IMPORTED # type: ignore[misc,assignment] try: from sqlalchemy.sql import sqltypes diff --git a/tests/integration/data_sources_and_expectations/test_canonical_expectations.py b/tests/integration/data_sources_and_expectations/test_canonical_expectations.py index b79708cdbc72..6cac30389e84 100644 --- a/tests/integration/data_sources_and_expectations/test_canonical_expectations.py +++ b/tests/integration/data_sources_and_expectations/test_canonical_expectations.py @@ -3,7 +3,6 @@ import pandas as pd import great_expectations.expectations as gxe -from great_expectations.compatibility.snowflake import SNOWFLAKE_TYPES from great_expectations.compatibility.sqlalchemy import sqltypes from tests.integration.conftest import parameterize_batch_for_data_sources from tests.integration.test_utils.data_source_config import ( @@ -18,8 +17,8 @@ data_source_configs=[ PandasDataFrameDatasourceTestConfig(), PandasFilesystemCsvDatasourceTestConfig(), - PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}), - SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}), + PostgreSQLDatasourceTestConfig(), + SnowflakeDatasourceTestConfig(), ], data=pd.DataFrame({"a": [1, 2]}), ) @@ -87,8 +86,8 @@ def test_expect_column_max_to_be_between__date(batch_for_datasource) -> None: data_source_configs=[ PandasDataFrameDatasourceTestConfig(), PandasFilesystemCsvDatasourceTestConfig(), - SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}), - PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}), + PostgreSQLDatasourceTestConfig(), + SnowflakeDatasourceTestConfig(), ], data=pd.DataFrame({"a": [1, 2]}), ) @@ -102,8 +101,8 @@ def test_expect_column_max_to_be_between(batch_for_datasource) -> None: data_source_configs=[ PandasDataFrameDatasourceTestConfig(), PandasFilesystemCsvDatasourceTestConfig(), - SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}), - PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}), + PostgreSQLDatasourceTestConfig(), + SnowflakeDatasourceTestConfig(), ], data=pd.DataFrame({"a": [1, 2]}), ) @@ -117,8 +116,8 @@ def test_expect_column_to_exist(batch_for_datasource): data_source_configs=[ PandasDataFrameDatasourceTestConfig(), PandasFilesystemCsvDatasourceTestConfig(), - SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}), - PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}), + PostgreSQLDatasourceTestConfig(), + SnowflakeDatasourceTestConfig(), ], data=pd.DataFrame({"a": [1, 2]}), ) @@ -132,8 +131,8 @@ def test_expect_column_values_to_not_be_null(batch_for_datasource): data_source_configs=[ PandasDataFrameDatasourceTestConfig(), PandasFilesystemCsvDatasourceTestConfig(), - SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}), - PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}), + PostgreSQLDatasourceTestConfig(), + SnowflakeDatasourceTestConfig(), ], data=pd.DataFrame({"a": [1, 2, 3, 4]}), ) @@ -145,12 +144,7 @@ def test_expect_column_mean_to_be_between(batch_for_datasource): class TestExpectTableRowCountToEqualOtherTable: @parameterize_batch_for_data_sources( - data_source_configs=[ - PostgreSQLDatasourceTestConfig( - column_types={"col_a": sqltypes.INTEGER}, - extra_assets={"test_table_two": {"col_b": sqltypes.VARCHAR}}, - ), - ], + data_source_configs=[PostgreSQLDatasourceTestConfig(), SnowflakeDatasourceTestConfig()], data=pd.DataFrame({"a": [1, 2, 3, 4]}), extra_data={"test_table_two": pd.DataFrame({"col_b": ["a", "b", "c", "d"]})}, ) @@ -161,10 +155,8 @@ def test_success(self, batch_for_datasource): @parameterize_batch_for_data_sources( data_source_configs=[ - PostgreSQLDatasourceTestConfig( - column_types={"col_a": sqltypes.INTEGER}, - extra_assets={"test_table_two": {"col_b": sqltypes.VARCHAR}}, - ), + PostgreSQLDatasourceTestConfig(), + SnowflakeDatasourceTestConfig(), ], data=pd.DataFrame({"a": [1, 2, 3, 4]}), extra_data={"test_table_two": pd.DataFrame({"col_b": ["just_this_one!"]})}, @@ -180,9 +172,8 @@ def test_different_counts(self, batch_for_datasource): @parameterize_batch_for_data_sources( data_source_configs=[ - PostgreSQLDatasourceTestConfig( - column_types={"col_a": sqltypes.INTEGER}, - ), + PostgreSQLDatasourceTestConfig(), + SnowflakeDatasourceTestConfig(), ], data=pd.DataFrame({"a": [1, 2, 3, 4]}), ) diff --git a/tests/integration/test_utils/data_source_config/base.py b/tests/integration/test_utils/data_source_config/base.py index a1faff8bc120..406f6844127f 100644 --- a/tests/integration/test_utils/data_source_config/base.py +++ b/tests/integration/test_utils/data_source_config/base.py @@ -3,7 +3,7 @@ import random import string from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property from typing import TYPE_CHECKING, Generic, Mapping, Optional, TypeVar @@ -23,7 +23,7 @@ class DataSourceTestConfig(ABC, Generic[_ColumnTypes]): name: Optional[str] = None column_types: Optional[Mapping[str, _ColumnTypes]] = None - extra_assets: Optional[Mapping[str, Mapping[str, _ColumnTypes]]] = None + extra_column_types: Mapping[str, Mapping[str, _ColumnTypes]] = field(default_factory=dict) @property @abstractmethod diff --git a/tests/integration/test_utils/data_source_config/postgres.py b/tests/integration/test_utils/data_source_config/postgres.py index bdb385c0bbd3..b815ac5e177a 100644 --- a/tests/integration/test_utils/data_source_config/postgres.py +++ b/tests/integration/test_utils/data_source_config/postgres.py @@ -1,61 +1,7 @@ -from random import randint from typing import Mapping, Union import pandas as pd import pytest -from sqlalchemy import Column, MetaData, Table, create_engine, insert - -# commented out types are present in SqlAlchemy 2.x but not 1.4 -from sqlalchemy.dialects.postgresql import ( - ARRAY, - BIGINT, - BIT, - BOOLEAN, - BYTEA, - CHAR, - CIDR, - # CITEXT, - DATE, - # DATEMULTIRANGE, - DATERANGE, - # DOMAIN, - DOUBLE_PRECISION, - ENUM, - FLOAT, - HSTORE, - INET, - # INT4MULTIRANGE, - INT4RANGE, - # INT8MULTIRANGE, - INT8RANGE, - INTEGER, - INTERVAL, - JSON, - JSONB, - # JSONPATH, - MACADDR, - # MACADDR8, - MONEY, - NUMERIC, - # NUMMULTIRANGE, - NUMRANGE, - OID, - REAL, - REGCLASS, - # REGCONFIG, - SMALLINT, - TEXT, - TIME, - TIMESTAMP, - # TSMULTIRANGE, - # TSQUERY, - # TSRANGE, - # TSTZMULTIRANGE, - TSTZRANGE, - TSVECTOR, - UUID, - VARCHAR, -) from great_expectations.compatibility.typing_extensions import override from great_expectations.datasource.fluent.interfaces import Batch @@ -63,63 +9,10 @@ BatchTestSetup, DataSourceTestConfig, ) - -# Sqlalchemy follows the convention of exporting all known valid types for a given dialect -# as uppercase types from the namespace `sqlalchemy.dialects. -# commented out types are present in SqlAlchemy 2.x but not 1.4 -PostgresColumnType = Union[ - type[ARRAY], - type[BIGINT], - type[BIT], - type[BOOLEAN], - type[BYTEA], - type[CHAR], - type[CIDR], - # type[CITEXT], - type[DATE], - # type[DATEMULTIRANGE], - type[DATERANGE], - # type[DOMAIN], - type[DOUBLE_PRECISION], - type[ENUM], - type[FLOAT], - type[HSTORE], - type[INET], - # type[INT4MULTIRANGE], - type[INT4RANGE], - # type[INT8MULTIRANGE], - type[INT8RANGE], - type[INTEGER], - type[INTERVAL], - type[JSON], - type[JSONB], - # type[JSONPATH], - type[MACADDR], - # type[MACADDR8], - type[MONEY], - type[NUMERIC], - # type[NUMMULTIRANGE], - type[NUMRANGE], - type[OID], - type[REAL], - type[REGCLASS], - # type[REGCONFIG], - type[SMALLINT], - type[TEXT], - type[TIME], - type[TIMESTAMP], - # type[TSMULTIRANGE], - # type[TSQUERY], - # type[TSRANGE], - # type[TSTZMULTIRANGE], - type[TSTZRANGE], - type[TSVECTOR], - type[UUID], - type[VARCHAR], -] +from tests.integration.test_utils.data_source_config.sql import SQLBatchTestSetup -class PostgreSQLDatasourceTestConfig(DataSourceTestConfig[PostgresColumnType]): +class PostgreSQLDatasourceTestConfig(DataSourceTestConfig): @property @override def label(self) -> str: @@ -144,21 +37,16 @@ def create_batch_setup( ) -class PostgresBatchTestSetup(BatchTestSetup[PostgreSQLDatasourceTestConfig]): - def __init__( - self, - config: PostgreSQLDatasourceTestConfig, - data: pd.DataFrame, - extra_data: Mapping[str, pd.DataFrame], - ) -> None: - self.table_name = f"postgres_expectation_test_table_{randint(0, 1000000)}" - self.connection_string = "postgresql+psycopg2://postgres@localhost:5432/test_ci" - self.engine = create_engine(url=self.connection_string) - self.metadata = MetaData() - self.tables: Union[list[Table], None] = None - self.schema = "public" - self.extra_data = extra_data - super().__init__(config=config, data=data) +class PostgresBatchTestSetup(SQLBatchTestSetup[PostgreSQLDatasourceTestConfig]): + @override + @property + def connection_string(self) -> str: + return "postgresql+psycopg2://postgres@localhost:5432/test_ci" + + @override + @property + def schema(self) -> Union[str, None]: + return "public" @override def make_batch(self) -> Batch: @@ -175,52 +63,3 @@ def make_batch(self) -> Batch: .add_batch_definition_whole_table(name=name) .get_batch() ) - - @override - def setup(self) -> None: - main_table = self._create_table(name=self.table_name, columns=self.get_column_types()) - extra_tables = { - table_name: self._create_table( - name=table_name, - columns=self.get_extra_column_types(table_name), - ) - for table_name in self.extra_data - } - self.tables = [main_table, *extra_tables.values()] - - self.metadata.create_all(self.engine) - with self.engine.connect() as conn: - # pd.DataFrame(...).to_dict("index") returns a dictionary where the keys are the row - # index and the values are a dict of column names mapped to column values. - # Then we pass that list of dicts in as parameters to our insert statement. - # INSERT INTO test_table (my_int_column, my_str_column) VALUES (?, ?) - # [...] [('1', 'foo'), ('2', 'bar')] - with conn.begin(): - conn.execute(insert(main_table), list(self.data.to_dict("index").values())) - for table_name, table_data in self.extra_data.items(): - conn.execute( - insert(extra_tables[table_name]), - list(table_data.to_dict("index").values()), - ) - - @override - def teardown(self) -> None: - if self.tables: - for table in self.tables: - table.drop(self.engine) - - def _create_table(self, name: str, columns: Mapping[str, PostgresColumnType]) -> Table: - column_list = [Column(col_name, col_type) for col_name, col_type in columns.items()] - return Table(name, self.metadata, *column_list, schema=self.schema) - - def get_column_types(self) -> Mapping[str, PostgresColumnType]: - if self.config.column_types is None: - return {} - return self.config.column_types - - def get_extra_column_types(self, table_name: str) -> Mapping[str, PostgresColumnType]: - extra_assets = self.config.extra_assets - if not extra_assets: - return {} - else: - return extra_assets[table_name] diff --git a/tests/integration/test_utils/data_source_config/snowflake.py b/tests/integration/test_utils/data_source_config/snowflake.py index a24acfa8864f..ed96c5b99db5 100644 --- a/tests/integration/test_utils/data_source_config/snowflake.py +++ b/tests/integration/test_utils/data_source_config/snowflake.py @@ -1,23 +1,16 @@ -from random import randint -from typing import Any, Mapping, Union +from typing import Mapping, Union import pandas as pd import pytest from great_expectations.compatibility.pydantic import BaseSettings -from great_expectations.compatibility.sqlalchemy import ( - Column, - MetaData, - Table, - create_engine, - insert, -) from great_expectations.compatibility.typing_extensions import override from great_expectations.datasource.fluent.interfaces import Batch from tests.integration.test_utils.data_source_config.base import ( BatchTestSetup, DataSourceTestConfig, ) +from tests.integration.test_utils.data_source_config.sql import SQLBatchTestSetup class SnowflakeDatasourceTestConfig(DataSourceTestConfig): @@ -68,20 +61,25 @@ def connection_string(self) -> str: ) -class SnowflakeBatchTestSetup(BatchTestSetup[SnowflakeDatasourceTestConfig]): +class SnowflakeBatchTestSetup(SQLBatchTestSetup[SnowflakeDatasourceTestConfig]): + @override + @property + def connection_string(self) -> str: + return self.snowflake_connection_config.connection_string + + @override + @property + def schema(self) -> Union[str, None]: + return self.snowflake_connection_config.SNOWFLAKE_SCHEMA + def __init__( self, config: SnowflakeDatasourceTestConfig, data: pd.DataFrame, extra_data: Mapping[str, pd.DataFrame], ) -> None: - self.table_name = f"snowflake_expectation_test_table_{randint(0, 1000000)}" self.snowflake_connection_config = SnowflakeConnectionConfig() # type: ignore[call-arg] # retrieves env vars - self.engine = create_engine(url=self.snowflake_connection_config.connection_string) - self.metadata = MetaData() - self.tables: Union[list[Table], None] = None - self.extra_data = extra_data - super().__init__(config=config, data=data) + super().__init__(config=config, data=data, extra_data=extra_data) @override def make_batch(self) -> Batch: @@ -104,59 +102,3 @@ def make_batch(self) -> Batch: .add_batch_definition_whole_table(name=name) .get_batch() ) - - @override - def setup(self) -> None: - main_table = self._create_table(name=self.table_name, columns=self.get_column_types()) - extra_tables = { - table_name: self._create_table( - name=table_name, - columns=self.get_extra_column_types(table_name), - ) - for table_name in self.extra_data - } - self.tables = [main_table, *extra_tables.values()] - - self.metadata.create_all(self.engine) - with self.engine.connect() as conn: - # pd.DataFrame(...).to_dict("index") returns a dictionary where the keys are the row - # index and the values are a dict of column names mapped to column values. - # Then we pass that list of dicts in as parameters to our insert statement. - # INSERT INTO test_table (my_int_column, my_str_column) VALUES (?, ?) - # [...] [('1', 'foo'), ('2', 'bar')] - with conn.begin(): - conn.execute(insert(main_table), list(self.data.to_dict("index").values())) - for table_name, table_data in self.extra_data.items(): - conn.execute( - insert(extra_tables[table_name]), - list(table_data.to_dict("index").values()), - ) - - @override - def teardown(self) -> None: - if self.tables: - for table in self.tables: - table.drop(self.engine) - - def _create_table(self, name: str, columns: Mapping[str, Any]) -> Table: - column_list: list[Column] = [ - Column(col_name, col_type) for col_name, col_type in columns.items() - ] - return Table( - name, - self.metadata, - *column_list, - schema=self.snowflake_connection_config.SNOWFLAKE_SCHEMA, - ) - - def get_column_types(self) -> Mapping[str, Any]: - if self.config.column_types is None: - return {} - return self.config.column_types - - def get_extra_column_types(self, table_name: str) -> Mapping[str, Any]: - extra_assets = self.config.extra_assets - if not extra_assets: - return {} - else: - return extra_assets[table_name] diff --git a/tests/integration/test_utils/data_source_config/sql.py b/tests/integration/test_utils/data_source_config/sql.py new file mode 100644 index 000000000000..611ffc8f7043 --- /dev/null +++ b/tests/integration/test_utils/data_source_config/sql.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Generic, List, Mapping, Type, Union + +from typing_extensions import override + +from great_expectations.compatibility.sqlalchemy import ( + Column, + MetaData, + Table, + create_engine, + insert, + sqltypes, +) +from tests.integration.test_utils.data_source_config.base import BatchTestSetup, _ConfigT + +if TYPE_CHECKING: + import pandas as pd + + from great_expectations.compatibility.sqlalchemy import TypeEngine + + +@dataclass +class TableData: + name: str + df: pd.DataFrame + column_types: Dict[str, TypeEngine] + table: Union[Table, None] = None + + +class SQLBatchTestSetup(BatchTestSetup, ABC, Generic[_ConfigT]): + @property + @abstractmethod + def connection_string(self) -> str: + """Connection string used to connect to SQL backend.""" + + @property + @abstractmethod + def schema(self) -> Union[str, None]: + """Schema -- if any -- to use when connecting to SQL backend.""" + + @property + def inferrable_types_lookup(self) -> Dict[Type, TypeEngine]: + """Dict of Python type keys mapped to SQL dialect-specific SqlAlchemy types.""" + # implementations of the class can override this if more specific types are required + return { + str: sqltypes.VARCHAR, # type: ignore[dict-item] + int: sqltypes.INTEGER, # type: ignore[dict-item] + float: sqltypes.DECIMAL, # type: ignore[dict-item] + bool: sqltypes.BOOLEAN, # type: ignore[dict-item] + } + + def __init__( + self, + config: _ConfigT, + data: pd.DataFrame, + extra_data: Mapping[str, pd.DataFrame], + ) -> None: + self.extra_data = extra_data + self.table_name = f"{config.label}_expectation_test_table_{self._random_resource_name()}" + self.engine = create_engine(url=self.connection_string) + self.metadata = MetaData() + self.tables: List[Table] = [] + super().__init__(config, data) + + @override + def setup(self) -> None: + main_table_data = TableData( + name=self.table_name, df=self.data, column_types=self.config.column_types or {} + ) + extra_table_data = [ + TableData(name=name, df=df, column_types=self.config.extra_column_types.get(name, {})) + for name, df in self.extra_data.items() + ] + all_table_data = [main_table_data, *extra_table_data] + + # create tables + for table_data in all_table_data: + columns = self.get_column_types(table_data) + table = self.create_table(table_data.name, columns=columns) + self.tables.append(table) + table_data.table = table + self.metadata.create_all(self.engine) + + # insert data + with self.engine.connect() as conn, conn.begin(): + for table_data in all_table_data: + if table_data.table is None: + raise RuntimeError("Table must be created before data can be loaded.") + # pd.DataFrame(...).to_dict("index") returns a dictionary where the keys are the row + # index and the values are a dict of column names mapped to column values. + # Then we pass that list of dicts in as parameters to our insert statement. + # INSERT INTO test_table (my_int_column, my_str_column) VALUES (?, ?) + # [...] [('1', 'foo'), ('2', 'bar')] + conn.execute( + insert(table_data.table), list(table_data.df.to_dict("index").values()) + ) + + @override + def teardown(self) -> None: + for table in self.tables: + table.drop(self.engine) + + def create_table(self, name: str, columns: Mapping[str, TypeEngine]) -> Table: + column_list = [Column(col_name, col_type) for col_name, col_type in columns.items()] + return Table(name, self.metadata, *column_list, schema=self.schema) + + def get_column_types( + self, + table_data: TableData, + ) -> Mapping[str, TypeEngine]: + column_types = self.infer_column_types(table_data.df) + # prefer explicit types if they're provided + column_types.update(table_data.column_types) + untyped_columns = set(table_data.df.columns) - set(column_types.keys()) + if untyped_columns: + config_class_name = self.config.__class__.__name__ + message = ( + f"Unable to infer types for the following column(s): " + f"{', '.join(untyped_columns)}. \n" + f"Please provide the missing types as the `column_types` " + f"parameter when \ninstantiating {config_class_name}." + ) + raise RuntimeError(message) + return column_types + + def infer_column_types(self, data: pd.DataFrame) -> Dict[str, TypeEngine]: + inferred_column_types: Dict[str, TypeEngine] = {} + for column, value_list in data.to_dict("list").items(): + python_type = type(value_list[0]) + if not all(isinstance(val, python_type) for val in value_list): + raise RuntimeError( + f"Cannot infer type of column {column}. " + "Please provide an explicit column type in the test config." + ) + inferred_type = self.inferrable_types_lookup.get(python_type) + if inferred_type: + inferred_column_types[str(column)] = inferred_type + return inferred_column_types