Skip to content

Commit

Permalink
lazy load agate and add TYPE_CHECKING (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
dwreeves authored Mar 26, 2024
1 parent bc3f8eb commit 1c137ee
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 72 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240309-150639.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Lazy load agate to improve dbt-core performance
time: 2024-03-09T15:06:39.038593-05:00
custom:
Author: dwreeves
Issue: "125"
11 changes: 6 additions & 5 deletions dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
Tuple,
Type,
Union,
TYPE_CHECKING,
)

import agate
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtInternalError, NotImplementedError
Expand Down Expand Up @@ -48,6 +48,9 @@
)
from dbt.adapters.exceptions import FailedToConnectError, InvalidConnectionError

if TYPE_CHECKING:
import agate


SleepTime = Union[int, float] # As taken by time.sleep.
AdapterHandle = Any # Adapter connection handle objects can be any class.
Expand Down Expand Up @@ -162,9 +165,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
conn.handle = LazyHandle(self.open)
# Add the connection to thread_connections for this thread
self.set_thread_connection(conn)
fire_event(
NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info())
)
fire_event(NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()))
else: # existing connection either wasn't open or didn't have the right name
if conn.state != "open":
conn.handle = LazyHandle(self.open)
Expand Down Expand Up @@ -396,7 +397,7 @@ def execute(
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
) -> Tuple[AdapterResponse, agate.Table]:
) -> Tuple[AdapterResponse, "agate.Table"]:
"""Execute the given SQL.
:param str sql: The sql to execute.
Expand Down
106 changes: 61 additions & 45 deletions dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,9 @@
Type,
TypedDict,
Union,
TYPE_CHECKING,
)

import agate
from dbt_common.clients.agate_helper import (
Integer,
empty_table,
get_column_value_uncased,
merge_tables,
table_from_rows,
)
from dbt_common.clients.jinja import CallableMacroGenerator
from dbt_common.contracts.constraints import (
ColumnLevelConstraint,
Expand Down Expand Up @@ -94,6 +87,9 @@
)
from dbt.adapters.protocol import AdapterConfig, MacroContextGeneratorCallable

if TYPE_CHECKING:
import agate


GET_CATALOG_MACRO_NAME = "get_catalog"
GET_CATALOG_RELATIONS_MACRO_NAME = "get_catalog_relations"
Expand All @@ -107,7 +103,14 @@ class ConstraintSupport(str, Enum):
NOT_SUPPORTED = "not_supported"


def _expect_row_value(key: str, row: agate.Row):
def _parse_callback_empty_table(*args, **kwargs) -> Tuple[str, "agate.Table"]:
# Lazy load agate_helper to avoid importing agate when it is not necessary.
from dbt_common.clients.agate_helper import empty_table

return "", empty_table()


def _expect_row_value(key: str, row: "agate.Row"):
if key not in row.keys():
raise DbtInternalError(
'Got a row without "{}" column, columns: {}'.format(key, row.keys())
Expand All @@ -117,13 +120,13 @@ def _expect_row_value(key: str, row: agate.Row):

def _catalog_filter_schemas(
used_schemas: FrozenSet[Tuple[str, str]]
) -> Callable[[agate.Row], bool]:
) -> Callable[["agate.Row"], bool]:
"""Return a function that takes a row and decides if the row should be
included in the catalog output.
"""
schemas = frozenset((d.lower(), s.lower()) for d, s in used_schemas)

def test(row: agate.Row) -> bool:
def test(row: "agate.Row") -> bool:
table_database = _expect_row_value("table_database", row)
table_schema = _expect_row_value("table_schema", row)
# the schema may be present but None, which is not an error and should
Expand Down Expand Up @@ -325,14 +328,14 @@ def connection_named(self, name: str, query_header_context: Any = None) -> Itera
if self.connections.query_header is not None:
self.connections.query_header.reset()

@available.parse(lambda *a, **k: ("", empty_table()))
@available.parse(_parse_callback_empty_table)
def execute(
self,
sql: str,
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
) -> Tuple[AdapterResponse, agate.Table]:
) -> Tuple[AdapterResponse, "agate.Table"]:
"""Execute the given SQL. This is a thin wrapper around
ConnectionManager.execute.
Expand All @@ -342,7 +345,7 @@ def execute(
:param bool fetch: If set, fetch results.
:param Optional[int] limit: If set, only fetch n number of rows
:return: A tuple of the query status and results (empty if fetch=False).
:rtype: Tuple[AdapterResponse, agate.Table]
:rtype: Tuple[AdapterResponse, "agate.Table"]
"""
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch, limit=limit)

Expand Down Expand Up @@ -370,16 +373,16 @@ def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
]
return columns

@available.parse(lambda *a, **k: ("", empty_table()))
def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]:
@available.parse(_parse_callback_empty_table)
def get_partitions_metadata(self, table: str) -> Tuple["agate.Table"]:
"""
TODO: Can we move this to dbt-bigquery?
Obtain partitions metadata for a BigQuery partitioned table.
:param str table: a partitioned table id, in standard SQL format.
:return: a partition metadata tuple, as described in
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
:rtype: agate.Table
:rtype: "agate.Table"
"""
if hasattr(self.connections, "get_partitions_metadata"):
return self.connections.get_partitions_metadata(table=table)
Expand Down Expand Up @@ -423,7 +426,9 @@ def _get_cache_schemas(self, relation_configs: Iterable[RelationConfig]) -> Set[
populate.
"""
return {
self.Relation.create_from(quoting=self.config, relation_config=relation_config).without_identifier()
self.Relation.create_from(
quoting=self.config, relation_config=relation_config
).without_identifier()
for relation_config in relation_configs
}

Expand Down Expand Up @@ -665,7 +670,7 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> List[
# Methods about grants
###
@available
def standardize_grants_dict(self, grants_table: agate.Table) -> dict:
def standardize_grants_dict(self, grants_table: "agate.Table") -> dict:
"""Translate the result of `show grants` (or equivalent) to match the
grants which a user would configure in their project.
Expand Down Expand Up @@ -940,7 +945,7 @@ def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str:
###
@classmethod
@abc.abstractmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Text
type for the given agate table and column index.
Expand All @@ -952,7 +957,7 @@ def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@classmethod
@abc.abstractmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Number
type for the given agate table and column index.
Expand All @@ -963,7 +968,7 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
raise NotImplementedError("`convert_number_type` is not implemented for this adapter!")

@classmethod
def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Number
type for the given agate table and column index.
Expand All @@ -975,7 +980,7 @@ def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@classmethod
@abc.abstractmethod
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_boolean_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Boolean
type for the given agate table and column index.
Expand All @@ -987,7 +992,7 @@ def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@classmethod
@abc.abstractmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.DateTime
type for the given agate table and column index.
Expand All @@ -999,7 +1004,7 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@classmethod
@abc.abstractmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Date
type for the given agate table and column index.
Expand All @@ -1011,7 +1016,7 @@ def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@classmethod
@abc.abstractmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
"""Return the type in the database that best maps to the
agate.TimeDelta type for the given agate table and column index.
Expand All @@ -1023,11 +1028,14 @@ def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@available
@classmethod
def convert_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]:
def convert_type(cls, agate_table: "agate.Table", col_idx: int) -> Optional[str]:
return cls.convert_agate_type(agate_table, col_idx)

@classmethod
def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]:
def convert_agate_type(cls, agate_table: "agate.Table", col_idx: int) -> Optional[str]:
import agate
from dbt_common.clients.agate_helper import Integer

