diff --git a/src/snowflake/cli/_plugins/workspace/manager.py b/src/snowflake/cli/_plugins/workspace/manager.py index 25b56d542f..5d3a12fb23 100644 --- a/src/snowflake/cli/_plugins/workspace/manager.py +++ b/src/snowflake/cli/_plugins/workspace/manager.py @@ -1,21 +1,17 @@ from pathlib import Path from typing import Dict -from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext -from snowflake.cli.api.cli_global_context import get_cli_context +from snowflake.cli._plugins.workspace.context import ActionContext from snowflake.cli.api.console import cli_console as cc -from snowflake.cli.api.entities.common import EntityActions, get_sql_executor -from snowflake.cli.api.exceptions import InvalidProjectDefinitionVersionError -from snowflake.cli.api.project.definition import default_role -from snowflake.cli.api.project.schemas.entities.entities import ( - Entity, - v2_entity_model_to_entity_map, +from snowflake.cli.api.entities.common import ( + EntityActions, ) +from snowflake.cli.api.exceptions import InvalidProjectDefinitionVersionError +from snowflake.cli.api.project.schemas.entities.entities import Entity from snowflake.cli.api.project.schemas.project_definition import ( DefinitionV20, ProjectDefinition, ) -from snowflake.cli.api.project.util import to_identifier class WorkspaceManager: @@ -41,15 +37,7 @@ def get_entity(self, entity_id: str): entity_model = self._project_definition.entities.get(entity_id, None) if entity_model is None: raise ValueError(f"No such entity ID: {entity_id}") - entity_model_cls = entity_model.__class__ - entity_cls = v2_entity_model_to_entity_map[entity_model_cls] - workspace_ctx = WorkspaceContext( - console=cc, - project_root=self.project_root, - get_default_role=_get_default_role, - get_default_warehouse=_get_default_warehouse, - ) - self._entities_cache[entity_id] = entity_cls(entity_model, workspace_ctx) + self._entities_cache[entity_id] = entity_model.get_entity(cc, self.project_root) return self._entities_cache[entity_id] def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs): @@ -68,17 +56,3 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs) @property def project_root(self) -> Path: return self._project_root - - -def _get_default_role() -> str: - role = default_role() - if role is None: - role = get_sql_executor().current_role() - return role - - -def _get_default_warehouse() -> str | None: - warehouse = get_cli_context().connection.warehouse - if warehouse: - warehouse = to_identifier(warehouse) - return warehouse diff --git a/src/snowflake/cli/api/entities/common.py b/src/snowflake/cli/api/entities/common.py index 9f98f970ec..fb4cc2fa72 100644 --- a/src/snowflake/cli/api/entities/common.py +++ b/src/snowflake/cli/api/entities/common.py @@ -20,7 +20,32 @@ class EntityActions(str, Enum): T = TypeVar("T") -class EntityBase(Generic[T]): +class EntityBaseMetaclass(type): + def __new__(mcs, name, bases, attrs): # noqa: N804 + cls = super().__new__(mcs, name, bases, attrs) + generic_bases = attrs.get("__orig_bases__", []) + if not generic_bases: + # Subclass is not generic + return cls + + target_model_class = get_args(generic_bases[0])[0] # type: ignore[attr-defined] + if target_model_class is T: + # Generic parameter is not filled in + return cls + + target_entity_class = getattr(target_model_class, "_entity_class", None) + if target_entity_class is not None: + raise ValueError( + f"Entity model class {target_model_class} is already " + f"associated with entity class {target_entity_class}, " + f"cannot associate with {cls}" + ) + + setattr(target_model_class, "_entity_class", cls) + return cls + + +class EntityBase(Generic[T], metaclass=EntityBaseMetaclass): """ Base class for the fully-featured entity classes. """ diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index d9036d9a4c..4b6896b6a6 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -15,9 +15,12 @@ from __future__ import annotations from abc import ABC +from pathlib import Path from typing import Dict, Generic, List, Optional, TypeVar, Union from pydantic import Field, PrivateAttr, field_validator +from snowflake.cli._plugins.workspace.context import WorkspaceContext +from snowflake.cli.api.console.abc import AbstractConsole from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.schemas.updatable_model import ( IdentifierField, @@ -110,6 +113,24 @@ def fqn(self) -> FQN: if self.entity_id: return FQN.from_string(self.entity_id) + def get_entity(self, console: AbstractConsole, project_root: Path): + if type(self) is EntityModelBase: + raise NotImplementedError + # Set by EntityBaseMetaclass when creating the + # Entity class that refers to this model + entity_class = getattr(self, "_entity_class", None) + if entity_class is None: + raise ValueError( + f"Entity model class {type(self).__name__} is not associated with an entity class" + ) + workspace_ctx = WorkspaceContext( + console=console, + project_root=project_root, + get_default_role=_get_default_role, + get_default_warehouse=_get_default_warehouse, + ) + return entity_class(self, workspace_ctx) + TargetType = TypeVar("TargetType") @@ -162,3 +183,23 @@ def get_secrets_sql(self) -> str | None: return None secrets = ", ".join(f"'{key}'={value}" for key, value in self.secrets.items()) return f"secrets=({secrets})" + + +def _get_default_role() -> str: + from snowflake.cli.api.entities.common import get_sql_executor + from snowflake.cli.api.project.definition import default_role + + role = default_role() + if role is None: + role = get_sql_executor().current_role() + return role + + +def _get_default_warehouse() -> str | None: + from snowflake.cli.api.cli_global_context import get_cli_context + from snowflake.cli.api.project.util import to_identifier + + warehouse = get_cli_context().connection.warehouse + if warehouse: + warehouse = to_identifier(warehouse) + return warehouse diff --git a/tests/nativeapp/test_version_create.py b/tests/nativeapp/test_version_create.py index 30454bc35b..510adcb0ac 100644 --- a/tests/nativeapp/test_version_create.py +++ b/tests/nativeapp/test_version_create.py @@ -32,7 +32,7 @@ AskAlwaysPolicy, DenyAlwaysPolicy, ) -from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext +from snowflake.cli._plugins.workspace.context import ActionContext from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.connector.cursor import DictCursor @@ -60,13 +60,7 @@ def _version_create( dm = DefinitionManager() pd = dm.project_definition pkg_model: ApplicationPackageEntityModel = pd.entities["app_pkg"] - ctx = WorkspaceContext( - console=cc, - project_root=dm.project_root, - get_default_role=lambda: "mock_role", - get_default_warehouse=lambda: "mock_warehouse", - ) - pkg = ApplicationPackageEntity(pkg_model, ctx) + pkg = pkg_model.get_entity(cc, dm.project_root) return pkg.action_version_create( action_ctx=mock.Mock(spec=ActionContext), version=version, diff --git a/tests/project/test_project_definition_v2.py b/tests/project/test_project_definition_v2.py index 03e7544581..4a0acf8213 100644 --- a/tests/project/test_project_definition_v2.py +++ b/tests/project/test_project_definition_v2.py @@ -23,12 +23,6 @@ ) from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.cli.api.project.errors import SchemaValidationError -from snowflake.cli.api.project.schemas.entities.entities import ( - ALL_ENTITIES, - ALL_ENTITY_MODELS, - v2_entity_model_to_entity_map, - v2_entity_model_types_map, -) from snowflake.cli.api.project.schemas.project_definition import ( DefinitionV20, ) @@ -310,25 +304,6 @@ def test_identifiers(): assert entities["D"].entity_id == "D" -# Verify that each entity model type has the correct "type" field -def test_entity_types(): - for entity_type, entity_class in v2_entity_model_types_map.items(): - model_entity_type = entity_class.get_type() - assert model_entity_type == entity_type - - -# Verify that each entity class has a corresponding entity model class, and that all entities are covered -def test_entity_model_to_entity_map(): - entities = set(ALL_ENTITIES) - entity_models = set(ALL_ENTITY_MODELS) - assert len(entities) == len(entity_models) - for entity_model_class, entity_class in v2_entity_model_to_entity_map.items(): - entities.remove(entity_class) - entity_models.remove(entity_model_class) - assert len(entities) == 0 - assert len(entity_models) == 0 - - @pytest.mark.parametrize( "project_name", [