Skip to content

Commit

Permalink
Feature/ssl support dbtspark (#169)
Browse files Browse the repository at this point in the history
* Added support for hive ssl

* Added support for hive ssl

* Updated code to remove pure-trasport dependency

* Fixed issues

* Updated test cases

* fixed test cases

* Fixed flake8 issues

* Update README.md

* Update README.md

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update README.md

* Update README.md

* Update README.md

* Added import except case

* Update README.md

* Fixed minor issue

* Fixed minor issue
  • Loading branch information
rahulgoyal2987 authored Jun 7, 2021
1 parent b1e5f77 commit 2ab5523
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## dbt next

### Features

- Allow user to specify `use_ssl` ([#169](https://github.com/fishtown-analytics/dbt-spark/pull/169))
- Allow setting table `OPTIONS` using `config` ([#171](https://github.com/fishtown-analytics/dbt-spark/pull/171))
- Add support for column comment ([#170](https://github.com/fishtown-analytics/dbt-spark/pull/170))

Expand All @@ -17,6 +19,7 @@
- [@friendofasquid](https://github.com/friendofasquid) ([#159](https://github.com/fishtown-analytics/dbt-spark/pull/159))
- [@franloza](https://github.com/franloza) ([#160](https://github.com/fishtown-analytics/dbt-spark/pull/160))
- [@Fokko](https://github.com/Fokko) ([#165](https://github.com/fishtown-analytics/dbt-spark/pull/165))
- [@rahulgoyal2987](https://github.com/rahulgoyal2987) ([#169](https://github.com/fishtown-analytics/dbt-spark/pull/169))
- [@JCZuurmond](https://github.com/JCZuurmond) ([#171](https://github.com/fishtown-analytics/dbt-spark/pull/171))
- [@cristianoperez](https://github.com/cristianoperez) ([#170](https://github.com/fishtown-analytics/dbt-spark/pull/170))

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ A dbt profile for Spark connections support the following configurations:
| user | The username to use to connect to the cluster |||| `hadoop` |
| connect_timeout | The number of seconds to wait before retrying to connect to a Pending Spark cluster || ❔ (`10`) | ❔ (`10`) | `60` |
| connect_retries | The number of times to try connecting to a Pending Spark cluster before giving up || ❔ (`0`) | ❔ (`0`) | `5` |
| use_ssl | The value of `hive.server2.use.SSL` (`True` or `False`). Default ssl store (ssl.get_default_verify_paths()) is the valid location for SSL certificate || ❔ (`False`) || `True` |

**Databricks** connections differ based on the cloud provider:

Expand Down
75 changes: 70 additions & 5 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
from hologram.helpers import StrEnum
from dataclasses import dataclass
from typing import Optional
try:
from thrift.transport.TSSLSocket import TSSLSocket
import thrift
import ssl
import sasl
import thrift_sasl
except ImportError:
TSSLSocket = None
thrift = None
ssl = None
sasl = None
thrift_sasl = None

import base64
import time
Expand Down Expand Up @@ -59,6 +71,7 @@ class SparkCredentials(Credentials):
organization: str = '0'
connect_retries: int = 0
connect_timeout: int = 10
use_ssl: bool = False

@classmethod
def __pre_deserialize__(cls, data):
Expand Down Expand Up @@ -348,11 +361,20 @@ def open(cls, connection):
cls.validate_creds(creds,
['host', 'port', 'user', 'schema'])

conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
if creds.use_ssl:
transport = build_ssl_transport(
host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name)
conn = hive.connect(thrift_transport=transport)
else:
conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
if creds.cluster is not None:
Expand Down Expand Up @@ -431,6 +453,49 @@ def open(cls, connection):
return connection


def build_ssl_transport(host, port, username, auth,
kerberos_service_name, password=None):
transport = None
if port is None:
port = 10000
if auth is None:
auth = 'NONE'
socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE)
if auth == 'NOSASL':
# NOSASL corresponds to hive.server2.authentication=NOSASL
# in hive-site.xml
transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI
# in sasl library
sasl_auth = 'GSSAPI'
else:
sasl_auth = 'PLAIN'
if password is None:
# Password doesn't matter in NONE mode, just needs
# to be nonempty.
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client

transport = thrift_sasl.TSaslClientTransport(sasl_factory,
sasl_auth, socket)
return transport


def _is_retryable_error(exc: Exception) -> Optional[str]:
message = getattr(exc, 'message', None)
if message is None:
Expand Down
35 changes: 35 additions & 0 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ def _get_target_thrift_kerberos(self, project):
'target': 'test'
})

def _get_target_use_ssl_thrift(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'use_ssl': True,
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
}
},
'target': 'test'
})

def _get_target_odbc_cluster(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
Expand Down Expand Up @@ -154,6 +170,25 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name):
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_ssl_connection(self):
config = self._get_target_use_ssl_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(thrift_transport):
self.assertIsNotNone(thrift_transport)
transport = thrift_transport._trans
self.assertEqual(transport.host, 'myorg.sparkhost.com')
self.assertEqual(transport.port, 10001)

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
connection.handle # trigger lazy-load

self.assertEqual(connection.state, 'open')
self.assertIsNotNone(connection.handle)
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)
Expand Down

0 comments on commit 2ab5523

Please sign in to comment.