From 3fc624cb99488e803956304c9dea2c10facab08d Mon Sep 17 00:00:00 2001 From: Colin Rogers <111200756+colin-rogers-dbt@users.noreply.github.com> Date: Thu, 12 Sep 2024 13:00:36 -0700 Subject: [PATCH] Feature/http odbc conn extra (#1093) * add support for extra odbc connection properties * clean up * fix typo in test_incremental_on_schema_change.py * fix formatting * changelog * Add unit test and refactor unit test fixtures * update changie * update changie * remove holdover code * remove dbt-core ref --------- Co-authored-by: nilan3 Co-authored-by: Mike Alfare --- .../unreleased/Features-20240910-175846.yaml | 6 + .github/workflows/integration.yml | 1 + .github/workflows/release-internal.yml | 1 + .github/workflows/release-prep.yml | 1 + dagger/run_dbt_spark_tests.py | 2 +- dbt/adapters/spark/connections.py | 54 +++-- tests/conftest.py | 16 ++ .../test_incremental_on_schema_change.py | 4 +- .../test_incremental_strategies.py | 10 +- tests/functional/adapter/test_constraints.py | 18 +- tests/functional/adapter/test_python_model.py | 16 +- .../adapter/test_store_test_failures.py | 4 +- tests/unit/conftest.py | 1 + tests/unit/fixtures/__init__.py | 0 tests/unit/fixtures/profiles.py | 174 +++++++++++++ tests/unit/test_adapter.py | 229 +++++------------- 16 files changed, 336 insertions(+), 201 deletions(-) create mode 100644 .changes/unreleased/Features-20240910-175846.yaml create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/fixtures/__init__.py create mode 100644 tests/unit/fixtures/profiles.py diff --git a/.changes/unreleased/Features-20240910-175846.yaml b/.changes/unreleased/Features-20240910-175846.yaml new file mode 100644 index 000000000..68ef8551e --- /dev/null +++ b/.changes/unreleased/Features-20240910-175846.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support custom ODBC connection parameters via `connection_string_suffix` config +time: 2024-09-10T17:58:46.141332-04:00 +custom: + Author: colin-rogers-dbt jpoley nilan3 + Issue: "1092" diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 699d45391..35bd9cae0 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -76,6 +76,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/.github/workflows/release-internal.yml b/.github/workflows/release-internal.yml index d4e7a3c93..1a5090312 100644 --- a/.github/workflows/release-internal.yml +++ b/.github/workflows/release-internal.yml @@ -79,6 +79,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/.github/workflows/release-prep.yml b/.github/workflows/release-prep.yml index 9cb2c3e19..9937463d3 100644 --- a/.github/workflows/release-prep.yml +++ b/.github/workflows/release-prep.yml @@ -482,6 +482,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 15f9cf2c2..67fa56587 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -137,7 +137,7 @@ async def test_spark(test_args): spark_ctr, spark_host = get_spark_container(client) tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr) - elif test_profile in ["databricks_cluster", "databricks_sql_endpoint"]: + elif test_profile in ["databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"]: tst_container = ( tst_container.with_workdir("/") .with_exec(["./scripts/configure_odbc.sh"]) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 0405eaf5b..d9b615ecb 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -78,6 +78,7 @@ class SparkCredentials(Credentials): auth: Optional[str] = None kerberos_service_name: Optional[str] = None organization: str = "0" + connection_string_suffix: Optional[str] = None connect_retries: int = 0 connect_timeout: int = 10 use_ssl: bool = False @@ -483,38 +484,51 @@ def open(cls, connection: Connection) -> Connection: http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format( endpoint=creds.endpoint ) + elif creds.connection_string_suffix is not None: + required_fields = ["driver", "host", "port", "connection_string_suffix"] else: raise DbtConfigError( - "Either `cluster` or `endpoint` must set when" + "Either `cluster`, `endpoint`, `connection_string_suffix` must set when" " using the odbc method to connect to Spark" ) cls.validate_creds(creds, required_fields) - dbt_spark_version = __version__.version user_agent_entry = ( f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa ) - # http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()} - - # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm - connection_str = _build_odbc_connnection_string( - DRIVER=creds.driver, - HOST=creds.host, - PORT=creds.port, - UID="token", - PWD=creds.token, - HTTPPath=http_path, - AuthMech=3, - SparkServerType=3, - ThriftTransport=2, - SSL=1, - UserAgentEntry=user_agent_entry, - LCaseSspKeyName=0 if ssp else 1, - **ssp, - ) + if creds.token is not None: + # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm + connection_str = _build_odbc_connnection_string( + DRIVER=creds.driver, + HOST=creds.host, + PORT=creds.port, + UID="token", + PWD=creds.token, + HTTPPath=http_path, + AuthMech=3, + SparkServerType=3, + ThriftTransport=2, + SSL=1, + UserAgentEntry=user_agent_entry, + LCaseSspKeyName=0 if ssp else 1, + **ssp, + ) + else: + connection_str = _build_odbc_connnection_string( + DRIVER=creds.driver, + HOST=creds.host, + PORT=creds.port, + ThriftTransport=2, + SSL=1, + UserAgentEntry=user_agent_entry, + LCaseSspKeyName=0 if ssp else 1, + **ssp, + ) + if creds.connection_string_suffix is not None: + connection_str = connection_str + ";" + creds.connection_string_suffix conn = pyodbc.connect(connection_str, autocommit=True) handle = PyodbcConnectionWrapper(conn) diff --git a/tests/conftest.py b/tests/conftest.py index efba41a5f..09b31f406 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,8 @@ def dbt_profile_target(request): target = databricks_http_cluster_target() elif profile_type == "spark_session": target = spark_session_target() + elif profile_type == "spark_http_odbc": + target = spark_http_odbc_target() else: raise ValueError(f"Invalid profile type '{profile_type}'") return target @@ -102,6 +104,20 @@ def spark_session_target(): } +def spark_http_odbc_target(): + return { + "type": "spark", + "method": "odbc", + "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), + "port": 443, + "driver": os.getenv("ODBC_DRIVER"), + "connection_string_suffix": f'UID=token;PWD={os.getenv("DBT_DATABRICKS_TOKEN")};HTTPPath=/sql/1.0/endpoints/{os.getenv("DBT_DATABRICKS_ENDPOINT")};AuthMech=3;SparkServerType=3', + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + + @pytest.fixture(autouse=True) def skip_by_profile_type(request): profile_type = request.config.getoption("--profile") diff --git a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py index 478329668..6f881697c 100644 --- a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py +++ b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py @@ -21,7 +21,7 @@ def test_run_incremental_fail_on_schema_change(self, project): assert "Compilation Error" in results_two[1].message -@pytest.mark.skip_profile("databricks_sql_endpoint") +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_http_odbc") class TestAppendOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): @@ -32,7 +32,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session") +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_http_odbc") class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py index b05fcb279..a44a1d23e 100644 --- a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py +++ b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py @@ -55,7 +55,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["default_append", "expected_append"]) @pytest.mark.skip_profile( - "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_default_append(self, project): self.run_and_test(project) @@ -77,7 +77,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["insert_overwrite_partitions", "expected_upsert"]) @pytest.mark.skip_profile( - "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_insert_overwrite(self, project): self.run_and_test(project) @@ -103,7 +103,11 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"]) @pytest.mark.skip_profile( - "apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "apache_spark", + "databricks_http_cluster", + "databricks_sql_endpoint", + "spark_session", + "spark_http_odbc", ) def test_delta_strategies(self, project): self.run_and_test(project) diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index e35a13a64..f33359262 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -183,7 +183,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_http_odbc", ) class TestSparkTableConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual @@ -198,7 +202,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_http_odbc", ) class TestSparkViewConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual @@ -213,7 +221,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_http_odbc", ) class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py index cd798d1da..50132b883 100644 --- a/tests/functional/adapter/test_python_model.py +++ b/tests/functional/adapter/test_python_model.py @@ -8,12 +8,16 @@ from dbt.tests.adapter.python_model.test_spark import BasePySparkTests -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestPythonModelSpark(BasePythonModelTests): pass -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestPySpark(BasePySparkTests): def test_different_dataframes(self, project): """ @@ -33,7 +37,9 @@ def test_different_dataframes(self, project): assert len(results) == 3 -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestPythonIncrementalModelSpark(BasePythonIncrementalTests): @pytest.fixture(scope="class") def project_config_update(self): @@ -78,7 +84,9 @@ def model(dbt, spark): """ -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestChangingSchemaSpark: """ Confirm that we can setup a spot instance and parse required packages into the Databricks job. diff --git a/tests/functional/adapter/test_store_test_failures.py b/tests/functional/adapter/test_store_test_failures.py index e78bd4f71..3d8a4c192 100644 --- a/tests/functional/adapter/test_store_test_failures.py +++ b/tests/functional/adapter/test_store_test_failures.py @@ -7,7 +7,9 @@ ) -@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "spark_session", "databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc" +) class TestSparkStoreTestFailures(StoreTestFailuresBase): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..c3b000352 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1 @@ +from .fixtures.profiles import * diff --git a/tests/unit/fixtures/__init__.py b/tests/unit/fixtures/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/fixtures/profiles.py b/tests/unit/fixtures/profiles.py new file mode 100644 index 000000000..c5f24581e --- /dev/null +++ b/tests/unit/fixtures/profiles.py @@ -0,0 +1,174 @@ +import pytest + +from tests.unit.utils import config_from_parts_or_dicts + + +@pytest.fixture(scope="session", autouse=True) +def base_project_cfg(): + return { + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "quoting": { + "identifier": False, + "schema": False, + }, + "config-version": 2, + } + + +@pytest.fixture(scope="session", autouse=True) +def target_http(base_project_cfg): + config = config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "http", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 443, + "token": "abc123", + "organization": "0123456789", + "cluster": "01234-23423-coffeetime", + "server_side_parameters": {"spark.driver.memory": "4g"}, + } + }, + "target": "test", + }, + ) + return config + + +@pytest.fixture(scope="session", autouse=True) +def target_thrift(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "thrift", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 10001, + "user": "dbt", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_thrift_kerberos(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "thrift", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 10001, + "user": "dbt", + "auth": "KERBEROS", + "kerberos_service_name": "hive", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_use_ssl_thrift(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "thrift", + "use_ssl": True, + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 10001, + "user": "dbt", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_odbc_cluster(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "odbc", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 443, + "token": "abc123", + "organization": "0123456789", + "cluster": "01234-23423-coffeetime", + "driver": "Simba", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_odbc_sql_endpoint(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "odbc", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 443, + "token": "abc123", + "endpoint": "012342342393920a", + "driver": "Simba", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_odbc_with_extra_conn(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "odbc", + "host": "myorg.sparkhost.com", + "schema": "analytics", + "port": 443, + "driver": "Simba", + "connection_string_suffix": "someExtraValues", + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + }, + "target": "test", + }, + ) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 54e9f0158..323e82a11 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,8 +1,8 @@ import unittest +import pytest from multiprocessing import get_context from unittest import mock -import dbt.flags as flags from dbt.exceptions import DbtRuntimeError from agate import Row from pyhive import hive @@ -11,143 +11,29 @@ class TestSparkAdapter(unittest.TestCase): - def setUp(self): - flags.STRICT_MODE = False - - self.project_cfg = { - "name": "X", - "version": "0.1", - "profile": "test", - "project-root": "/tmp/dbt/does-not-exist", - "quoting": { - "identifier": False, - "schema": False, - }, - "config-version": 2, - } - - def _get_target_http(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "http", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 443, - "token": "abc123", - "organization": "0123456789", - "cluster": "01234-23423-coffeetime", - "server_side_parameters": {"spark.driver.memory": "4g"}, - } - }, - "target": "test", - }, - ) - - def _get_target_thrift(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "thrift", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 10001, - "user": "dbt", - } - }, - "target": "test", - }, - ) - - def _get_target_thrift_kerberos(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "thrift", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 10001, - "user": "dbt", - "auth": "KERBEROS", - "kerberos_service_name": "hive", - } - }, - "target": "test", - }, - ) - - def _get_target_use_ssl_thrift(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "thrift", - "use_ssl": True, - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 10001, - "user": "dbt", - } - }, - "target": "test", - }, - ) - - def _get_target_odbc_cluster(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "odbc", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 443, - "token": "abc123", - "organization": "0123456789", - "cluster": "01234-23423-coffeetime", - "driver": "Simba", - } - }, - "target": "test", - }, - ) - - def _get_target_odbc_sql_endpoint(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "odbc", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 443, - "token": "abc123", - "endpoint": "012342342393920a", - "driver": "Simba", - } - }, - "target": "test", - }, - ) + @pytest.fixture(autouse=True) + def set_up_fixtures( + self, + target_http, + target_odbc_with_extra_conn, + target_thrift, + target_thrift_kerberos, + target_odbc_sql_endpoint, + target_odbc_cluster, + target_use_ssl_thrift, + base_project_cfg, + ): + self.base_project_cfg = base_project_cfg + self.target_http = target_http + self.target_odbc_with_extra_conn = target_odbc_with_extra_conn + self.target_odbc_sql_endpoint = target_odbc_sql_endpoint + self.target_odbc_cluster = target_odbc_cluster + self.target_thrift = target_thrift + self.target_thrift_kerberos = target_thrift_kerberos + self.target_use_ssl_thrift = target_use_ssl_thrift def test_http_connection(self): - config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_http, get_context("spawn")) def hive_http_connect(thrift_transport, configuration): self.assertEqual(thrift_transport.scheme, "https") @@ -171,7 +57,7 @@ def hive_http_connect(thrift_transport, configuration): self.assertIsNone(connection.credentials.database) def test_thrift_connection(self): - config = self._get_target_thrift(self.project_cfg) + config = self.target_thrift adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect( @@ -195,8 +81,7 @@ def hive_thrift_connect( self.assertIsNone(connection.credentials.database) def test_thrift_ssl_connection(self): - config = self._get_target_use_ssl_thrift(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_use_ssl_thrift, get_context("spawn")) def hive_thrift_connect(thrift_transport, configuration): self.assertIsNotNone(thrift_transport) @@ -215,8 +100,7 @@ def hive_thrift_connect(thrift_transport, configuration): self.assertIsNone(connection.credentials.database) def test_thrift_connection_kerberos(self): - config = self._get_target_thrift_kerberos(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_thrift_kerberos, get_context("spawn")) def hive_thrift_connect( host, port, username, auth, kerberos_service_name, password, configuration @@ -239,8 +123,7 @@ def hive_thrift_connect( self.assertIsNone(connection.credentials.database) def test_odbc_cluster_connection(self): - config = self._get_target_odbc_cluster(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_odbc_cluster, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -266,8 +149,7 @@ def pyodbc_connect(connection_str, autocommit): self.assertIsNone(connection.credentials.database) def test_odbc_endpoint_connection(self): - config = self._get_target_odbc_sql_endpoint(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_odbc_sql_endpoint, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -291,6 +173,26 @@ def pyodbc_connect(connection_str, autocommit): self.assertEqual(connection.credentials.schema, "analytics") self.assertIsNone(connection.credentials.database) + def test_odbc_with_extra_connection_string(self): + adapter = SparkAdapter(self.target_odbc_with_extra_conn, get_context("spawn")) + + def pyodbc_connect(connection_str, autocommit): + self.assertTrue(autocommit) + self.assertIn("driver=simba;", connection_str.lower()) + self.assertIn("port=443;", connection_str.lower()) + self.assertIn("host=myorg.sparkhost.com;", connection_str.lower()) + self.assertIn("someExtraValues", connection_str) + + with mock.patch( + "dbt.adapters.spark.connections.pyodbc.connect", new=pyodbc_connect + ): # noqa + connection = adapter.acquire_connection("dummy") + connection.handle # trigger lazy-load + + self.assertEqual(connection.state, "open") + self.assertIsNotNone(connection.handle) + self.assertIsNone(connection.credentials.database) + def test_parse_relation(self): self.maxDiff = None rel_type = SparkRelation.get_relation_type.Table @@ -329,8 +231,7 @@ def test_parse_relation(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + rows = SparkAdapter(self.target_http, get_context("spawn")).parse_describe_extended( relation, input_cols ) self.assertEqual(len(rows), 4) @@ -420,8 +321,7 @@ def test_parse_relation_with_integer_owner(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + rows = SparkAdapter(self.target_http, get_context("spawn")).parse_describe_extended( relation, input_cols ) @@ -458,8 +358,7 @@ def test_parse_relation_with_statistics(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + rows = SparkAdapter(self.target_http, get_context("spawn")).parse_describe_extended( relation, input_cols ) self.assertEqual(len(rows), 1) @@ -489,8 +388,7 @@ def test_parse_relation_with_statistics(self): ) def test_relation_with_database(self): - config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_http, get_context("spawn")) # fine adapter.Relation.create(schema="different", identifier="table") with self.assertRaises(DbtRuntimeError): @@ -516,7 +414,7 @@ def test_profile_with_database(self): "target": "test", } with self.assertRaises(DbtRuntimeError): - config_from_parts_or_dicts(self.project_cfg, profile) + config_from_parts_or_dicts(self.base_project_cfg, profile) def test_profile_with_cluster_and_sql_endpoint(self): profile = { @@ -536,7 +434,7 @@ def test_profile_with_cluster_and_sql_endpoint(self): "target": "test", } with self.assertRaises(DbtRuntimeError): - config_from_parts_or_dicts(self.project_cfg, profile) + config_from_parts_or_dicts(self.base_project_cfg, profile) def test_parse_columns_from_information_with_table_type_and_delta_provider(self): self.maxDiff = None @@ -570,10 +468,9 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) schema="default_schema", identifier="mytable", type=rel_type, information=information ) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( - relation - ) + columns = SparkAdapter( + self.target_http, get_context("spawn") + ).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( columns[0].to_column_dict(omit_none=False), @@ -657,10 +554,9 @@ def test_parse_columns_from_information_with_view_type(self): schema="default_schema", identifier="myview", type=rel_type, information=information ) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( - relation - ) + columns = SparkAdapter( + self.target_http, get_context("spawn") + ).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( columns[1].to_column_dict(omit_none=False), @@ -725,10 +621,9 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel schema="default_schema", identifier="mytable", type=rel_type, information=information ) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( - relation - ) + columns = SparkAdapter( + self.target_http, get_context("spawn") + ).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual(