Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use deterministic job_ids to avoid retrying successful queries #977

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import json
import re
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field

Expand All @@ -11,7 +12,7 @@
from functools import lru_cache
import agate
from requests.exceptions import ConnectionError
from typing import Optional, Any, Dict, Tuple
from typing import Optional, Any, Dict, List, Tuple

import google.auth
import google.auth.exceptions
Expand All @@ -25,10 +26,11 @@
)

from dbt.adapters.bigquery import gcloud
from dbt.adapters.bigquery.jobs import define_job_id
from dbt.clients import agate_helper
from dbt.config.profile import INVALID_PROFILE_MESSAGE
from dbt.tracking import active_user
from dbt.contracts.connection import ConnectionState, AdapterResponse
from dbt.contracts.connection import ConnectionState, AdapterResponse, AdapterRequiredConfig
from dbt.exceptions import (
FailedToConnectError,
DbtRuntimeError,
Expand Down Expand Up @@ -239,6 +241,10 @@ class BigQueryConnectionManager(BaseConnectionManager):
DEFAULT_INITIAL_DELAY = 1.0 # Seconds
DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds

def __init__(self, profile: AdapterRequiredConfig):
super().__init__(profile)
self.jobs_by_thread: Dict[Any, Any] = {}
McKnight-42 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def handle_error(cls, error, message):
error_msg = "\n".join([item["message"] for item in error.errors])
Expand Down Expand Up @@ -292,8 +298,29 @@ def exception_handler(self, sql):
exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip()
raise DbtRuntimeError(exc_message)

def cancel_open(self) -> None:
pass
def cancel_open(self) -> List[str]:
names = []
this_connection = self.get_if_exists()
with self.lock:
for thread_id, connection in self.thread_connections.items():
if connection is this_connection:
continue

if connection.handle is not None and connection.state == ConnectionState.OPEN:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be over-engineering, but I would consider putting the contents of this if block into its own method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be clear are you referring to lines 308-321?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referring to lines 310-318. Then this would look like:

names = []
this_connection = self.get_if_exists()

with self.lock:
    for thread_id, connection in self.thread_connections.items():

        if connection is this_connection:
            continue

        if connection.handle and connection.state == ConnectionState.OPEN:
            self.close_thread(thread_id, connection)  # or whatever name you choose

        if name := connection.name:
            names.append(name)

return names

Something else worth considering is whether we want to handle all of the threads within a connection. I don't know if there's more than one thread for a connection, but I feel like there is. If there's a connection with more than one thread, you'll close that connection in the second condition above when you get to the first thread. Then you'll skip past the second condition for every other thread since connection.state should be closed at that point.

I think what you probably want is a list of job_ids by connection. Then for each connection you would cancel the job. Once all jobs are cancelled, then close the connection.

client = connection.handle
for job_id in self.jobs_by_thread.get(thread_id, []):

def fn():
return client.cancel_job(job_id)

self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn)

self.close(connection)

if connection.name is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we specifically want all connections which are not this_connection? Or do we only want connections in which we cancelled jobs? In this current flow, a connection for which connection.state == ConnectionState.CLOSED will show up in names, which doesn't feel like an intuitive list to get from `cancel_open'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If i backtrack open_cancel to core it looks like one of the only places we call it is for cancel_open_connections which does make me think the desired result is that all closed connections are accounted for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words we should only be returning the connections which we cancelled during this call, right?

names.append(connection.name)

return names

