Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Nicholas A. Yager <yager@nicholasyager.com>
  • Loading branch information
dave-connors-3 and nicholasyager authored Jul 7, 2023
1 parent e20f99e commit 4d5e2df
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
7 changes: 3 additions & 4 deletions dbt_meshify/dbt_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@ def get_manifest_node(self, unique_id: str) -> Optional[ManifestNode]:
"snapshot",
]:
return self.manifest.nodes.get(unique_id)
else:
pluralized = NodeType(unique_id.split(".")[0]).pluralize()
resources = getattr(self.manifest, pluralized)
return resources.get(unique_id)
pluralized = NodeType(unique_id.split(".")[0]).pluralize()
resources = getattr(self.manifest, pluralized)
return resources.get(unique_id)


class DbtProject(BaseDbtProject):
Expand Down
25 changes: 12 additions & 13 deletions dbt_meshify/storage/file_content_editors.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_source_yml_entry(
"""
sources = resources_yml_to_dict(full_yml, NodeType.Source)
source_definition = sources.get(source_name)
tables = source_definition.get("tables", None)
tables = source_definition.get("tables", [])
table = list(filter(lambda x: x["name"] == resource_name, tables))
remaining_tables = list(filter(lambda x: x["name"] != resource_name, tables))
resource_yml = source_definition.copy()
Expand All @@ -127,9 +127,9 @@ def get_source_yml_entry(
sources[source_name] = source_definition
if len(remaining_tables) == 0:
return resource_yml, None
else:
full_yml["sources"] = list(sources.values())
return resource_yml, full_yml

full_yml["sources"] = list(sources.values())
return resource_yml, full_yml

def get_yml_entry(
self,
Expand Down Expand Up @@ -447,17 +447,16 @@ def update_model_refs(self, model_name: str, project_name: str) -> None:

# read the model file
model_code = str(self.file_manager.read_file(model_path))
if self.node.language == "sql":
updated_code = self.update_sql_refs(
model_name=model_name,
project_name=project_name,
model_code=model_code,
)
elif self.node.language == "python":
updated_code = self.update_python_refs(
# This can be defined in the init for this clas.
ref_update_methods = {
'sql': self.update_sql_refs,
'python': self.update_python_refs
}
# Here, we're trusting the dbt-core code to check the languages for us. 🐉
updated_code = ref_update_methods[self.node.language](
model_name=model_name,
project_name=project_name,
model_code=model_code,
)
)
# write the updated model code to the file
self.file_manager.write_file(model_path, updated_code)

0 comments on commit 4d5e2df

Please sign in to comment.