Skip to content

Commit

Permalink
Merge pull request #63 from dbt-labs/split-command
Browse files Browse the repository at this point in the history
Split command
  • Loading branch information
dave-connors-3 authored Jul 7, 2023
2 parents 912d5ad + b222ef1 commit 764ad1f
Show file tree
Hide file tree
Showing 26 changed files with 1,683 additions and 438 deletions.
7 changes: 7 additions & 0 deletions dbt_meshify/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
help="The path to the dbt project to operate on. Defaults to the current directory.",
)

create_path = click.option(
"--create-path",
type=click.Path(exists=True),
default=None,
help="The path to create the new dbt project. Defaults to the name argument supplied.",
)

exclude = click.option(
"--exclude",
"-e",
Expand Down
129 changes: 107 additions & 22 deletions dbt_meshify/dbt_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,27 @@

import yaml
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ManifestNode, ModelNode, SourceDefinition
from dbt.contracts.graph.nodes import (
Documentation,
Exposure,
Group,
Macro,
ManifestNode,
ModelNode,
Resource,
SourceDefinition,
)
from dbt.contracts.project import Project
from dbt.contracts.results import CatalogArtifact, CatalogTable
from dbt.graph import Graph
from dbt.node_types import NodeType

from dbt_meshify.dbt import Dbt
from dbt_meshify.storage.file_content_editors import (
DbtMeshConstructor,
filter_empty_dict_items,
)
from dbt_meshify.storage.file_manager import DbtFileManager

logger = logging.getLogger()

Expand Down Expand Up @@ -100,7 +115,8 @@ def installed_packages(self) -> Set[str]:
if item.package_name:
_hash = hashlib.md5()
_hash.update(item.package_name.encode("utf-8"))
project_packages.append(_hash.hexdigest())
if _hash.hexdigest() != self.manifest.metadata.project_id:
project_packages.append(_hash.hexdigest())
return set(project_packages)

@property
Expand Down Expand Up @@ -150,9 +166,20 @@ def get_catalog_entry(self, unique_id: str) -> Optional[CatalogTable]:
"""Returns the catalog entry for a model in the dbt project's catalog"""
return self.catalog.nodes.get(unique_id)

def get_manifest_node(self, unique_id: str) -> Optional[ManifestNode]:
"""Returns the catalog entry for a model in the dbt project's catalog"""
return self.manifest.nodes.get(unique_id)
def get_manifest_node(self, unique_id: str) -> Optional[Resource]:
"""Returns the manifest entry for a resource in the dbt project's manifest"""
if unique_id.split(".")[0] in [
"model",
"seed",
"snapshot",
"test",
"analysis",
"snapshot",
]:
return self.manifest.nodes.get(unique_id)
pluralized = NodeType(unique_id.split(".")[0]).pluralize()
resources = getattr(self.manifest, pluralized)
return resources.get(unique_id)


class DbtProject(BaseDbtProject):
Expand Down Expand Up @@ -216,10 +243,13 @@ def split(
project_name: str,
select: str,
exclude: Optional[str] = None,
selector: Optional[str] = None,
) -> "DbtSubProject":
"""Create a new DbtSubProject using NodeSelection syntax."""

subproject_resources = self.select_resources(select, exclude)
subproject_resources = self.select_resources(
select=select, exclude=exclude, selector=selector, output_key="unique_id"
)

# Construct a new project and inject the new manifest
subproject = DbtSubProject(
Expand All @@ -242,24 +272,86 @@ class DbtSubProject(BaseDbtProject):
def __init__(self, name: str, parent_project: DbtProject, resources: Set[str]):
self.name = name
self.resources = resources
self.parent = parent_project
self.parent_project = parent_project
self.path = parent_project.path / Path(name)

self.manifest = parent_project.manifest.deepcopy()
# self.manifest = parent_project.manifest.deepcopy()
# i am running into a bug with the core deepcopy -- checking with michelle
self.manifest = copy.deepcopy(parent_project.manifest)
self.project = copy.deepcopy(parent_project.project)
self.catalog = parent_project.catalog
self.custom_macros = self._get_custom_macros()
self.groups = self._get_indirect_groups()

super().__init__(self.manifest, self.project, self.catalog)
self._rename_project()

def select_resources(self, select: str, exclude: Optional[str] = None) -> Set[str]:
super().__init__(self.manifest, self.project, self.catalog, self.name)

def _rename_project(self) -> None:
"""
edits the project yml to take any instance of the parent project name and update it to the subproject name
"""
project_dict = self.project.to_dict()
for key in [resource.pluralize() for resource in NodeType]:
if self.parent_project.name in project_dict.get(key, {}).keys():
project_dict[key][self.name] = project_dict[key].pop(self.parent_project.name)
project_dict["name"] = self.name
self.project = Project.from_dict(project_dict)

def _get_custom_macros(self) -> Set[str]:
"""
get a set of macro unique_ids for all the selected resources
"""
macros_set = set()
for unique_id in self.resources:
resource = self.get_manifest_node(unique_id)
if not resource or any(
isinstance(resource, class_) for class_ in [Documentation, Group]
):
continue
macros = resource.depends_on.macros # type: ignore
project_macros = [
macro
for macro in macros
if hashlib.md5((macro.split(".")[1]).encode()).hexdigest()
== self.manifest.metadata.project_id
]
macros_set.update(project_macros)
return macros_set

def _get_indirect_groups(self) -> Set[str]:
"""
get a set of group unique_ids for all the selected resources
"""
groups = set()
for unique_id in self.resources:
resource = self.get_manifest_node(unique_id) # type: ignore
if not resource or any(
isinstance(resource, class_)
for class_ in [Documentation, Group, Exposure, SourceDefinition, Macro]
):
continue
group = resource.group # type: ignore
if group:
group_unique_id = f"group.{self.parent_project.name}.{group}"
groups.update({group_unique_id})
return groups

def select_resources(
self,
select: str,
exclude: Optional[str] = None,
selector: Optional[str] = None,
output_key: Optional[str] = None,
) -> Set[str]:
"""
Select resources using the parent DbtProject and filtering down to only include resources in this
subproject.
"""
args = ["--select", select]
if exclude:
args.extend(["--exclude", exclude])

results = self.parent.dbt.ls(self.parent.path, args)
results = self.parent_project.select_resources(
select=select, exclude=exclude, selector=selector, output_key=output_key
)

return set(results) - self.resources

Expand All @@ -276,7 +368,7 @@ def split(
# Construct a new project and inject the new manifest
subproject = DbtSubProject(
name=project_name,
parent_project=copy.deepcopy(self.parent),
parent_project=copy.deepcopy(self.parent_project),
resources=subproject_resources,
)

Expand All @@ -285,13 +377,6 @@ def split(

return subproject

def initialize(self, target_directory: os.PathLike):
"""Initialize this subproject as a full dbt project at the provided `target_directory`."""

# TODO: Implement project initialization

raise NotImplementedError


class DbtProjectHolder:
def __init__(self) -> None:
Expand Down
57 changes: 27 additions & 30 deletions dbt_meshify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from dbt.contracts.graph.unparsed import Owner
from loguru import logger

from dbt_meshify.storage.dbt_project_creator import DbtSubprojectCreator

from .cli import (
create_path,
exclude,
group_yml_path,
owner,
Expand All @@ -20,7 +23,7 @@
selector,
)
from .dbt_projects import DbtProject, DbtProjectHolder, DbtSubProject
from .storage.yaml_editors import DbtMeshModelConstructor
from .storage.file_content_editors import DbtMeshConstructor

log_format = "<white>{time:HH:mm:ss}</white> | <level>{level}</level> | <level>{message}</level>"
logger.remove() # Remove the default sink added by Loguru
Expand Down Expand Up @@ -65,40 +68,36 @@ def connect(projects_dir):


@cli.command(name="split")
@create_path
@click.argument("project_name")
@exclude
@project_path
@select
@selector
def split():
def split(project_name, select, exclude, project_path, selector, create_path):
"""
!!! info
This command is not yet implemented
Splits dbt projects apart by adding all necessary dbt Mesh constructs based on the selection syntax.
Splits out a new subproject from a dbt project by adding all necessary dbt Mesh constructs to the resources based on the selected resources.
"""
path_string = input("Enter the relative path to a dbt project you'd like to split: ")

holder = DbtProjectHolder()

path = Path(path_string).expanduser().resolve()
path = Path(project_path).expanduser().resolve()
project = DbtProject.from_directory(path)
holder.register_project(project)

while True:
subproject_name = input("Enter the name for your subproject ('done' to finish): ")
if subproject_name == "done":
break
subproject_selector = input(
f"Enter the selector that represents the subproject {subproject_name}: "
)

subproject: DbtSubProject = project.split(
project_name=subproject_name, select=subproject_selector
)
holder.register_project(subproject)

print(holder.project_map())
subproject = project.split(
project_name=project_name, select=select, exclude=exclude, selector=selector
)
logger.info(f"Selected {len(subproject.resources)} resources: {subproject.resources}")
target_directory = Path(create_path) if create_path else None
subproject_creator = DbtSubprojectCreator(
subproject=subproject, target_directory=target_directory
)
logger.info(f"Creating subproject {subproject.name}...")
try:
subproject_creator.initialize()
logger.success(f"Successfully created subproject {subproject.name}")
except Exception as e:
logger.error(f"Error creating subproject {subproject.name}")
logger.exception(e)


@operation.command(name="add-contract")
Expand Down Expand Up @@ -127,8 +126,8 @@ def add_contract(select, exclude, project_path, selector, public_only=False):
for model_unique_id in models:
model_node = project.get_manifest_node(model_unique_id)
model_catalog = project.get_catalog_entry(model_unique_id)
meshify_constructor = DbtMeshModelConstructor(
project_path=project_path, model_node=model_node, model_catalog=model_catalog
meshify_constructor = DbtMeshConstructor(
project_path=project_path, node=model_node, catalog=model_catalog
)
logger.info(f"Adding contract to model: {model_unique_id}")
try:
Expand Down Expand Up @@ -164,9 +163,7 @@ def add_version(select, exclude, project_path, selector, prerelease, defined_in)
for model_unique_id in models:
model_node = project.get_manifest_node(model_unique_id)
if model_node.version == model_node.latest_version:
meshify_constructor = DbtMeshModelConstructor(
project_path=project_path, model_node=model_node
)
meshify_constructor = DbtMeshConstructor(project_path=project_path, node=model_node)
try:
meshify_constructor.add_model_version(prerelease=prerelease, defined_in=defined_in)
logger.success(f"Successfully added version to model: {model_unique_id}")
Expand Down
Loading

0 comments on commit 764ad1f

Please sign in to comment.