diff --git a/client.py b/client.py index d652730..1deb48c 100644 --- a/client.py +++ b/client.py @@ -46,9 +46,12 @@ def _compose_and_draw_inner(self, pygraph: AGraph, flow: TokenSequence, name: st def _draw_architectural_connections(self, pygraph: AGraph, diagram: TokenSequence): diagram.architecture.draw_connections(pygraph=pygraph) - def draw_epc(self, out_path: str, file_format: list[FileFormat]): + def draw_epc(self, out_path: str, file_format: list[FileFormat], language: str): cluster = Cluster() - cluster.extract_flows(file_name_list=[file.input_path for file in file_format]) + cluster.extract_flows( + file_name_list=[file.input_path for file in file_format], + language=language, + ) main_flows = [flow.tokens for flow in cluster._main_flows] ARCHG = AGraph(directed=True, compound=True) diff --git a/clusterer.py b/clusterer.py index 2252418..fc6b87c 100644 --- a/clusterer.py +++ b/clusterer.py @@ -66,10 +66,10 @@ def _find_main_flows(self, token: str, index: int, flows: Flows): if flows.latest_flow is not None: flows.latest_flow.tokens.append(token) - def extract_flows(self, file_name_list: list[str]): + def extract_flows(self, file_name_list: list[str], language: str): for file in file_name_list: parser = Parser() - parsed = parser.parse(file) + parsed = parser.parse(file, language) flows = Flows() for index, token_raw in enumerate(parsed): diff --git a/main.py b/main.py index 611163a..f33b8d6 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import os +import re from typing import Optional import typer @@ -27,8 +28,22 @@ def get_filepaths(directory): return file_paths +def is_file_allowed(file_name: str, extension: str): + file_extension = re.search(r"\.\w+", file_name)[0] + + if file_extension == f".{extension}": + return True + + return False + + @app.command() -def docupyt(path: Optional[str] = None, out_path: Optional[str] = "_outputs"): +def docupyt( + path: Optional[str] = None, + out_path: Optional[str] = "_outputs", + extension="py", + language="python", +): if not path: raise ValueError("Path is required") @@ -38,13 +53,14 @@ def docupyt(path: Optional[str] = None, out_path: Optional[str] = "_outputs"): formats = [] for filepath in get_filepaths(path): - formats.append( - FileFormat( - input_path=filepath, + if is_file_allowed(filepath, extension): + formats.append( + FileFormat( + input_path=filepath, + ) ) - ) - client.draw_epc(out_path=out_path, file_format=formats) + client.draw_epc(out_path=out_path, file_format=formats, language=language) if __name__ == "__main__": diff --git a/parser/doctree_parser.py b/parser/doctree_parser.py index b149040..4b2fd99 100644 --- a/parser/doctree_parser.py +++ b/parser/doctree_parser.py @@ -10,10 +10,10 @@ def parse(self): class Parser(IParser): - def parse(self, file_path: str): + def parse(self, file_path: str, language: str): with open(file_path) as file: file_str = file.read() return ctok.tokenize( file_str, - lang="python", + lang=language, )