Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address type annotation issues and clean up impl #933

Merged
merged 10 commits into from
Sep 22, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230921-155645.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Address type annotation issues and remove protected method ref from impl
time: 2023-09-21T15:56:45.329798-07:00
custom:
Author: colin-rogers-dbt
Issue: "933"
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230922-125327.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Fixed a mypy failure by reworking BigQueryAdapter constructor.
time: 2023-09-22T12:53:27.339599-04:00
custom:
Author: peterallenwebb
Issue: "934"
14 changes: 14 additions & 0 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,20 @@ def fn():

self._retry_and_handle(msg="create dataset", conn=conn, fn=fn)

def list_dataset(self, database: str):
# the database string we get here is potentially quoted. Strip that off
# for the API call.
database = database.strip("`")
conn = self.get_thread_connection()
client = conn.handle

def query_schemas():
# this is similar to how we have to deal with listing tables
all_datasets = client.list_datasets(project=database, max_results=10000)
return [ds.dataset_id for ds in all_datasets]

return self._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas)

def _query_and_results(
self,
client,
Expand Down
51 changes: 5 additions & 46 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ class BigQueryAdapter(BaseAdapter):
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}

def __init__(self, config) -> None:
super().__init__(config)
self.connections: BigQueryConnectionManager = self.connections

###
# Implementations of abstract methods
###
Expand Down Expand Up @@ -267,18 +271,7 @@ def rename_relation(

@available
def list_schemas(self, database: str) -> List[str]:
# the database string we get here is potentially quoted. Strip that off
# for the API call.
database = database.strip("`")
conn = self.connections.get_thread_connection()
client = conn.handle

def query_schemas():
# this is similar to how we have to deal with listing tables
all_datasets = client.list_datasets(project=database, max_results=10000)
return [ds.dataset_id for ds in all_datasets]

return self.connections._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should never be referencing a private ConnectionManager method in impl

return self.connections.list_dataset(database)

@available.parse(lambda *a, **k: False)
def check_schema_exists(self, database: str, schema: str) -> bool:
Expand Down Expand Up @@ -481,40 +474,6 @@ def _agate_to_schema(
bq_schema.append(SchemaField(col_name, type_)) # type: ignore[arg-type]
return bq_schema

def _materialize_as_view(self, model: Dict[str, Any]) -> str:
model_database = model.get("database")
model_schema = model.get("schema")
model_alias = model.get("alias")
model_code = model.get("compiled_code")

logger.debug("Model SQL ({}):\n{}".format(model_alias, model_code))
self.connections.create_view(
database=model_database, schema=model_schema, table_name=model_alias, sql=model_code
)
return "CREATE VIEW"

def _materialize_as_table(
self,
model: Dict[str, Any],
model_sql: str,
decorator: Optional[str] = None,
) -> str:
model_database = model.get("database")
model_schema = model.get("schema")
model_alias = model.get("alias")

if decorator is None:
table_name = model_alias
else:
table_name = "{}${}".format(model_alias, decorator)

logger.debug("Model SQL ({}):\n{}".format(table_name, model_sql))
self.connections.create_table(
database=model_database, schema=model_schema, table_name=table_name, sql=model_sql
)

return "CREATE TABLE"

@available.parse(lambda *a, **k: "")
def copy_table(self, source, destination, materialization):
if materialization == "incremental":
Expand Down
182 changes: 1 addition & 181 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
import time

import agate
import decimal
import json
import string
import random
import re
import pytest
import unittest
from contextlib import contextmanager
from requests.exceptions import ConnectionError
from unittest.mock import patch, MagicMock, Mock, create_autospec, ANY
from unittest.mock import patch, MagicMock, create_autospec

import dbt.dataclass_schema

from dbt.adapters.bigquery import PartitionConfig
from dbt.adapters.bigquery import BigQueryCredentials
from dbt.adapters.bigquery import BigQueryAdapter
from dbt.adapters.bigquery import BigQueryRelation
from dbt.adapters.bigquery import Plugin as BigQueryPlugin
from google.cloud.bigquery.table import Table
from dbt.adapters.bigquery.connections import BigQueryConnectionManager
from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT
from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.clients import agate_helper
Expand Down Expand Up @@ -543,179 +536,6 @@ def test_replace(self):
assert other_schema.quote_policy.database is False


class TestBigQueryConnectionManager(unittest.TestCase):
def setUp(self):
credentials = Mock(BigQueryCredentials)
profile = Mock(query_comment=None, credentials=credentials)
self.connections = BigQueryConnectionManager(profile=profile)

self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client)
self.mock_connection = MagicMock()

self.mock_connection.handle = self.mock_client

self.connections.get_thread_connection = lambda: self.mock_connection
self.connections.get_job_retry_deadline_seconds = lambda x: None
self.connections.get_job_retries = lambda x: 1

@patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True)
def test_retry_and_handle(self, is_retryable):
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0

