diff --git a/internal/roadmap_agent.py b/internal/roadmap_agent.py index 7c65585..4a119fb 100644 --- a/internal/roadmap_agent.py +++ b/internal/roadmap_agent.py @@ -301,9 +301,30 @@ def generate_roadmap(user_input: Dict[str, Any], courses: List[Dict[str, Any]]) workflow = create_roadmap_workflow() final_state = workflow.invoke(initial_state) + # Post process - Update the validated roadmap with filtered courses + filtered_courses, total_duration = validated_roadmap_course(courses, final_state["validated_roadmap"]['validated_roadmap']["recommended_courses"]) + final_state["validated_roadmap"]['validated_roadmap']["recommended_courses"] = filtered_courses + final_state["validated_roadmap"]['validated_roadmap']["total_duration"] = total_duration # Extract roadmap from the final state return final_state["validated_roadmap"] +def validated_roadmap_course(courses: List[Any], validated_roadmap: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + # Extract valid course IDs from the courses list + valid_course_ids = {course.id for course in courses} + + # Track seen IDs to ensure only the first occurrence is added, also update new total_duration + seen_ids = set() + unique_courses = [] + total_duration = 0 + + for course in validated_roadmap: + if course["id"] in valid_course_ids and course["id"] not in seen_ids: + unique_courses.append(course) + seen_ids.add(course["id"]) + total_duration += course["duration"] + + return unique_courses, total_duration + # -- Example -- # # for test llm # from utils.mock_data import MockDataGenerator