Skip to content

Commit

Permalink
Feature/http odbc conn extra (#1093)
Browse files Browse the repository at this point in the history
* add support for extra odbc connection properties

* clean  up

* fix typo in test_incremental_on_schema_change.py

* fix formatting

* changelog

* Add unit test and refactor unit test fixtures

* update changie

* update changie

* remove holdover code

* remove dbt-core ref

---------

Co-authored-by: nilan3 <nilanthanb1994@gmail.com>
Co-authored-by: Mike Alfare <mike.alfare@dbtlabs.com>
  • Loading branch information
3 people authored Sep 12, 2024
1 parent 2124423 commit 3fc624c
Show file tree
Hide file tree
Showing 16 changed files with 336 additions and 201 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240910-175846.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support custom ODBC connection parameters via `connection_string_suffix` config
time: 2024-09-10T17:58:46.141332-04:00
custom:
Author: colin-rogers-dbt jpoley nilan3
Issue: "1092"
1 change: 1 addition & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
test:
- "apache_spark"
- "spark_session"
- "spark_http_odbc"
- "databricks_sql_endpoint"
- "databricks_cluster"
- "databricks_http_cluster"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release-internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
test:
- "apache_spark"
- "spark_session"
- "spark_http_odbc"
- "databricks_sql_endpoint"
- "databricks_cluster"
- "databricks_http_cluster"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release-prep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ jobs:
test:
- "apache_spark"
- "spark_session"
- "spark_http_odbc"
- "databricks_sql_endpoint"
- "databricks_cluster"
- "databricks_http_cluster"
Expand Down
2 changes: 1 addition & 1 deletion dagger/run_dbt_spark_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def test_spark(test_args):
spark_ctr, spark_host = get_spark_container(client)
tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr)

elif test_profile in ["databricks_cluster", "databricks_sql_endpoint"]:
elif test_profile in ["databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"]:
tst_container = (
tst_container.with_workdir("/")
.with_exec(["./scripts/configure_odbc.sh"])
Expand Down
54 changes: 34 additions & 20 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class SparkCredentials(Credentials):
auth: Optional[str] = None
kerberos_service_name: Optional[str] = None
organization: str = "0"
connection_string_suffix: Optional[str] = None
connect_retries: int = 0
connect_timeout: int = 10
use_ssl: bool = False
Expand Down Expand Up @@ -483,38 +484,51 @@ def open(cls, connection: Connection) -> Connection:
http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format(
endpoint=creds.endpoint
)
elif creds.connection_string_suffix is not None:
required_fields = ["driver", "host", "port", "connection_string_suffix"]
else:
raise DbtConfigError(
"Either `cluster` or `endpoint` must set when"
"Either `cluster`, `endpoint`, `connection_string_suffix` must set when"
" using the odbc method to connect to Spark"
)

cls.validate_creds(creds, required_fields)

dbt_spark_version = __version__.version
user_agent_entry = (
f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa
)

# http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm
ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()}

# https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm
connection_str = _build_odbc_connnection_string(
DRIVER=creds.driver,
HOST=creds.host,
PORT=creds.port,
UID="token",
PWD=creds.token,
HTTPPath=http_path,
AuthMech=3,
SparkServerType=3,
ThriftTransport=2,
SSL=1,
UserAgentEntry=user_agent_entry,
LCaseSspKeyName=0 if ssp else 1,
**ssp,
)
if creds.token is not None:
# https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm
connection_str = _build_odbc_connnection_string(
DRIVER=creds.driver,
HOST=creds.host,
PORT=creds.port,
UID="token",
PWD=creds.token,
HTTPPath=http_path,
AuthMech=3,
SparkServerType=3,
ThriftTransport=2,
SSL=1,
UserAgentEntry=user_agent_entry,
LCaseSspKeyName=0 if ssp else 1,
**ssp,
)
else:
connection_str = _build_odbc_connnection_string(
DRIVER=creds.driver,
HOST=creds.host,
PORT=creds.port,
ThriftTransport=2,
SSL=1,
UserAgentEntry=user_agent_entry,
LCaseSspKeyName=0 if ssp else 1,
**ssp,
)
if creds.connection_string_suffix is not None:
connection_str = connection_str + ";" + creds.connection_string_suffix

