Skip to content

Commit

Permalink
Merge branch 'main' into config/remove-jira
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare authored Feb 27, 2024
2 parents 30076fd + ef91425 commit a1b62b2
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 13 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240220-195925.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Implement spark__safe_cast and add functional tests for unit testing
time: 2024-02-20T19:59:25.907821-05:00
custom:
Author: michelleark
Issue: "987"
18 changes: 15 additions & 3 deletions dagger/run_dbt_spark_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,27 @@ async def test_spark(test_args):
.with_exec(["./scripts/install_os_reqs.sh"])
# install dbt-spark + python deps
.with_directory("/src", req_files)
.with_directory("src/dbt", dbt_spark_dir)
.with_directory("src/tests", test_dir)
.with_workdir("/src")
.with_exec(["pip", "install", "-U", "pip"])
.with_workdir("/src")
.with_exec(["pip", "install", "-r", "requirements.txt"])
.with_exec(["pip", "install", "-r", "dev-requirements.txt"])
)

# install local dbt-spark changes
tst_container = (
tst_container.with_workdir("/")
.with_directory("src/dbt", dbt_spark_dir)
.with_workdir("/src")
.with_exec(["pip", "install", "-e", "."])
)

# install local test changes
tst_container = (
tst_container.with_workdir("/")
.with_directory("src/tests", test_dir)
.with_workdir("/src")
)

if test_profile == "apache_spark":
spark_ctr, spark_host = get_spark_container(client)
tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr)
Expand Down
1 change: 1 addition & 0 deletions dbt/include/spark/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@
"identifier": tmp_identifier
}) -%}

{%- set tmp_relation = tmp_relation.include(database=false, schema=false) -%}
{% do return(tmp_relation) %}
{% endmacro %}

Expand Down
8 changes: 8 additions & 0 deletions dbt/include/spark/macros/utils/safe_cast.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{% macro spark__safe_cast(field, type) %}
{%- set field_clean = field.strip('"').strip("'") if (cast_from_string_unsupported_for(type) and field is string) else field -%}
cast({{field_clean}} as {{type}})
{% endmacro %}

{% macro cast_from_string_unsupported_for(type) %}
{{ return(type.lower().startswith('struct') or type.lower().startswith('array') or type.lower().startswith('map')) }}
{% endmacro %}
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# install latest changes in dbt-core
# TODO: how to automate switching from develop to version branches?
git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core
git+https://github.com/dbt-labs/dbt-common.git
git+https://github.com/dbt-labs/dbt-adapters.git
git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter

# if version 1.x or greater -> pin to major version
Expand Down
34 changes: 34 additions & 0 deletions tests/functional/adapter/unit_testing/test_unit_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from dbt.tests.adapter.unit_testing.test_types import BaseUnitTestingTypes
from dbt.tests.adapter.unit_testing.test_case_insensitivity import BaseUnitTestCaseInsensivity
from dbt.tests.adapter.unit_testing.test_invalid_input import BaseUnitTestInvalidInput


class TestSparkUnitTestingTypes(BaseUnitTestingTypes):
@pytest.fixture
def data_types(self):
# sql_value, yaml_value
return [
["1", "1"],
["2.0", "2.0"],
["'12345'", "12345"],
["'string'", "string"],
["true", "true"],
["date '2011-11-11'", "2011-11-11"],
["timestamp '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["array(1, 2, 3)", "'array(1, 2, 3)'"],
[
"map('10', 't', '15', 'f', '20', NULL)",
"""'map("10", "t", "15", "f", "20", NULL)'""",
],
['named_struct("a", 1, "b", 2, "c", 3)', """'named_struct("a", 1, "b", 2, "c", 3)'"""],
]


class TestSparkUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity):
pass


class TestSparkUnitTestInvalidInput(BaseUnitTestInvalidInput):
pass
28 changes: 18 additions & 10 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from multiprocessing import Lock

import time
import pytest

_db_start_lock = Lock()
_DB_CLUSTER_STARTED = False

def _wait_for_databricks_cluster(project):
"""
It takes roughly 3min for the cluster to start, to be safe we'll wait for 5min
"""
for _ in range(60):
try:
project.run_sql("SELECT 1", fetch=True)
return
except Exception:
time.sleep(10)

raise Exception("Databricks cluster did not start in time")


# Running this should prevent tests from needing to be retried because the Databricks cluster isn't available
@pytest.fixture(scope="class", autouse=True)
def start_databricks_cluster(project, request):
global _DB_CLUSTER_STARTED
profile_type = request.config.getoption("--profile")
with _db_start_lock:
if "databricks" in profile_type and not _DB_CLUSTER_STARTED:
print("Starting Databricks cluster")
project.run_sql("SELECT 1")

_DB_CLUSTER_STARTED = True
if "databricks" in profile_type:
_wait_for_databricks_cluster(project)

yield 1

0 comments on commit a1b62b2

Please sign in to comment.