diff --git a/.changes/unreleased/Features-20240220-195925.yaml b/.changes/unreleased/Features-20240220-195925.yaml new file mode 100644 index 000000000..c5d86ab7c --- /dev/null +++ b/.changes/unreleased/Features-20240220-195925.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Implement spark__safe_cast and add functional tests for unit testing +time: 2024-02-20T19:59:25.907821-05:00 +custom: + Author: michelleark + Issue: "987" diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 436cb1e92..15f9cf2c2 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -112,15 +112,27 @@ async def test_spark(test_args): .with_exec(["./scripts/install_os_reqs.sh"]) # install dbt-spark + python deps .with_directory("/src", req_files) - .with_directory("src/dbt", dbt_spark_dir) - .with_directory("src/tests", test_dir) - .with_workdir("/src") .with_exec(["pip", "install", "-U", "pip"]) + .with_workdir("/src") .with_exec(["pip", "install", "-r", "requirements.txt"]) .with_exec(["pip", "install", "-r", "dev-requirements.txt"]) + ) + + # install local dbt-spark changes + tst_container = ( + tst_container.with_workdir("/") + .with_directory("src/dbt", dbt_spark_dir) + .with_workdir("/src") .with_exec(["pip", "install", "-e", "."]) ) + # install local test changes + tst_container = ( + tst_container.with_workdir("/") + .with_directory("src/tests", test_dir) + .with_workdir("/src") + ) + if test_profile == "apache_spark": spark_ctr, spark_host = get_spark_container(client) tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index bf9f63cf9..a6404a2de 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -387,6 +387,7 @@ "identifier": tmp_identifier }) -%} + {%- set tmp_relation = tmp_relation.include(database=false, schema=false) -%} {% do return(tmp_relation) %} {% endmacro %} diff --git a/dbt/include/spark/macros/utils/safe_cast.sql b/dbt/include/spark/macros/utils/safe_cast.sql new file mode 100644 index 000000000..3ce5820a8 --- /dev/null +++ b/dbt/include/spark/macros/utils/safe_cast.sql @@ -0,0 +1,8 @@ +{% macro spark__safe_cast(field, type) %} +{%- set field_clean = field.strip('"').strip("'") if (cast_from_string_unsupported_for(type) and field is string) else field -%} +cast({{field_clean}} as {{type}}) +{% endmacro %} + +{% macro cast_from_string_unsupported_for(type) %} + {{ return(type.lower().startswith('struct') or type.lower().startswith('array') or type.lower().startswith('map')) }} +{% endmacro %} diff --git a/dev-requirements.txt b/dev-requirements.txt index 28a626fc3..8f674d84b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,8 @@ # install latest changes in dbt-core # TODO: how to automate switching from develop to version branches? +git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core +git+https://github.com/dbt-labs/dbt-common.git +git+https://github.com/dbt-labs/dbt-adapters.git git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter # if version 1.x or greater -> pin to major version diff --git a/tests/functional/adapter/unit_testing/test_unit_testing.py b/tests/functional/adapter/unit_testing/test_unit_testing.py new file mode 100644 index 000000000..b70c581d1 --- /dev/null +++ b/tests/functional/adapter/unit_testing/test_unit_testing.py @@ -0,0 +1,34 @@ +import pytest + +from dbt.tests.adapter.unit_testing.test_types import BaseUnitTestingTypes +from dbt.tests.adapter.unit_testing.test_case_insensitivity import BaseUnitTestCaseInsensivity +from dbt.tests.adapter.unit_testing.test_invalid_input import BaseUnitTestInvalidInput + + +class TestSparkUnitTestingTypes(BaseUnitTestingTypes): + @pytest.fixture + def data_types(self): + # sql_value, yaml_value + return [ + ["1", "1"], + ["2.0", "2.0"], + ["'12345'", "12345"], + ["'string'", "string"], + ["true", "true"], + ["date '2011-11-11'", "2011-11-11"], + ["timestamp '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"], + ["array(1, 2, 3)", "'array(1, 2, 3)'"], + [ + "map('10', 't', '15', 'f', '20', NULL)", + """'map("10", "t", "15", "f", "20", NULL)'""", + ], + ['named_struct("a", 1, "b", 2, "c", 3)', """'named_struct("a", 1, "b", 2, "c", 3)'"""], + ] + + +class TestSparkUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity): + pass + + +class TestSparkUnitTestInvalidInput(BaseUnitTestInvalidInput): + pass diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index c1a0397bd..476ffb474 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -1,19 +1,27 @@ -from multiprocessing import Lock - +import time import pytest -_db_start_lock = Lock() -_DB_CLUSTER_STARTED = False + +def _wait_for_databricks_cluster(project): + """ + It takes roughly 3min for the cluster to start, to be safe we'll wait for 5min + """ + for _ in range(60): + try: + project.run_sql("SELECT 1", fetch=True) + return + except Exception: + time.sleep(10) + + raise Exception("Databricks cluster did not start in time") # Running this should prevent tests from needing to be retried because the Databricks cluster isn't available @pytest.fixture(scope="class", autouse=True) def start_databricks_cluster(project, request): - global _DB_CLUSTER_STARTED profile_type = request.config.getoption("--profile") - with _db_start_lock: - if "databricks" in profile_type and not _DB_CLUSTER_STARTED: - print("Starting Databricks cluster") - project.run_sql("SELECT 1") - _DB_CLUSTER_STARTED = True + if "databricks" in profile_type: + _wait_for_databricks_cluster(project) + + yield 1