diff --git a/agent/function/plan_agent.py b/agent/function/plan_agent.py index 68b4a88..1d513de 100644 --- a/agent/function/plan_agent.py +++ b/agent/function/plan_agent.py @@ -51,7 +51,7 @@ def gen_file_name(project, llm_agent): console.log(f"Error parsing filenames: {e}") file_name_candidates = [] # Fallback to an empty list if parsing fails else: - file_name_candidates = (file_name_candidates, ) + file_name_candidates = (file_name_candidates,) file_path_candidates = [os.path.join(project.path, filename.strip("'")) for filename in file_name_candidates] @@ -99,7 +99,7 @@ def pmpt_task_select(): Output Format: Your response should include three entries formatted as a list of strings, where each string contains the task name followed by its description, e.g.: - ["Task1, Description of Task1", "Task2, Description of Task2", "Task3, Description of Task3"] + ["Task1: Description of Task1", "Task2: Description of Task2", "Task3: Description of Task3"] Note: Return only the task names followed by a brief description, without any additional information or punctuation. """ @@ -124,9 +124,9 @@ def pmpt_model_select(): Please return a list of three strings, where each string includes the model's name followed by its summary. Ensure the description highlights how the model meets one of the selection criteria (best, balanced, fastest). Example format: - ["ModelName1, Best for task X with high accuracy of Y%, suitable for complex data analysis.", - "ModelName2, Balanced model, offers moderate accuracy with better speed, good for real-time applications.", - "ModelName3, Fastest model with lower accuracy, best for quick processing where speed is prioritized over precision."] + ["ModelName1: Best for task X with high accuracy of Y%, suitable for complex data analysis.", + "ModelName2: Balanced model, offers moderate accuracy with better speed, good for real-time applications.", + "ModelName3: Fastest model with lower accuracy, best for quick processing where speed is prioritized over precision."] Note: Ensure that the architecture names with summary are returned without any additional punctuation. """ diff --git a/agent/function/tech_leader.py b/agent/function/tech_leader.py index 837207b..3df1fb7 100644 --- a/agent/function/tech_leader.py +++ b/agent/function/tech_leader.py @@ -50,153 +50,198 @@ def __init__(self, project: Project, model): if self.project.plan is None: self.project.plan = Plan(current_task=0) - def start(self): + def user_requirement_understanding(self): """ - Execute the chain. - :return: the result of the chain. + (STEP-0) User Requirement Understanding. + :return: """ - try: - is_running = True - while is_running: - show_panel("STEP 1: User Requirements Understanding") - if self.project.requirement: - self.console.log(f"[cyan]User Requirement:[/cyan] {self.project.requirement}") - else: - self.requirement = questionary.text("Hi, what are your requirements?").ask() - self.project.requirement = self.requirement - - if not self.requirement: - raise SystemExit("The user requirement is not provided.") - - # Generate entry file name based on requirement - enhanced_requirement = self.requirement - if self.entry_file is None: - self.entry_file = gen_file_name(self.project, self.model) - self.console.log(f"The entry file is: {self.entry_file}") - - show_panel("STEP 2: Dataset Selection") - if self.project.plan.data_kind is None and self.project.plan.dataset is None: - self.project.plan.data_kind = analyze_requirement( - enhanced_requirement, - pmpt_dataset_detect(), + if self.project.requirement: + self.console.log(f"[cyan]User Requirement:[/cyan] {self.project.requirement}") + else: + show_panel("STEP 0: User Requirements Understanding") + self.requirement = questionary.text("Hi, what are your requirements?").ask() + self.project.requirement = self.requirement + + if not self.requirement: + raise SystemExit("The user requirement is not provided.") + + # Generate entry file name based on requirement + self.project.enhanced_requirement = self.requirement + if self.entry_file is None: + self.entry_file = gen_file_name(self.project, self.model) + self.console.log(f"The entry file is: {self.entry_file}") + update_project_state(self.project) + + def dataset_selection(self): + """ + (STEP-1) Dataset Selection. + :return: + """ + + if self.project.plan.data_kind is None and self.project.plan.dataset is None: + show_panel("STEP 1: Dataset Selection") + self.project.plan.data_kind = analyze_requirement( + self.project.enhanced_requirement, + pmpt_dataset_detect(), + self.model + ) + + if self.project.plan.data_kind == 'no_data_information_provided': + public_dataset_list = analyze_requirement( + self.project.enhanced_requirement, + pmpt_public_dataset_guess(), + self.model + ) + public_dataset_list = ast.literal_eval(public_dataset_list) + self.project.plan.dataset = questionary.select( + "Please select the dataset:", + choices=public_dataset_list + ).ask() + elif self.project.plan.data_kind == 'csv_data': + self.project.plan.dataset = questionary.text("Please provide the CSV data path:").ask() + if os.path.exists(self.project.plan.dataset) is False: + public_dataset_list = analyze_requirement( + self.project.enhanced_requirement, + pmpt_public_dataset_guess(), self.model ) - if self.project.plan.data_kind == 'no_data_information_provided': - public_dataset_list = analyze_requirement( - enhanced_requirement, - pmpt_public_dataset_guess(), - self.model - ) - public_dataset_list = ast.literal_eval(public_dataset_list) - self.project.plan.dataset = questionary.select( - "Please select the dataset:", - choices=public_dataset_list - ).ask() - - elif self.project.plan.data_kind == 'csv_data': - self.project.plan.dataset = questionary.text("Please provide the CSV data path:").ask() - if os.path.exists(self.project.plan.dataset) is False: - public_dataset_list = analyze_requirement( - enhanced_requirement, - pmpt_public_dataset_guess(), - self.model - ) - public_dataset_list = ast.literal_eval(public_dataset_list) - self.project.plan.dataset = questionary.select( - "Please select the dataset:", - choices=public_dataset_list - ).ask() - - if self.project.plan.dataset is None: - raise SystemExit("There is no dataset information. Aborted.") - else: - self.console.log(f"[cyan]Data source:[/cyan] {self.project.plan.dataset}") - if self.project.plan.data_kind == 'csv_data': - csv_data_sample = read_csv_file(self.project.plan.dataset) - self.console.log(f"[cyan]Dataset examples:[/cyan] {csv_data_sample}") - enhanced_requirement += f"\nDataset: {self.project.plan.dataset}" - enhanced_requirement += f"\nDataset Sample: {csv_data_sample}" - update_project_state(self.project) - - show_panel("STEP 3: Task & Model Selection") - if self.project.plan.ml_task_type is None: - ml_task_list = analyze_requirement(enhanced_requirement, pmpt_task_select(), self.model) - ml_task_list = ast.literal_eval(ml_task_list) - ml_task_type = questionary.select( - "Please select the ML task type:", - choices=ml_task_list + public_dataset_list = ast.literal_eval(public_dataset_list) + self.project.plan.dataset = questionary.select( + "Please select the dataset:", + choices=public_dataset_list ).ask() - self.console.log(f"[cyan]ML task type detected:[/cyan] {ml_task_type}") - confirm_ml_task_type = questionary.confirm("Are you sure to use this ml task type?").ask() - if confirm_ml_task_type: - self.project.plan.ml_task_type = ml_task_type - update_project_state(self.project) - else: - self.console.log("Seems you are not satisfied with the task type. Aborting the chain.") - return - - enhanced_requirement += f"\n\nML task type: {self.project.plan.ml_task_type}" - if self.project.plan.ml_model_arch is None: - ml_model_list = analyze_requirement(enhanced_requirement, pmpt_model_select(), self.model) - ml_model_list = ast.literal_eval(ml_model_list) - ml_model_arch = questionary.select( - "Please select the ML model architecture:", - choices=ml_model_list - ).ask() - self.console.log(f"[cyan]Model architecture detected:[/cyan] {ml_model_arch}") - confirm_ml_model_arch = questionary.confirm("Are you sure to use this ml arch?").ask() - if confirm_ml_model_arch: - self.project.plan.ml_model_arch = ml_model_arch - update_project_state(self.project) - else: - self.console.log("Seems you are not satisfied with the model architecture. Aborting the chain.") - return - - enhanced_requirement += f"\nModel architecture: {self.project.plan.ml_model_arch}" - - show_panel("STEP 4: Planning") - if self.project.plan.tasks is None: - self.console.log( - f"The project [cyan]{self.project.name}[/cyan] has no existing plans. Start planning..." - ) - enhanced_requirement += f"\nDataset: {self.project.plan.dataset}" - with self.console.status("Planning the tasks for you..."): - task_dicts = plan_generator( - enhanced_requirement, - self.model, - ) - self.console.print(generate_plan_card_ascii(task_dicts), highlight=False) - self.project.plan.tasks = [] - for task_dict in task_dicts.get('tasks'): - task = match_plan(task_dict) - if task: - self.project.plan.tasks.append(task) - - # Confirm the plan - confirm_plan = questionary.confirm("Are you sure to use this plan?").ask() - if confirm_plan: - update_project_state(self.project) - else: - self.console.log("Seems you are not satisfied with the plan. Aborting the chain.") - return - - task_num = len(self.project.plan.tasks) - # check if all tasks are completed. - # if self.project.plan.current_task == task_num: - # self.console.log(":tada: Looks like all tasks are completed.") - # return - - # code generation - show_panel("STEP 5: Code Generation") - code_generation_agent = CodeAgent(self.model, self.project) - code_generation_agent.invoke(task_num, self.requirement) - - # install the dependencies for this plan and code. - show_panel("STEP 6: Execution and Refection") - launch_agent = SetupAgent(self.model, self.project) - launch_agent.invoke() - is_running = False + if self.project.plan.dataset is None: + raise SystemExit("There is no dataset information. Aborted.") + else: + self.console.log(f"[cyan]Data source:[/cyan] {self.project.plan.dataset}") + if self.project.plan.data_kind == 'csv_data': + csv_data_sample = read_csv_file(self.project.plan.dataset) + self.console.log(f"[cyan]Dataset examples:[/cyan] {csv_data_sample}") + self.project.enhanced_requirement += f"\nDataset Sample: {csv_data_sample}" + + self.project.enhanced_requirement += f"\nDataset: {self.project.plan.dataset}" + update_project_state(self.project) + + def task_model_selection(self): + """ + (STEP-2) Task & Model Selection. + :return: + """ + if self.project.plan.ml_task_type is None or self.project.plan.ml_model_arch is None: + show_panel("STEP 2: Task & Model Selection") + + # select the ml task type + if self.project.plan.ml_task_type is None: + ml_task_list = analyze_requirement(self.project.enhanced_requirement, pmpt_task_select(), self.model) + ml_task_list = ast.literal_eval(ml_task_list) + ml_task_type = questionary.select( + "Please select the ML task type:", + choices=ml_task_list + ).ask() + + self.console.log(f"[cyan]ML task type detected:[/cyan] {ml_task_type}") + confirm_ml_task_type = questionary.confirm("Are you sure to use this ml task type?").ask() + if confirm_ml_task_type: + self.project.plan.ml_task_type = ml_task_type + else: + self.console.log("Seems you are not satisfied with the task type. Aborting the chain.") + return + self.project.enhanced_requirement += f"\n\nML task type: {self.project.plan.ml_task_type}" + + # select the mode architecture + if self.project.plan.ml_model_arch is None: + ml_model_list = analyze_requirement(self.project.enhanced_requirement, pmpt_model_select(), self.model) + ml_model_list = ast.literal_eval(ml_model_list) + ml_model_arch = questionary.select( + "Please select the ML model architecture:", + choices=ml_model_list + ).ask() + self.console.log(f"[cyan]Model architecture detected:[/cyan] {ml_model_arch}") + confirm_ml_model_arch = questionary.confirm("Are you sure to use this ml arch?").ask() + if confirm_ml_model_arch: + self.project.plan.ml_model_arch = ml_model_arch + else: + self.console.log("Seems you are not satisfied with the model architecture. Aborting the chain.") + return + + update_project_state(self.project) + self.project.enhanced_requirement += f"\nModel architecture: {self.project.plan.ml_model_arch}" + + def task_planning(self): + """ + (STEP-3) Task Planning. + :return: + """ + if self.project.plan.tasks is None: + show_panel("STEP 3: Task Planning") + self.console.log( + f"The project [cyan]{self.project.name}[/cyan] has no existing plans. Start planning..." + ) + self.project.enhanced_requirement += f"\nDataset: {self.project.plan.dataset}" + with self.console.status("Planning the tasks for you..."): + task_dicts = plan_generator(self.project.enhanced_requirement, self.model) + self.console.print(generate_plan_card_ascii(task_dicts), highlight=False) + self.project.plan.tasks = [] + for task_dict in task_dicts.get('tasks'): + task = match_plan(task_dict) + if task: + self.project.plan.tasks.append(task) + + # Confirm the plan + confirm_plan = questionary.confirm("Are you sure to use this plan?").ask() + if confirm_plan: + update_project_state(self.project) + else: + self.console.log("Seems you are not satisfied with the plan. Aborting the chain.") + return + else: + tasks = [] + for t in self.project.plan.tasks: + tasks.append({'name': t.name, 'resources': [r.name for r in t.resources], 'description': t.description}) + self.console.print(generate_plan_card_ascii({'tasks': tasks}), highlight=False) + + def code_generation(self): + """ + (STEP-4) Code Generation. + :return: + """ + task_num = len(self.project.plan.tasks) + if self.project.plan.current_task < task_num: + show_panel("STEP 4: Code Generation") + code_generation_agent = CodeAgent(self.model, self.project) + code_generation_agent.invoke(task_num, self.requirement) + update_project_state(self.project) + + def execution_and_reflection(self): + """ + (STEP-5) Execution and Reflection. + :return: + """ + show_panel("STEP 5: Execution and Reflection") + launch_agent = SetupAgent(self.model, self.project) + launch_agent.invoke() + update_project_state(self.project) + + def start(self): + """ + Execute the chain. + :return: the result of the chain. + """ + try: + # STEP 0: User Requirement Understanding + self.user_requirement_understanding() + # STEP 1: Dataset Selection + self.dataset_selection() + # STEP 2: Task & Model Selection + self.task_model_selection() + # STEP 3: Task Planning + self.task_planning() + # STEP 4: Code Generation + self.code_generation() + # STEP 5: Execution and Reflection + self.execution_and_reflection() + self.console.log("The chain has been completed.") except KeyboardInterrupt: self.console.log("MLE Plan Agent has been interrupted.") return diff --git a/agent/types/base.py b/agent/types/base.py index a294a9e..678c80c 100644 --- a/agent/types/base.py +++ b/agent/types/base.py @@ -50,11 +50,12 @@ class Plan(BaseModel): class Project(BaseModel): name: str - path: Optional[str] = None lang: str llm: str + path: Optional[str] = None plan: Optional[Plan] = None entry_file: Optional[str] = None debug_env: Optional[str] = DebugEnv.local description: Optional[str] = None requirement: Optional[str] = None + enhanced_requirement: Optional[str] = None