agate_type: Type = agate_table.column_types[col_idx]
conversions: List[Tuple[Type, Callable[..., str]]] = [
(Integer, cls.convert_integer_type),
Expand Down Expand Up @@ -1104,11 +1112,13 @@ def execute_macro(

@classmethod
def _catalog_filter_table(
cls, table: agate.Table, used_schemas: FrozenSet[Tuple[str, str]]
) -> agate.Table:
cls, table: "agate.Table", used_schemas: FrozenSet[Tuple[str, str]]
) -> "agate.Table":
"""Filter the table as appropriate for catalog entries. Subclasses can
override this to change filtering rules on a per-adapter basis.
"""
from dbt_common.clients.agate_helper import table_from_rows

# force database + schema to be strings
table = table_from_rows(
table.rows,
Expand All @@ -1122,7 +1132,7 @@ def _get_one_catalog(
information_schema: InformationSchema,
schemas: Set[str],
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
) -> "agate.Table":
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs)

Expand All @@ -1134,7 +1144,7 @@ def _get_one_catalog_by_relations(
information_schema: InformationSchema,
relations: List[BaseRelation],
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
) -> "agate.Table":
kwargs = {
"information_schema": information_schema,
"relations": relations,
Expand All @@ -1150,7 +1160,7 @@ def get_filtered_catalog(
used_schemas: FrozenSet[Tuple[str, str]],
relations: Optional[Set[BaseRelation]] = None,
):
catalogs: agate.Table
catalogs: "agate.Table"
if (
relations is None
or len(relations) > self.MAX_SCHEMA_METADATA_RELATIONS
Expand All @@ -1173,7 +1183,7 @@ def get_filtered_catalog(
for r in relations
}

def in_map(row: agate.Row):
def in_map(row: "agate.Row"):
d = _expect_row_value("table_database", row)
s = _expect_row_value("table_schema", row)
i = _expect_row_value("table_name", row)
Expand All @@ -1186,16 +1196,16 @@ def in_map(row: agate.Row):

return catalogs, exceptions

def row_matches_relation(self, row: agate.Row, relations: Set[BaseRelation]):
def row_matches_relation(self, row: "agate.Row", relations: Set[BaseRelation]):
pass

def get_catalog(
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
) -> Tuple[agate.Table, List[Exception]]:
) -> Tuple["agate.Table", List[Exception]]:
with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
futures: List[Future["agate.Table"]] = []
schema_map: SchemaSearchMap = self._get_catalog_schemas(relation_configs)
for info, schemas in schema_map.items():
if len(schemas) == 0:
Expand All @@ -1211,9 +1221,9 @@ def get_catalog(

def get_catalog_by_relations(
self, used_schemas: FrozenSet[Tuple[str, str]], relations: Set[BaseRelation]
) -> Tuple[agate.Table, List[Exception]]:
) -> Tuple["agate.Table", List[Exception]]:
with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
futures: List[Future["agate.Table"]] = []
relations_by_schema = self._get_catalog_relations_by_info_schema(relations)
for info_schema in relations_by_schema:
name = ".".join([str(info_schema.database), "information_schema"])
Expand Down Expand Up @@ -1243,6 +1253,8 @@ def calculate_freshness(
macro_resolver: Optional[MacroResolverProtocol] = None,
) -> Tuple[Optional[AdapterResponse], FreshnessResponse]:
"""Calculate the freshness of sources in dbt, and return it"""
import agate

kwargs: Dict[str, Any] = {
"source": source,
"loaded_at_field": loaded_at_field,
Expand All @@ -1253,8 +1265,8 @@ def calculate_freshness(
# in older versions of dbt-core, the 'collect_freshness' macro returned the table of results directly
# starting in v1.5, by default, we return both the table and the adapter response (metadata about the query)
result: Union[
AttrDict, # current: contains AdapterResponse + agate.Table
agate.Table, # previous: just table
AttrDict, # current: contains AdapterResponse + "agate.Table"
"agate.Table", # previous: just table
]
result = self.execute_macro(
FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver
Expand Down Expand Up @@ -1302,6 +1314,8 @@ def calculate_freshness_from_metadata(
adapter_response, table = result.response, result.table # type: ignore[attr-defined]

try:
from dbt_common.clients.agate_helper import get_column_value_uncased

row = table[0]
last_modified_val = get_column_value_uncased("last_modified", row)
snapshotted_at_val = get_column_value_uncased("snapshotted_at", row)
Expand Down Expand Up @@ -1638,10 +1652,12 @@ def supports(cls, capability: Capability) -> bool:


def catch_as_completed(
futures, # typing: List[Future[agate.Table]]
) -> Tuple[agate.Table, List[Exception]]:
# catalogs: agate.Table = agate.Table(rows=[])
tables: List[agate.Table] = []
futures, # typing: List[Future["agate.Table"]]
) -> Tuple["agate.Table", List[Exception]]:
from dbt_common.clients.agate_helper import merge_tables

# catalogs: "agate.Table" =".Table(rows=[])
tables: List["agate.Table"] = []
exceptions: List[Exception] = []

for future in as_completed(futures):
Expand Down
4 changes: 1 addition & 3 deletions dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def register_adapter(self, config: AdapterRequiredConfig, mp_context: SpawnConte
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
adapter_version = self._adapter_version(adapter_name)
fire_event(
AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version)
)
fire_event(AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version))
with self.lock:
if adapter_name in self.adapters:
# this shouldn't really happen...
Expand Down
Loading

0 comments on commit 1c137ee

Please sign in to comment.