@classmethod
def close(cls, connection):
Expand Down Expand Up @@ -491,18 +518,26 @@ def raw_execute(

job_creation_timeout = self.get_job_creation_timeout_seconds(conn)
job_execution_timeout = self.get_job_execution_timeout_seconds(conn)
# build out determinsitic_id
model_name = conn.credentials.schema # schema name as model name is not
invocation_id = str(uuid.uuid4())
job_id = define_job_id(model_name, invocation_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think uuid.uuid4() is deterministic, which means job_id is not either. Have you considered an md5 hash of sufficient attributes (model, connection name, etc.)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently calling uuid directly as part of getting unit tests swapped over for functionality I think initial/current plan was to use the invocation_id we define via tracking in core https://docs.getdbt.com/reference/dbt-jinja-functions/invocation_id and it itself is a uuid based on docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using invocation_id (which we only sometimes have) should we just use the actual query text (which we have to have)?

thread_id = self.get_thread_identifier()
self.jobs_by_thread[thread_id] = self.jobs_by_thread.get(thread_id, []) + [job_id]

def fn():
return self._query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=job_creation_timeout,
job_execution_timeout=job_execution_timeout,
limit=limit,
)

query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn)
self.jobs_by_thread.get(thread_id, []).remove(job_id)

return query_job, iterator

Expand Down Expand Up @@ -734,14 +769,17 @@ def _query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=None,
job_execution_timeout=None,
limit: Optional[int] = None,
):
"""Query the client and wait for results."""
# Cannot reuse job_config if destination is set and ddl is used
job_config = google.cloud.bigquery.QueryJobConfig(**job_params)
query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout)
query_job = client.query(
query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout
)
if (
query_job.location is not None
and query_job.job_id is not None
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def date_function(cls) -> str:

@classmethod
def is_cancelable(cls) -> bool:
return False
return True

def drop_relation(self, relation: BigQueryRelation) -> None:
is_cached = self._schema_is_cached(relation.database, relation.schema) # type: ignore[arg-type]
Expand Down
3 changes: 3 additions & 0 deletions dbt/adapters/bigquery/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def define_job_id(model_name, invocation_id):
job_id = f"{model_name}_{invocation_id}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the constraint on job_id? Is there a max length? Can all characters that go into a model name also be used in a job id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will definitely have to test this, I think all characters are fine as we should just be combining 2 strings but length may hit a limit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would want to make sure that we can submit a job_id to BQ with some weird characters. People put all kinds of things in their model names. An alternative is to hash the model_name so that it's only alpha-numeric.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 if we want to stick with uuid we can just generate a deterministic one with uuid.uuid5

return job_id
26 changes: 18 additions & 8 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

from google.cloud.bigquery import AccessEntry

from .utils import config_from_parts_or_dicts, inject_adapter, TestAdapterConversions
from .utils import (
config_from_parts_or_dicts,
inject_adapter,
mock_connection,
TestAdapterConversions,
)


def _bq_conn():
Expand Down Expand Up @@ -340,23 +345,28 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection):

def test_cancel_open_connections_empty(self):
adapter = self.get_adapter("oauth")
self.assertEqual(adapter.cancel_open_connections(), None)
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)

def test_cancel_open_connections_master(self):
adapter = self.get_adapter("oauth")
adapter.connections.thread_connections[0] = object()
self.assertEqual(adapter.cancel_open_connections(), None)
key = adapter.connections.get_thread_identifier()
adapter.connections.thread_connections[key] = mock_connection("master")
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)

def test_cancel_open_connections_single(self):
adapter = self.get_adapter("oauth")
master = mock_connection("master")
model = mock_connection("model")
model.handle.session_id = 42

key = adapter.connections.get_thread_identifier()
adapter.connections.thread_connections.update(
{
0: object(),
1: object(),
key: master,
1: model,
}
)
# actually does nothing
self.assertEqual(adapter.cancel_open_connections(), None)
self.assertEqual(len(list(adapter.cancel_open_connections())), 1)

@patch("dbt.adapters.bigquery.impl.google.auth.default")
@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_bigquery_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,32 +113,35 @@ def test_query_and_results(self, mock_bq):
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_id=1,
job_creation_timeout=15,
job_execution_timeout=3,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
)

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results_timeout(self, mock_bq):
self.mock_client.query = Mock(
return_value=Mock(result=lambda *args, **kwargs: time.sleep(4))
)

with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc:
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_id=1,
job_creation_timeout=15,
job_execution_timeout=1,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
)
assert "Query exceeded configured timeout of 1s" in str(exc.value)

Expand Down
Loading