@contextmanager
def dummy_handler(msg):
yield

self.connections.exception_handler = dummy_handler

class DummyException(Exception):
"""Count how many times this exception is raised"""

count = 0

def __init__(self):
DummyException.count += 1

def raiseDummyException():
raise DummyException()

with self.assertRaises(DummyException):
self.connections._retry_and_handle(
"some sql", Mock(credentials=Mock(retries=8)), raiseDummyException
)
self.assertEqual(DummyException.count, 9)

@patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True)
def test_retry_connection_reset(self, is_retryable):
self.connections.open = MagicMock()
self.connections.close = MagicMock()
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0

@contextmanager
def dummy_handler(msg):
yield

self.connections.exception_handler = dummy_handler

def raiseConnectionResetError():
raise ConnectionResetError("Connection broke")

mock_conn = Mock(credentials=Mock(retries=1))
with self.assertRaises(ConnectionResetError):
self.connections._retry_and_handle("some sql", mock_conn, raiseConnectionResetError)
self.connections.close.assert_called_once_with(mock_conn)
self.connections.open.assert_called_once_with(mock_conn)

def test_is_retryable(self):
_is_retryable = dbt.adapters.bigquery.connections._is_retryable
exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions
internal_server_error = exceptions.InternalServerError("code broke")
bad_request_error = exceptions.BadRequest("code broke")
connection_error = ConnectionError("code broke")
client_error = exceptions.ClientError("bad code")
rate_limit_error = exceptions.Forbidden(
"code broke", errors=[{"reason": "rateLimitExceeded"}]
)

self.assertTrue(_is_retryable(internal_server_error))
self.assertTrue(_is_retryable(bad_request_error))
self.assertTrue(_is_retryable(connection_error))
self.assertFalse(_is_retryable(client_error))
self.assertTrue(_is_retryable(rate_limit_error))

def test_drop_dataset(self):
mock_table = Mock()
mock_table.reference = "table1"
self.mock_client.list_tables.return_value = [mock_table]

self.connections.drop_dataset("project", "dataset")

self.mock_client.list_tables.assert_not_called()
self.mock_client.delete_table.assert_not_called()
self.mock_client.delete_dataset.assert_called_once()

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results(self, mock_bq):
self.mock_client.query = Mock(return_value=Mock(state="DONE"))
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_creation_timeout=15,
job_execution_timeout=3,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
)

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results_timeout(self, mock_bq):
self.mock_client.query = Mock(
return_value=Mock(result=lambda *args, **kwargs: time.sleep(4))
)
with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc:
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_creation_timeout=15,
job_execution_timeout=1,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
)
assert "Query exceeded configured timeout of 1s" in str(exc.value)

def test_copy_bq_table_appends(self):
self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_APPEND
)

def test_copy_bq_table_truncates(self):
self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_TRUNCATE)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_TRUNCATE
)

def test_job_labels_valid_json(self):
expected = {"key": "value"}
labels = self.connections._labels_from_query_comment(json.dumps(expected))
self.assertEqual(labels, expected)

def test_job_labels_invalid_json(self):
labels = self.connections._labels_from_query_comment("not json")
self.assertEqual(labels, {"query_comment": "not_json"})

def _table_ref(self, proj, ds, table):
return self.connections.table_ref(proj, ds, table)

def _copy_table(self, write_disposition):
source = BigQueryRelation.create(database="project", schema="dataset", identifier="table1")
destination = BigQueryRelation.create(
database="project", schema="dataset", identifier="table2"
)
self.connections.copy_bq_table(source, destination, write_disposition)


class TestBigQueryAdapter(BaseTestBigQueryAdapter):
def test_copy_table_materialization_table(self):
adapter = self.get_adapter("oauth")
Expand Down
Loading