diff --git a/backend/project/endpoints/courses/courses.py b/backend/project/endpoints/courses/courses.py index b550b60f..04b645ed 100644 --- a/backend/project/endpoints/courses/courses.py +++ b/backend/project/endpoints/courses/courses.py @@ -12,10 +12,15 @@ from flask import request from flask_restful import Resource +from sqlalchemy import union, select +from sqlalchemy.exc import SQLAlchemyError + from project.models.course import Course -from project.utils.query_agent import query_selected_from_model, insert_into_model -from project.utils.authentication import login_required, authorize_teacher +from project.models.course_relation import CourseAdmin, CourseStudent +from project.utils.query_agent import insert_into_model +from project.utils.authentication import login_required_return_uid, authorize_teacher from project.endpoints.courses.courses_utils import check_data +from project.db_in import db load_dotenv() API_URL = getenv("API_HOST") @@ -24,20 +29,61 @@ class CourseForUser(Resource): """Api endpoint for the /courses link""" - @login_required - def get(self): + @login_required_return_uid + def get(self, uid=None): """ " Get function for /courses this will be the main endpoint to get all courses and filter by given query parameter like /courses?parameter=... parameters can be either one of the following: teacher,ufora_id,name. """ - return query_selected_from_model( - Course, - RESPONSE_URL, - url_mapper={"course_id": RESPONSE_URL}, - filters=request.args - ) + try: + + filter_params = request.args.to_dict() + + # Start with a base query + base_query = select(Course) + + # Apply filters dynamically if they are provided + for param, value in filter_params.items(): + if value: + attribute = getattr(Course, param, None) + if attribute: + base_query = base_query.filter(attribute == value) + + # Define the role-specific queries + student_courses = base_query.join( + CourseStudent, + Course.course_id == CourseStudent.course_id).filter( + CourseStudent.uid == uid) + admin_courses = base_query.join( + CourseAdmin, + Course.course_id == CourseAdmin.course_id).filter( + CourseAdmin.uid == uid) + teacher_courses = base_query.filter(Course.teacher == uid) + + # Combine the select statements using union to remove duplicates + all_courses_query = union(student_courses, admin_courses, teacher_courses) + + # Execute the union query and fetch all results as Course instances + courses = db.session.execute(all_courses_query).mappings().all() + courses_data = [dict(course) for course in courses] + + for course in courses_data: + course["course_id"] = urljoin(f"{RESPONSE_URL}/", str(course['course_id'])) + + return { + "data": courses_data, + "url": RESPONSE_URL, + "message": "Courses fetched successfully" + } + + except SQLAlchemyError: + db.session.rollback() + return { + "message": "An error occurred while fetching the courses", + "url": RESPONSE_URL + }, 500 @authorize_teacher def post(self, teacher_id=None): diff --git a/backend/tests/endpoints/course/courses_test.py b/backend/tests/endpoints/course/courses_test.py index ca3599c5..b3145d68 100644 --- a/backend/tests/endpoints/course/courses_test.py +++ b/backend/tests/endpoints/course/courses_test.py @@ -115,7 +115,7 @@ def test_data_fields(self, data_field_type_test: tuple[str, Any, str, dict[str, ### QUERY PARAMETER ### # Test a query parameter, should return [] for wrong values query_parameter_tests = \ - query_parameter_tests("/courses", "get", "student", [f.name for f in fields(Course)]) + query_parameter_tests("/courses", "get", "teacher", [f.name for f in fields(Course)]) @mark.parametrize("query_parameter_test", query_parameter_tests, indirect=True) def test_query_parameters(self, query_parameter_test: tuple[str, Any, str, bool]): @@ -127,7 +127,7 @@ def test_query_parameters(self, query_parameter_test: tuple[str, Any, str, bool] ### COURSES ### def test_get_courses(self, client: FlaskClient, courses: list[Course]): """Test getting all courses""" - csrf = get_csrf_from_login(client, "student") + csrf = get_csrf_from_login(client, "teacher") response = client.get("/courses", headers = {"X-CSRF-TOKEN":csrf}) assert response.status_code == 200 data = [course["name"] for course in response.json["data"]] @@ -222,7 +222,6 @@ def test_post_courses(self, client: FlaskClient, teacher: User): } ) assert response.status_code == 201 - csrf = get_csrf_from_login(client, "student") response = client.get("/courses?name=test", headers = {"X-CSRF-TOKEN":csrf}) assert response.status_code == 200 data = response.json["data"][0] diff --git a/backend/tests/endpoints/endpoint.py b/backend/tests/endpoints/endpoint.py index c2b2c796..14434353 100644 --- a/backend/tests/endpoints/endpoint.py +++ b/backend/tests/endpoints/endpoint.py @@ -69,7 +69,7 @@ def query_parameter_tests( new_endpoint = endpoint + "?parameter=0" tests.append(param( (new_endpoint, method, token, True), - id = f"{new_endpoint} {method.upper()} {token} (parameter 0 400)" + id = f"{new_endpoint} {method.upper()} {token} (parameter 0 500)" )) for parameter in parameters: @@ -114,7 +114,8 @@ def query_parameter(self, test: tuple[str, Any, str, bool]): endpoint, method, csrf, wrong_parameter = test response = method(endpoint, headers = {"X-CSRF-TOKEN":csrf}) - assert wrong_parameter == (response.status_code == 400) + if wrong_parameter: + assert wrong_parameter == (response.status_code == 200) if not wrong_parameter: assert response.json["data"] == []