From 5e11e7c8c77d496077bdc6505bce4fd5d9613158 Mon Sep 17 00:00:00 2001 From: William Shin Date: Wed, 18 Dec 2024 14:21:23 -0800 Subject: [PATCH] [BUGFIX] `Databricks` Fix Type Translation - `ExpectColumnValuesToBeInTypeList` and `ExpectColumnValuesToBeInType` (#10791) --- ...expect_column_values_to_be_in_type_list.py | 2 +- .../expect_column_values_to_be_of_type.py | 2 +- .../expectations/metrics/util.py | 15 +- ...expect_column_values_to_be_in_type_list.py | 212 ++++++++++++++++++ ...test_expect_column_values_to_be_of_type.py | 3 +- 5 files changed, 224 insertions(+), 10 deletions(-) diff --git a/great_expectations/expectations/core/expect_column_values_to_be_in_type_list.py b/great_expectations/expectations/core/expect_column_values_to_be_in_type_list.py index 0f9d82e42771..e0bd10f36942 100644 --- a/great_expectations/expectations/core/expect_column_values_to_be_in_type_list.py +++ b/great_expectations/expectations/core/expect_column_values_to_be_in_type_list.py @@ -458,7 +458,7 @@ def _validate_pandas( # noqa: C901, PLR0912 def _validate_sqlalchemy(self, actual_column_type, expected_types_list, execution_engine): if expected_types_list is None: success = True - elif execution_engine.dialect_name == GXSqlDialect.SNOWFLAKE: + elif execution_engine.dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]: success = isinstance(actual_column_type, str) and any( actual_column_type.lower() == expected_type.lower() for expected_type in expected_types_list diff --git a/great_expectations/expectations/core/expect_column_values_to_be_of_type.py b/great_expectations/expectations/core/expect_column_values_to_be_of_type.py index 5429c3979882..647053760d6d 100644 --- a/great_expectations/expectations/core/expect_column_values_to_be_of_type.py +++ b/great_expectations/expectations/core/expect_column_values_to_be_of_type.py @@ -412,7 +412,7 @@ def _validate_sqlalchemy(self, actual_column_type, expected_type, execution_engi if expected_type is None: success = True - elif execution_engine.dialect_name == GXSqlDialect.SNOWFLAKE: + elif execution_engine.dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]: success = ( isinstance(actual_column_type, str) and actual_column_type.lower() == expected_type.lower() diff --git a/great_expectations/expectations/metrics/util.py b/great_expectations/expectations/metrics/util.py index c3dc274bfc94..c89f2e870477 100644 --- a/great_expectations/expectations/metrics/util.py +++ b/great_expectations/expectations/metrics/util.py @@ -414,16 +414,19 @@ def get_sqlalchemy_column_metadata( # noqa: C901 ) dialect_name = execution_engine.dialect.name - if dialect_name == GXSqlDialect.SNOWFLAKE: + if dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]: # WARNING: Do not alter columns in place, as they are cached on the inspector columns_copy = [column.copy() for column in columns] for column in columns_copy: column["type"] = column["type"].compile(dialect=execution_engine.dialect) - return [ - # TODO: SmartColumn should know the dialect and do lookups based on that - CaseInsensitiveNameDict(column) - for column in columns_copy - ] + if dialect_name == GXSqlDialect.SNOWFLAKE: + return [ + # TODO: SmartColumn should know the dialect and do lookups based on that + CaseInsensitiveNameDict(column) + for column in columns_copy + ] + else: + return columns_copy return columns except AttributeError as e: diff --git a/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_in_type_list.py b/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_in_type_list.py index c893f71327b9..69bbbb675dc2 100644 --- a/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_in_type_list.py +++ b/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_in_type_list.py @@ -1,9 +1,14 @@ import pandas as pd import pytest import sqlalchemy.types as sqltypes +from packaging import version import great_expectations.expectations as gxe +from great_expectations.compatibility.databricks import DATABRICKS_TYPES from great_expectations.compatibility.snowflake import SNOWFLAKE_TYPES +from great_expectations.compatibility.sqlalchemy import ( + sqlalchemy as sa, +) from great_expectations.core.result_format import ResultFormat from great_expectations.datasource.fluent.interfaces import Batch from tests.integration.conftest import parameterize_batch_for_data_sources @@ -379,3 +384,210 @@ def test_success_complete_snowflake( assert isinstance(result_dict["observed_value"], str) assert isinstance(expectation.type_list, list) assert result_dict["observed_value"] in expectation.type_list + + +@pytest.mark.parametrize( + "expectation", + [ + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="STRING", type_list=["STRING"]), + id="STRING", + ), + # SqlA Text gets converted to Databricks STRING + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="TEXT", type_list=["STRING"]), + id="TEXT", + ), + # SqlA UNICODE gets converted to Databricks STRING + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="UNICODE", type_list=["STRING"]), + id="UNICODE", + ), + # SqlA UNICODE_TEXT gets converted to Databricks STRING + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="UNICODE_TEXT", type_list=["STRING"]), + id="UNICODE_TEXT", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="BOOLEAN", type_list=["BOOLEAN"]), + id="BOOLEAN", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList( + column="DECIMAL", type_list=["DECIMAL", "DECIMAL(10, 0)"] + ), + id="DECIMAL", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="DATE", type_list=["DATE"]), + id="DATE", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="TIMESTAMP", type_list=["TIMESTAMP"]), + id="TIMESTAMP", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList( + column="TIMESTAMP_NTZ", type_list=["TIMESTAMP_NTZ"] + ), + id="TIMESTAMP_NTZ", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="FLOAT", type_list=["FLOAT"]), + id="FLOAT", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="INT", type_list=["INT"]), + id="INT", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList(column="TINYINT", type_list=["TINYINT"]), + id="TINYINT", + ), + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList( + column="DECIMAL", type_list=["DECIMAL", "DECIMAL(10, 0)"] + ), + id="DECIMAL", + ), + # SqlA Time gets converted to Databricks STRING, + # but is not supported by our testing framework + # pytest.param( + # gxe.ExpectColumnValuesToBeInTypeList(column="TIME", type_list=["STRING"]), + # id="TIME", + # ), + # SqlA UUID gets converted to Databricks STRING, + # but is not supported by our testing framework. + # pytest.param( + # gxe.ExpectColumnValuesToBeInTypeList(column="UUID", type_list=["STRING"]), + # id="UUID", + # ) + ], +) +@parameterize_batch_for_data_sources( + data_source_configs=[ + DatabricksDatasourceTestConfig( + column_types={ + "STRING": DATABRICKS_TYPES.STRING, + "TEXT": sqltypes.Text, + "UNICODE": sqltypes.Unicode, + "UNICODE_TEXT": sqltypes.UnicodeText, + "BIGINT": sqltypes.BigInteger, + "BOOLEAN": sqltypes.BOOLEAN, + "DATE": sqltypes.DATE, + "TIMESTAMP_NTZ": DATABRICKS_TYPES.TIMESTAMP_NTZ, + "TIMESTAMP": DATABRICKS_TYPES.TIMESTAMP, + "FLOAT": sqltypes.Float, + "INT": sqltypes.Integer, + "DECIMAL": sqltypes.Numeric, + "SMALLINT": sqltypes.SmallInteger, + "TINYINT": DATABRICKS_TYPES.TINYINT, + # "TIME": sqltypes.Time, + # "UUID": sqltypes.UUID, + } + ) + ], + data=pd.DataFrame( + { + "STRING": ["a", "b", "c"], + "TEXT": ["a", "b", "c"], + "UNICODE": ["\u00e9", "\u00e9", "\u00e9"], + "UNICODE_TEXT": ["a", "b", "c"], + "BIGINT": [1111, 2222, 3333], + "BOOLEAN": [True, True, False], + "DATE": [ + "2021-01-01", + "2021-01-02", + "2021-01-03", + ], + "TIMESTAMP_NTZ": [ + "2021-01-01 00:00:00", + "2021-01-02 00:00:00", + "2021-01-03 00:00:00", + ], + "TIMESTAMP": [ + "2021-01-01 00:00:00", + "2021-01-02 00:00:00", + "2021-01-03 00:00:00", + ], + "DOUBLE": [1.0, 2.0, 3.0], + "FLOAT": [1.0, 2.0, 3.0], + "INT": [1, 2, 3], + "DECIMAL": [1.1, 2.2, 3.3], + "SMALLINT": [1, 2, 3], + # "TIME": [ + # sa.Time("22:17:33.123456"), + # sa.Time("22:17:33.123456"), + # sa.Time("22:17:33.123456"), + # ], + # "UUID": [ + # uuid.UUID("905993ea-f50e-4284-bea0-5be3f0ed7031"), + # uuid.UUID("9406b631-fa2f-41cf-b666-f9a2ac3118c1"), + # uuid.UUID("47538f05-32e3-4594-80e2-0b3b33257ae7") + # ], + }, + dtype="object", + ), +) +def test_success_complete_databricks( + batch_for_datasource: Batch, expectation: gxe.ExpectColumnValuesToBeInTypeList +) -> None: + result = batch_for_datasource.validate(expectation, result_format=ResultFormat.COMPLETE) + result_dict = result.to_json_dict()["result"] + + assert result.success + assert isinstance(result_dict, dict) + assert isinstance(result_dict["observed_value"], str) + assert isinstance(expectation.type_list, list) + assert result_dict["observed_value"] in expectation.type_list + + +if version.parse(sa.__version__) >= version.parse("2.0.0"): + # Note: why not use pytest.skip? + # the import of `sqltypes.Double` is only possible in sqlalchemy >= 2.0.0 + # the import is done as part of the instantiation of the test, which includes + # processing the pytest.skip() statement. This way, we skip the instantiation + # of the test entirely. + @pytest.mark.parametrize( + "expectation", + [ + pytest.param( + gxe.ExpectColumnValuesToBeInTypeList( + column="DOUBLE", type_list=["DOUBLE", "FLOAT"] + ), + id="DOUBLE", + ) + ], + ) + @parameterize_batch_for_data_sources( + data_source_configs=[ + DatabricksDatasourceTestConfig( + column_types={ + "DOUBLE": sqltypes.Double, + } + ) + ], + data=pd.DataFrame( + { + "DOUBLE": [1.0, 2.0, 3.0], + }, + dtype="object", + ), + ) + def test_success_complete_databricks_double_type_only( + batch_for_datasource: Batch, expectation: gxe.ExpectColumnValuesToBeInTypeList + ) -> None: + """What does this test and why? + + Databricks mostly uses SqlA types directly, but the double type is + only available after sqlalchemy 2.0. We therefore split up the test + into 2 parts, with this test being skipped if the SA version is too low. + """ + result = batch_for_datasource.validate(expectation, result_format=ResultFormat.COMPLETE) + result_dict = result.to_json_dict()["result"] + + assert result.success + assert isinstance(result_dict, dict) + assert isinstance(result_dict["observed_value"], str) + assert isinstance(expectation.type_list, list) + assert result_dict["observed_value"] in expectation.type_list diff --git a/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_of_type.py b/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_of_type.py index 719931c5e099..ffcf64dab0b7 100644 --- a/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_of_type.py +++ b/tests/integration/data_sources_and_expectations/expectations/test_expect_column_values_to_be_of_type.py @@ -66,13 +66,12 @@ def test_success_for_type__INTEGER(batch_for_datasource: Batch) -> None: assert result.success -@pytest.mark.xfail @parameterize_batch_for_data_sources( data_source_configs=[DatabricksDatasourceTestConfig()], data=DATA, ) def test_success_for_type__Integer(batch_for_datasource: Batch) -> None: - expectation = gxe.ExpectColumnValuesToBeOfType(column=INTEGER_COLUMN, type_="Integer") + expectation = gxe.ExpectColumnValuesToBeOfType(column=INTEGER_COLUMN, type_="INT") result = batch_for_datasource.validate(expectation) assert result.success