Skip to content

Commit

Permalink
[FEATURE] Expectations tests against SQL backends infer column types (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-stauffer authored Nov 4, 2024
1 parent 424d9cf commit 61a4b3d
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 271 deletions.
4 changes: 4 additions & 0 deletions great_expectations/compatibility/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@
except (ImportError, AttributeError):
__version__ = None

try:
from sqlalchemy.sql.type_api import TypeEngine
except (ImportError, AttributeError):
TypeEngine = SQLALCHEMY_NOT_IMPORTED # type: ignore[misc,assignment]

try:
from sqlalchemy.sql import sqltypes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas as pd

import great_expectations.expectations as gxe
from great_expectations.compatibility.snowflake import SNOWFLAKE_TYPES
from great_expectations.compatibility.sqlalchemy import sqltypes
from tests.integration.conftest import parameterize_batch_for_data_sources
from tests.integration.test_utils.data_source_config import (
Expand All @@ -18,8 +17,8 @@
data_source_configs=[
PandasDataFrameDatasourceTestConfig(),
PandasFilesystemCsvDatasourceTestConfig(),
PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}),
SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
],
data=pd.DataFrame({"a": [1, 2]}),
)
Expand Down Expand Up @@ -87,8 +86,8 @@ def test_expect_column_max_to_be_between__date(batch_for_datasource) -> None:
data_source_configs=[
PandasDataFrameDatasourceTestConfig(),
PandasFilesystemCsvDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}),
PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
],
data=pd.DataFrame({"a": [1, 2]}),
)
Expand All @@ -102,8 +101,8 @@ def test_expect_column_max_to_be_between(batch_for_datasource) -> None:
data_source_configs=[
PandasDataFrameDatasourceTestConfig(),
PandasFilesystemCsvDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}),
PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
],
data=pd.DataFrame({"a": [1, 2]}),
)
Expand All @@ -117,8 +116,8 @@ def test_expect_column_to_exist(batch_for_datasource):
data_source_configs=[
PandasDataFrameDatasourceTestConfig(),
PandasFilesystemCsvDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}),
PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
],
data=pd.DataFrame({"a": [1, 2]}),
)
Expand All @@ -132,8 +131,8 @@ def test_expect_column_values_to_not_be_null(batch_for_datasource):
data_source_configs=[
PandasDataFrameDatasourceTestConfig(),
PandasFilesystemCsvDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(column_types={"a": SNOWFLAKE_TYPES.NUMBER}),
PostgreSQLDatasourceTestConfig(column_types={"a": sqltypes.INTEGER}),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
],
data=pd.DataFrame({"a": [1, 2, 3, 4]}),
)
Expand All @@ -145,12 +144,7 @@ def test_expect_column_mean_to_be_between(batch_for_datasource):

class TestExpectTableRowCountToEqualOtherTable:
@parameterize_batch_for_data_sources(
data_source_configs=[
PostgreSQLDatasourceTestConfig(
column_types={"col_a": sqltypes.INTEGER},
extra_assets={"test_table_two": {"col_b": sqltypes.VARCHAR}},
),
],
data_source_configs=[PostgreSQLDatasourceTestConfig(), SnowflakeDatasourceTestConfig()],
data=pd.DataFrame({"a": [1, 2, 3, 4]}),
extra_data={"test_table_two": pd.DataFrame({"col_b": ["a", "b", "c", "d"]})},
)
Expand All @@ -161,10 +155,8 @@ def test_success(self, batch_for_datasource):

@parameterize_batch_for_data_sources(
data_source_configs=[
PostgreSQLDatasourceTestConfig(
column_types={"col_a": sqltypes.INTEGER},
extra_assets={"test_table_two": {"col_b": sqltypes.VARCHAR}},
),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
],
data=pd.DataFrame({"a": [1, 2, 3, 4]}),
extra_data={"test_table_two": pd.DataFrame({"col_b": ["just_this_one!"]})},
Expand All @@ -180,9 +172,8 @@ def test_different_counts(self, batch_for_datasource):

@parameterize_batch_for_data_sources(
data_source_configs=[
PostgreSQLDatasourceTestConfig(
column_types={"col_a": sqltypes.INTEGER},
),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
],
data=pd.DataFrame({"a": [1, 2, 3, 4]}),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_utils/data_source_config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import string
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Generic, Mapping, Optional, TypeVar

Expand All @@ -23,7 +23,7 @@
class DataSourceTestConfig(ABC, Generic[_ColumnTypes]):
name: Optional[str] = None
column_types: Optional[Mapping[str, _ColumnTypes]] = None
extra_assets: Optional[Mapping[str, Mapping[str, _ColumnTypes]]] = None
extra_column_types: Mapping[str, Mapping[str, _ColumnTypes]] = field(default_factory=dict)

@property
@abstractmethod
Expand Down
185 changes: 12 additions & 173 deletions tests/integration/test_utils/data_source_config/postgres.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,18 @@
from random import randint
from typing import Mapping, Union

import pandas as pd
import pytest
from sqlalchemy import Column, MetaData, Table, create_engine, insert

# commented out types are present in SqlAlchemy 2.x but not 1.4
from sqlalchemy.dialects.postgresql import (
ARRAY,
BIGINT,
BIT,
BOOLEAN,
BYTEA,
CHAR,
CIDR,
# CITEXT,
DATE,
# DATEMULTIRANGE,
DATERANGE,
# DOMAIN,
DOUBLE_PRECISION,
ENUM,
FLOAT,
HSTORE,
INET,
# INT4MULTIRANGE,
INT4RANGE,
# INT8MULTIRANGE,
INT8RANGE,
INTEGER,
INTERVAL,
JSON,
JSONB,
# JSONPATH,
MACADDR,
# MACADDR8,
MONEY,
NUMERIC,
# NUMMULTIRANGE,
NUMRANGE,
OID,
REAL,
REGCLASS,
# REGCONFIG,
SMALLINT,
TEXT,
TIME,
TIMESTAMP,
# TSMULTIRANGE,
# TSQUERY,
# TSRANGE,
# TSTZMULTIRANGE,
TSTZRANGE,
TSVECTOR,
UUID,
VARCHAR,
)

from great_expectations.compatibility.typing_extensions import override
from great_expectations.datasource.fluent.interfaces import Batch
from tests.integration.test_utils.data_source_config.base import (
BatchTestSetup,
DataSourceTestConfig,
)

# Sqlalchemy follows the convention of exporting all known valid types for a given dialect
# as uppercase types from the namespace `sqlalchemy.dialects.<dialect>
# commented out types are present in SqlAlchemy 2.x but not 1.4
PostgresColumnType = Union[
type[ARRAY],
type[BIGINT],
type[BIT],
type[BOOLEAN],
type[BYTEA],
type[CHAR],
type[CIDR],
# type[CITEXT],
type[DATE],
# type[DATEMULTIRANGE],
type[DATERANGE],
# type[DOMAIN],
type[DOUBLE_PRECISION],
type[ENUM],
type[FLOAT],
type[HSTORE],
type[INET],
# type[INT4MULTIRANGE],
type[INT4RANGE],
# type[INT8MULTIRANGE],
type[INT8RANGE],
type[INTEGER],
type[INTERVAL],
type[JSON],
type[JSONB],
# type[JSONPATH],
type[MACADDR],
# type[MACADDR8],
type[MONEY],
type[NUMERIC],
# type[NUMMULTIRANGE],
type[NUMRANGE],
type[OID],
type[REAL],
type[REGCLASS],
# type[REGCONFIG],
type[SMALLINT],
type[TEXT],
type[TIME],
type[TIMESTAMP],
# type[TSMULTIRANGE],
# type[TSQUERY],
# type[TSRANGE],
# type[TSTZMULTIRANGE],
type[TSTZRANGE],
type[TSVECTOR],
type[UUID],
type[VARCHAR],
]
from tests.integration.test_utils.data_source_config.sql import SQLBatchTestSetup


class PostgreSQLDatasourceTestConfig(DataSourceTestConfig[PostgresColumnType]):
class PostgreSQLDatasourceTestConfig(DataSourceTestConfig):
@property
@override
def label(self) -> str:
Expand All @@ -144,21 +37,16 @@ def create_batch_setup(
)


class PostgresBatchTestSetup(BatchTestSetup[PostgreSQLDatasourceTestConfig]):
def __init__(
self,
config: PostgreSQLDatasourceTestConfig,
data: pd.DataFrame,
extra_data: Mapping[str, pd.DataFrame],
) -> None:
self.table_name = f"postgres_expectation_test_table_{randint(0, 1000000)}"
self.connection_string = "postgresql+psycopg2://postgres@localhost:5432/test_ci"
self.engine = create_engine(url=self.connection_string)
self.metadata = MetaData()
self.tables: Union[list[Table], None] = None
self.schema = "public"
self.extra_data = extra_data
super().__init__(config=config, data=data)
class PostgresBatchTestSetup(SQLBatchTestSetup[PostgreSQLDatasourceTestConfig]):
@override
@property
def connection_string(self) -> str:
return "postgresql+psycopg2://postgres@localhost:5432/test_ci"

@override
@property
def schema(self) -> Union[str, None]:
return "public"

@override
def make_batch(self) -> Batch:
Expand All @@ -175,52 +63,3 @@ def make_batch(self) -> Batch:
.add_batch_definition_whole_table(name=name)
.get_batch()
)

@override
def setup(self) -> None:
main_table = self._create_table(name=self.table_name, columns=self.get_column_types())
extra_tables = {
table_name: self._create_table(
name=table_name,
columns=self.get_extra_column_types(table_name),
)
for table_name in self.extra_data
}
self.tables = [main_table, *extra_tables.values()]

self.metadata.create_all(self.engine)
with self.engine.connect() as conn:
# pd.DataFrame(...).to_dict("index") returns a dictionary where the keys are the row
# index and the values are a dict of column names mapped to column values.
# Then we pass that list of dicts in as parameters to our insert statement.
# INSERT INTO test_table (my_int_column, my_str_column) VALUES (?, ?)
# [...] [('1', 'foo'), ('2', 'bar')]
with conn.begin():
conn.execute(insert(main_table), list(self.data.to_dict("index").values()))
for table_name, table_data in self.extra_data.items():
conn.execute(
insert(extra_tables[table_name]),
list(table_data.to_dict("index").values()),
)

@override
def teardown(self) -> None:
if self.tables:
for table in self.tables:
table.drop(self.engine)

def _create_table(self, name: str, columns: Mapping[str, PostgresColumnType]) -> Table:
column_list = [Column(col_name, col_type) for col_name, col_type in columns.items()]
return Table(name, self.metadata, *column_list, schema=self.schema)

def get_column_types(self) -> Mapping[str, PostgresColumnType]:
if self.config.column_types is None:
return {}
return self.config.column_types

def get_extra_column_types(self, table_name: str) -> Mapping[str, PostgresColumnType]:
extra_assets = self.config.extra_assets
if not extra_assets:
return {}
else:
return extra_assets[table_name]
Loading

0 comments on commit 61a4b3d

Please sign in to comment.