diff --git a/.changes/unreleased/Under the Hood-20240309-150639.yaml b/.changes/unreleased/Under the Hood-20240309-150639.yaml new file mode 100644 index 00000000..82167182 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240309-150639.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Lazy load agate to improve dbt-core performance +time: 2024-03-09T15:06:39.038593-05:00 +custom: + Author: dwreeves + Issue: "125" diff --git a/dbt/adapters/base/connections.py b/dbt/adapters/base/connections.py index fb23c2a8..a3a4d98d 100644 --- a/dbt/adapters/base/connections.py +++ b/dbt/adapters/base/connections.py @@ -18,9 +18,9 @@ Tuple, Type, Union, + TYPE_CHECKING, ) -import agate from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event from dbt_common.exceptions import DbtInternalError, NotImplementedError @@ -48,6 +48,9 @@ ) from dbt.adapters.exceptions import FailedToConnectError, InvalidConnectionError +if TYPE_CHECKING: + import agate + SleepTime = Union[int, float] # As taken by time.sleep. AdapterHandle = Any # Adapter connection handle objects can be any class. @@ -162,9 +165,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection: conn.handle = LazyHandle(self.open) # Add the connection to thread_connections for this thread self.set_thread_connection(conn) - fire_event( - NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) - ) + fire_event(NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info())) else: # existing connection either wasn't open or didn't have the right name if conn.state != "open": conn.handle = LazyHandle(self.open) @@ -396,7 +397,7 @@ def execute( auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None, - ) -> Tuple[AdapterResponse, agate.Table]: + ) -> Tuple[AdapterResponse, "agate.Table"]: """Execute the given SQL. :param str sql: The sql to execute. diff --git a/dbt/adapters/base/impl.py b/dbt/adapters/base/impl.py index b28494be..5b4b8080 100644 --- a/dbt/adapters/base/impl.py +++ b/dbt/adapters/base/impl.py @@ -20,16 +20,9 @@ Type, TypedDict, Union, + TYPE_CHECKING, ) -import agate -from dbt_common.clients.agate_helper import ( - Integer, - empty_table, - get_column_value_uncased, - merge_tables, - table_from_rows, -) from dbt_common.clients.jinja import CallableMacroGenerator from dbt_common.contracts.constraints import ( ColumnLevelConstraint, @@ -94,6 +87,9 @@ ) from dbt.adapters.protocol import AdapterConfig, MacroContextGeneratorCallable +if TYPE_CHECKING: + import agate + GET_CATALOG_MACRO_NAME = "get_catalog" GET_CATALOG_RELATIONS_MACRO_NAME = "get_catalog_relations" @@ -107,7 +103,14 @@ class ConstraintSupport(str, Enum): NOT_SUPPORTED = "not_supported" -def _expect_row_value(key: str, row: agate.Row): +def _parse_callback_empty_table(*args, **kwargs) -> Tuple[str, "agate.Table"]: + # Lazy load agate_helper to avoid importing agate when it is not necessary. + from dbt_common.clients.agate_helper import empty_table + + return "", empty_table() + + +def _expect_row_value(key: str, row: "agate.Row"): if key not in row.keys(): raise DbtInternalError( 'Got a row without "{}" column, columns: {}'.format(key, row.keys()) @@ -117,13 +120,13 @@ def _expect_row_value(key: str, row: agate.Row): def _catalog_filter_schemas( used_schemas: FrozenSet[Tuple[str, str]] -) -> Callable[[agate.Row], bool]: +) -> Callable[["agate.Row"], bool]: """Return a function that takes a row and decides if the row should be included in the catalog output. """ schemas = frozenset((d.lower(), s.lower()) for d, s in used_schemas) - def test(row: agate.Row) -> bool: + def test(row: "agate.Row") -> bool: table_database = _expect_row_value("table_database", row) table_schema = _expect_row_value("table_schema", row) # the schema may be present but None, which is not an error and should @@ -325,14 +328,14 @@ def connection_named(self, name: str, query_header_context: Any = None) -> Itera if self.connections.query_header is not None: self.connections.query_header.reset() - @available.parse(lambda *a, **k: ("", empty_table())) + @available.parse(_parse_callback_empty_table) def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None, - ) -> Tuple[AdapterResponse, agate.Table]: + ) -> Tuple[AdapterResponse, "agate.Table"]: """Execute the given SQL. This is a thin wrapper around ConnectionManager.execute. @@ -342,7 +345,7 @@ def execute( :param bool fetch: If set, fetch results. :param Optional[int] limit: If set, only fetch n number of rows :return: A tuple of the query status and results (empty if fetch=False). - :rtype: Tuple[AdapterResponse, agate.Table] + :rtype: Tuple[AdapterResponse, "agate.Table"] """ return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch, limit=limit) @@ -370,8 +373,8 @@ def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: ] return columns - @available.parse(lambda *a, **k: ("", empty_table())) - def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]: + @available.parse(_parse_callback_empty_table) + def get_partitions_metadata(self, table: str) -> Tuple["agate.Table"]: """ TODO: Can we move this to dbt-bigquery? Obtain partitions metadata for a BigQuery partitioned table. @@ -379,7 +382,7 @@ def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]: :param str table: a partitioned table id, in standard SQL format. :return: a partition metadata tuple, as described in https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables. - :rtype: agate.Table + :rtype: "agate.Table" """ if hasattr(self.connections, "get_partitions_metadata"): return self.connections.get_partitions_metadata(table=table) @@ -423,7 +426,9 @@ def _get_cache_schemas(self, relation_configs: Iterable[RelationConfig]) -> Set[ populate. """ return { - self.Relation.create_from(quoting=self.config, relation_config=relation_config).without_identifier() + self.Relation.create_from( + quoting=self.config, relation_config=relation_config + ).without_identifier() for relation_config in relation_configs } @@ -665,7 +670,7 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> List[ # Methods about grants ### @available - def standardize_grants_dict(self, grants_table: agate.Table) -> dict: + def standardize_grants_dict(self, grants_table: "agate.Table") -> dict: """Translate the result of `show grants` (or equivalent) to match the grants which a user would configure in their project. @@ -940,7 +945,7 @@ def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str: ### @classmethod @abc.abstractmethod - def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str: """Return the type in the database that best maps to the agate.Text type for the given agate table and column index. @@ -952,7 +957,7 @@ def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: @classmethod @abc.abstractmethod - def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str: """Return the type in the database that best maps to the agate.Number type for the given agate table and column index. @@ -963,7 +968,7 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: raise NotImplementedError("`convert_number_type` is not implemented for this adapter!") @classmethod - def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str: """Return the type in the database that best maps to the agate.Number type for the given agate table and column index. @@ -975,7 +980,7 @@ def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str: @classmethod @abc.abstractmethod - def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_boolean_type(cls, agate_table: "agate.Table", col_idx: int) -> str: """Return the type in the database that best maps to the agate.Boolean type for the given agate table and column index. @@ -987,7 +992,7 @@ def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str: @classmethod @abc.abstractmethod - def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str: """Return the type in the database that best maps to the agate.DateTime type for the given agate table and column index. @@ -999,7 +1004,7 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: @classmethod @abc.abstractmethod - def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str: """Return the type in the database that best maps to the agate.Date type for the given agate table and column index. @@ -1011,7 +1016,7 @@ def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: @classmethod @abc.abstractmethod - def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str: """Return the type in the database that best maps to the agate.TimeDelta type for the given agate table and column index. @@ -1023,11 +1028,14 @@ def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: @available @classmethod - def convert_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]: + def convert_type(cls, agate_table: "agate.Table", col_idx: int) -> Optional[str]: return cls.convert_agate_type(agate_table, col_idx) @classmethod - def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]: + def convert_agate_type(cls, agate_table: "agate.Table", col_idx: int) -> Optional[str]: + import agate + from dbt_common.clients.agate_helper import Integer + agate_type: Type = agate_table.column_types[col_idx] conversions: List[Tuple[Type, Callable[..., str]]] = [ (Integer, cls.convert_integer_type), @@ -1104,11 +1112,13 @@ def execute_macro( @classmethod def _catalog_filter_table( - cls, table: agate.Table, used_schemas: FrozenSet[Tuple[str, str]] - ) -> agate.Table: + cls, table: "agate.Table", used_schemas: FrozenSet[Tuple[str, str]] + ) -> "agate.Table": """Filter the table as appropriate for catalog entries. Subclasses can override this to change filtering rules on a per-adapter basis. """ + from dbt_common.clients.agate_helper import table_from_rows + # force database + schema to be strings table = table_from_rows( table.rows, @@ -1122,7 +1132,7 @@ def _get_one_catalog( information_schema: InformationSchema, schemas: Set[str], used_schemas: FrozenSet[Tuple[str, str]], - ) -> agate.Table: + ) -> "agate.Table": kwargs = {"information_schema": information_schema, "schemas": schemas} table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs) @@ -1134,7 +1144,7 @@ def _get_one_catalog_by_relations( information_schema: InformationSchema, relations: List[BaseRelation], used_schemas: FrozenSet[Tuple[str, str]], - ) -> agate.Table: + ) -> "agate.Table": kwargs = { "information_schema": information_schema, "relations": relations, @@ -1150,7 +1160,7 @@ def get_filtered_catalog( used_schemas: FrozenSet[Tuple[str, str]], relations: Optional[Set[BaseRelation]] = None, ): - catalogs: agate.Table + catalogs: "agate.Table" if ( relations is None or len(relations) > self.MAX_SCHEMA_METADATA_RELATIONS @@ -1173,7 +1183,7 @@ def get_filtered_catalog( for r in relations } - def in_map(row: agate.Row): + def in_map(row: "agate.Row"): d = _expect_row_value("table_database", row) s = _expect_row_value("table_schema", row) i = _expect_row_value("table_name", row) @@ -1186,16 +1196,16 @@ def in_map(row: agate.Row): return catalogs, exceptions - def row_matches_relation(self, row: agate.Row, relations: Set[BaseRelation]): + def row_matches_relation(self, row: "agate.Row", relations: Set[BaseRelation]): pass def get_catalog( self, relation_configs: Iterable[RelationConfig], used_schemas: FrozenSet[Tuple[str, str]], - ) -> Tuple[agate.Table, List[Exception]]: + ) -> Tuple["agate.Table", List[Exception]]: with executor(self.config) as tpe: - futures: List[Future[agate.Table]] = [] + futures: List[Future["agate.Table"]] = [] schema_map: SchemaSearchMap = self._get_catalog_schemas(relation_configs) for info, schemas in schema_map.items(): if len(schemas) == 0: @@ -1211,9 +1221,9 @@ def get_catalog( def get_catalog_by_relations( self, used_schemas: FrozenSet[Tuple[str, str]], relations: Set[BaseRelation] - ) -> Tuple[agate.Table, List[Exception]]: + ) -> Tuple["agate.Table", List[Exception]]: with executor(self.config) as tpe: - futures: List[Future[agate.Table]] = [] + futures: List[Future["agate.Table"]] = [] relations_by_schema = self._get_catalog_relations_by_info_schema(relations) for info_schema in relations_by_schema: name = ".".join([str(info_schema.database), "information_schema"]) @@ -1243,6 +1253,8 @@ def calculate_freshness( macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: """Calculate the freshness of sources in dbt, and return it""" + import agate + kwargs: Dict[str, Any] = { "source": source, "loaded_at_field": loaded_at_field, @@ -1253,8 +1265,8 @@ def calculate_freshness( # in older versions of dbt-core, the 'collect_freshness' macro returned the table of results directly # starting in v1.5, by default, we return both the table and the adapter response (metadata about the query) result: Union[ - AttrDict, # current: contains AdapterResponse + agate.Table - agate.Table, # previous: just table + AttrDict, # current: contains AdapterResponse + "agate.Table" + "agate.Table", # previous: just table ] result = self.execute_macro( FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver @@ -1302,6 +1314,8 @@ def calculate_freshness_from_metadata( adapter_response, table = result.response, result.table # type: ignore[attr-defined] try: + from dbt_common.clients.agate_helper import get_column_value_uncased + row = table[0] last_modified_val = get_column_value_uncased("last_modified", row) snapshotted_at_val = get_column_value_uncased("snapshotted_at", row) @@ -1638,10 +1652,12 @@ def supports(cls, capability: Capability) -> bool: def catch_as_completed( - futures, # typing: List[Future[agate.Table]] -) -> Tuple[agate.Table, List[Exception]]: - # catalogs: agate.Table = agate.Table(rows=[]) - tables: List[agate.Table] = [] + futures, # typing: List[Future["agate.Table"]] +) -> Tuple["agate.Table", List[Exception]]: + from dbt_common.clients.agate_helper import merge_tables + + # catalogs: "agate.Table" =".Table(rows=[]) + tables: List["agate.Table"] = [] exceptions: List[Exception] = [] for future in as_completed(futures): diff --git a/dbt/adapters/factory.py b/dbt/adapters/factory.py index e5c7be78..d77ad2f4 100644 --- a/dbt/adapters/factory.py +++ b/dbt/adapters/factory.py @@ -100,9 +100,7 @@ def register_adapter(self, config: AdapterRequiredConfig, mp_context: SpawnConte adapter_name = config.credentials.type adapter_type = self.get_adapter_class_by_name(adapter_name) adapter_version = self._adapter_version(adapter_name) - fire_event( - AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version) - ) + fire_event(AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version)) with self.lock: if adapter_name in self.adapters: # this shouldn't really happen... diff --git a/dbt/adapters/protocol.py b/dbt/adapters/protocol.py index f27394ad..bbfdd330 100644 --- a/dbt/adapters/protocol.py +++ b/dbt/adapters/protocol.py @@ -10,10 +10,10 @@ Type, TypeVar, Tuple, + TYPE_CHECKING, ) from typing_extensions import Protocol -import agate from dbt_common.clients.jinja import MacroProtocol from dbt_common.contracts.config.base import BaseConfig @@ -25,6 +25,9 @@ from dbt.adapters.contracts.macros import MacroResolverProtocol from dbt.adapters.contracts.relation import HasQuoting, Policy, RelationConfig +if TYPE_CHECKING: + import agate + @dataclass class AdapterConfig(BaseConfig): @@ -169,5 +172,5 @@ def commit_if_has_connection(self) -> None: def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False - ) -> Tuple[AdapterResponse, agate.Table]: + ) -> Tuple[AdapterResponse, "agate.Table"]: ... diff --git a/dbt/adapters/relation_configs/config_base.py b/dbt/adapters/relation_configs/config_base.py index e8131b67..62d14059 100644 --- a/dbt/adapters/relation_configs/config_base.py +++ b/dbt/adapters/relation_configs/config_base.py @@ -1,9 +1,11 @@ from dataclasses import dataclass -from typing import Dict, Union +from typing import Dict, Union, TYPE_CHECKING -import agate from dbt_common.utils import filter_null_values +if TYPE_CHECKING: + import agate + """ This is what relation metadata from the database looks like. It's a dictionary because there will be @@ -18,7 +20,7 @@ ]) } """ -RelationResults = Dict[str, Union[agate.Row, agate.Table]] +RelationResults = Dict[str, Union["agate.Row", "agate.Table"]] @dataclass(frozen=True) diff --git a/dbt/adapters/sql/connections.py b/dbt/adapters/sql/connections.py index f9d802aa..78cd3c9b 100644 --- a/dbt/adapters/sql/connections.py +++ b/dbt/adapters/sql/connections.py @@ -1,9 +1,7 @@ import abc import time -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING -import agate -from dbt_common.clients.agate_helper import empty_table, table_from_data_flat from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event from dbt_common.exceptions import DbtInternalError, NotImplementedError @@ -22,6 +20,9 @@ SQLQueryStatus, ) +if TYPE_CHECKING: + import agate + class SQLConnectionManager(BaseConnectionManager): """The default connection manager with some common SQL methods implemented. @@ -126,7 +127,9 @@ def process_results( return [dict(zip(column_names, row)) for row in rows] @classmethod - def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> agate.Table: + def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> "agate.Table": + from dbt_common.clients.agate_helper import table_from_data_flat + data: List[Any] = [] column_names: List[str] = [] @@ -146,7 +149,9 @@ def execute( auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None, - ) -> Tuple[AdapterResponse, agate.Table]: + ) -> Tuple[AdapterResponse, "agate.Table"]: + from dbt_common.clients.agate_helper import empty_table + sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) response = self.get_response(cursor) diff --git a/dbt/adapters/sql/impl.py b/dbt/adapters/sql/impl.py index c3a75cc6..8c6e0e8e 100644 --- a/dbt/adapters/sql/impl.py +++ b/dbt/adapters/sql/impl.py @@ -1,6 +1,5 @@ -from typing import Any, List, Optional, Tuple, Type +from typing import Any, List, Optional, Tuple, Type, TYPE_CHECKING -import agate from dbt_common.events.functions import fire_event from dbt.adapters.base import BaseAdapter, BaseRelation, available @@ -23,6 +22,9 @@ ALTER_COLUMN_TYPE_MACRO_NAME = "alter_column_type" VALIDATE_SQL_MACRO_NAME = "validate_sql" +if TYPE_CHECKING: + import agate + class SQLAdapter(BaseAdapter): """The default adapter with the common agate conversions and some SQL @@ -65,33 +67,35 @@ def add_query( return self.connections.add_query(sql, auto_begin, bindings, abridge_sql_log) @classmethod - def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "text" @classmethod - def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str: + import agate + # TODO CT-211 decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined] return "float8" if decimals else "integer" @classmethod - def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "integer" @classmethod - def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_boolean_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "boolean" @classmethod - def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "timestamp without time zone" @classmethod - def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "date" @classmethod - def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "time" @classmethod