From 2ab55238451ef2a4c6c4dec226e844d52148e4b2 Mon Sep 17 00:00:00 2001 From: rahulgoyal2987 Date: Mon, 7 Jun 2021 17:58:05 +0530 Subject: [PATCH] Feature/ssl support dbtspark (#169) * Added support for hive ssl * Added support for hive ssl * Updated code to remove pure-trasport dependency * Fixed issues * Updated test cases * fixed test cases * Fixed flake8 issues * Update README.md * Update README.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update README.md * Update README.md * Update README.md * Added import except case * Update README.md * Fixed minor issue * Fixed minor issue --- CHANGELOG.md | 3 ++ README.md | 1 + dbt/adapters/spark/connections.py | 75 ++++++++++++++++++++++++++++--- test/unit/test_adapter.py | 35 +++++++++++++++ 4 files changed, 109 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd2cf9b1a..948a64a8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ ## dbt next ### Features + +- Allow user to specify `use_ssl` ([#169](https://github.com/fishtown-analytics/dbt-spark/pull/169)) - Allow setting table `OPTIONS` using `config` ([#171](https://github.com/fishtown-analytics/dbt-spark/pull/171)) - Add support for column comment ([#170](https://github.com/fishtown-analytics/dbt-spark/pull/170)) @@ -17,6 +19,7 @@ - [@friendofasquid](https://github.com/friendofasquid) ([#159](https://github.com/fishtown-analytics/dbt-spark/pull/159)) - [@franloza](https://github.com/franloza) ([#160](https://github.com/fishtown-analytics/dbt-spark/pull/160)) - [@Fokko](https://github.com/Fokko) ([#165](https://github.com/fishtown-analytics/dbt-spark/pull/165)) +- [@rahulgoyal2987](https://github.com/rahulgoyal2987) ([#169](https://github.com/fishtown-analytics/dbt-spark/pull/169)) - [@JCZuurmond](https://github.com/JCZuurmond) ([#171](https://github.com/fishtown-analytics/dbt-spark/pull/171)) - [@cristianoperez](https://github.com/cristianoperez) ([#170](https://github.com/fishtown-analytics/dbt-spark/pull/170)) diff --git a/README.md b/README.md index 9e57e6075..71ec7cdf5 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ A dbt profile for Spark connections support the following configurations: | user | The username to use to connect to the cluster | ❔ | ❔ | ❔ | `hadoop` | | connect_timeout | The number of seconds to wait before retrying to connect to a Pending Spark cluster | ❌ | ❔ (`10`) | ❔ (`10`) | `60` | | connect_retries | The number of times to try connecting to a Pending Spark cluster before giving up | ❌ | ❔ (`0`) | ❔ (`0`) | `5` | +| use_ssl | The value of `hive.server2.use.SSL` (`True` or `False`). Default ssl store (ssl.get_default_verify_paths()) is the valid location for SSL certificate | ❌ | ❔ (`False`) | ❌ | `True` | **Databricks** connections differ based on the cloud provider: diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index bd26f6efe..1bc8d80b0 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -26,6 +26,18 @@ from hologram.helpers import StrEnum from dataclasses import dataclass from typing import Optional +try: + from thrift.transport.TSSLSocket import TSSLSocket + import thrift + import ssl + import sasl + import thrift_sasl +except ImportError: + TSSLSocket = None + thrift = None + ssl = None + sasl = None + thrift_sasl = None import base64 import time @@ -59,6 +71,7 @@ class SparkCredentials(Credentials): organization: str = '0' connect_retries: int = 0 connect_timeout: int = 10 + use_ssl: bool = False @classmethod def __pre_deserialize__(cls, data): @@ -348,11 +361,20 @@ def open(cls, connection): cls.validate_creds(creds, ['host', 'port', 'user', 'schema']) - conn = hive.connect(host=creds.host, - port=creds.port, - username=creds.user, - auth=creds.auth, - kerberos_service_name=creds.kerberos_service_name) # noqa + if creds.use_ssl: + transport = build_ssl_transport( + host=creds.host, + port=creds.port, + username=creds.user, + auth=creds.auth, + kerberos_service_name=creds.kerberos_service_name) + conn = hive.connect(thrift_transport=transport) + else: + conn = hive.connect(host=creds.host, + port=creds.port, + username=creds.user, + auth=creds.auth, + kerberos_service_name=creds.kerberos_service_name) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: if creds.cluster is not None: @@ -431,6 +453,49 @@ def open(cls, connection): return connection +def build_ssl_transport(host, port, username, auth, + kerberos_service_name, password=None): + transport = None + if port is None: + port = 10000 + if auth is None: + auth = 'NONE' + socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE) + if auth == 'NOSASL': + # NOSASL corresponds to hive.server2.authentication=NOSASL + # in hive-site.xml + transport = thrift.transport.TTransport.TBufferedTransport(socket) + elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): + # Defer import so package dependency is optional + if auth == 'KERBEROS': + # KERBEROS mode in hive.server2.authentication is GSSAPI + # in sasl library + sasl_auth = 'GSSAPI' + else: + sasl_auth = 'PLAIN' + if password is None: + # Password doesn't matter in NONE mode, just needs + # to be nonempty. + password = 'x' + + def sasl_factory(): + sasl_client = sasl.Client() + sasl_client.setAttr('host', host) + if sasl_auth == 'GSSAPI': + sasl_client.setAttr('service', kerberos_service_name) + elif sasl_auth == 'PLAIN': + sasl_client.setAttr('username', username) + sasl_client.setAttr('password', password) + else: + raise AssertionError + sasl_client.init() + return sasl_client + + transport = thrift_sasl.TSaslClientTransport(sasl_factory, + sasl_auth, socket) + return transport + + def _is_retryable_error(exc: Exception) -> Optional[str]: message = getattr(exc, 'message', None) if message is None: diff --git a/test/unit/test_adapter.py b/test/unit/test_adapter.py index d886ddee3..ddfbeddb2 100644 --- a/test/unit/test_adapter.py +++ b/test/unit/test_adapter.py @@ -75,6 +75,22 @@ def _get_target_thrift_kerberos(self, project): 'target': 'test' }) + def _get_target_use_ssl_thrift(self, project): + return config_from_parts_or_dicts(project, { + 'outputs': { + 'test': { + 'type': 'spark', + 'method': 'thrift', + 'use_ssl': True, + 'schema': 'analytics', + 'host': 'myorg.sparkhost.com', + 'port': 10001, + 'user': 'dbt' + } + }, + 'target': 'test' + }) + def _get_target_odbc_cluster(self, project): return config_from_parts_or_dicts(project, { 'outputs': { @@ -154,6 +170,25 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name): self.assertEqual(connection.credentials.schema, 'analytics') self.assertIsNone(connection.credentials.database) + def test_thrift_ssl_connection(self): + config = self._get_target_use_ssl_thrift(self.project_cfg) + adapter = SparkAdapter(config) + + def hive_thrift_connect(thrift_transport): + self.assertIsNotNone(thrift_transport) + transport = thrift_transport._trans + self.assertEqual(transport.host, 'myorg.sparkhost.com') + self.assertEqual(transport.port, 10001) + + with mock.patch.object(hive, 'connect', new=hive_thrift_connect): + connection = adapter.acquire_connection('dummy') + connection.handle # trigger lazy-load + + self.assertEqual(connection.state, 'open') + self.assertIsNotNone(connection.handle) + self.assertEqual(connection.credentials.schema, 'analytics') + self.assertIsNone(connection.credentials.database) + def test_thrift_connection_kerberos(self): config = self._get_target_thrift_kerberos(self.project_cfg) adapter = SparkAdapter(config)