conn = pyodbc.connect(connection_str, autocommit=True)
handle = PyodbcConnectionWrapper(conn)
Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def dbt_profile_target(request):
target = databricks_http_cluster_target()
elif profile_type == "spark_session":
target = spark_session_target()
elif profile_type == "spark_http_odbc":
target = spark_http_odbc_target()
else:
raise ValueError(f"Invalid profile type '{profile_type}'")
return target
Expand Down Expand Up @@ -102,6 +104,20 @@ def spark_session_target():
}


def spark_http_odbc_target():
return {
"type": "spark",
"method": "odbc",
"host": os.getenv("DBT_DATABRICKS_HOST_NAME"),
"port": 443,
"driver": os.getenv("ODBC_DRIVER"),
"connection_string_suffix": f'UID=token;PWD={os.getenv("DBT_DATABRICKS_TOKEN")};HTTPPath=/sql/1.0/endpoints/{os.getenv("DBT_DATABRICKS_ENDPOINT")};AuthMech=3;SparkServerType=3',
"connect_retries": 3,
"connect_timeout": 5,
"retry_all": True,
}


@pytest.fixture(autouse=True)
def skip_by_profile_type(request):
profile_type = request.config.getoption("--profile")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_run_incremental_fail_on_schema_change(self, project):
assert "Compilation Error" in results_two[1].message


@pytest.mark.skip_profile("databricks_sql_endpoint")
@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_http_odbc")
class TestAppendOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -32,7 +32,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session")
@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_http_odbc")
class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["default_append", "expected_append"])

@pytest.mark.skip_profile(
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc"
)
def test_default_append(self, project):
self.run_and_test(project)
Expand All @@ -77,7 +77,7 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["insert_overwrite_partitions", "expected_upsert"])

@pytest.mark.skip_profile(
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc"
)
def test_insert_overwrite(self, project):
self.run_and_test(project)
Expand All @@ -103,7 +103,11 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"])

@pytest.mark.skip_profile(
"apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"apache_spark",
"databricks_http_cluster",
"databricks_sql_endpoint",
"spark_session",
"spark_http_odbc",
)
def test_delta_strategies(self, project):
self.run_and_test(project)
Expand Down
18 changes: 15 additions & 3 deletions tests/functional/adapter/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def models(self):


@pytest.mark.skip_profile(
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster"
"spark_session",
"apache_spark",
"databricks_sql_endpoint",
"databricks_cluster",
"spark_http_odbc",
)
class TestSparkTableConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual
Expand All @@ -198,7 +202,11 @@ def models(self):


@pytest.mark.skip_profile(
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster"
"spark_session",
"apache_spark",
"databricks_sql_endpoint",
"databricks_cluster",
"spark_http_odbc",
)
class TestSparkViewConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual
Expand All @@ -213,7 +221,11 @@ def models(self):


@pytest.mark.skip_profile(
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster"
"spark_session",
"apache_spark",
"databricks_sql_endpoint",
"databricks_cluster",
"spark_http_odbc",
)
class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual
Expand Down
16 changes: 12 additions & 4 deletions tests/functional/adapter/test_python_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
from dbt.tests.adapter.python_model.test_spark import BasePySparkTests


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile(
"apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc"
)
class TestPythonModelSpark(BasePythonModelTests):
pass


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile(
"apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc"
)
class TestPySpark(BasePySparkTests):
def test_different_dataframes(self, project):
"""
Expand All @@ -33,7 +37,9 @@ def test_different_dataframes(self, project):
assert len(results) == 3


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile(
"apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc"
)
class TestPythonIncrementalModelSpark(BasePythonIncrementalTests):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down Expand Up @@ -78,7 +84,9 @@ def model(dbt, spark):
"""


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile(
"apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc"
)
class TestChangingSchemaSpark:
"""
Confirm that we can setup a spot instance and parse required packages into the Databricks job.
Expand Down
4 changes: 3 additions & 1 deletion tests/functional/adapter/test_store_test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
)


@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint")
@pytest.mark.skip_profile(
"spark_session", "databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"
)
class TestSparkStoreTestFailures(StoreTestFailuresBase):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fixtures.profiles import *
Empty file added tests/unit/fixtures/__init__.py
Empty file.
Loading

0 comments on commit 3fc624c

Please sign in to comment.