diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 38ae1c0..99fc2b5 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -103,7 +103,7 @@ def get_block_content(request, user_id, course_id, unit_usage_key): return cache_data['content_length'], cache_data['content_items'] -def render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id): +def render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id, template_string): """ Return a rendered prompt template, specified by the LEARNING_ASSISTANT_PROMPT_TEMPLATE setting. """ @@ -117,7 +117,6 @@ def render_prompt_template(request, user_id, course_run_id, unit_usage_key, cour skill_names = course_data['skill_names'] title = course_data['title'] - template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') template = Environment(loader=BaseLoader).from_string(template_string) data = template.render(unit_content=unit_content, skill_names=skill_names, title=title) return data diff --git a/learning_assistant/constants.py b/learning_assistant/constants.py index 7027a28..99ecb31 100644 --- a/learning_assistant/constants.py +++ b/learning_assistant/constants.py @@ -14,3 +14,12 @@ "html": "TEXT", "video": "VIDEO", } + + +class GptModels: + GPT_3_5_TURBO = 'gpt-3.5-turbo' + GPT_4o = 'gpt-4o' + + +class ResponseVariations: + GPT4_UPDATED_PROMPT = 'updated_prompt' diff --git a/learning_assistant/utils.py b/learning_assistant/utils.py index fc72daf..8d9e79a 100644 --- a/learning_assistant/utils.py +++ b/learning_assistant/utils.py @@ -52,18 +52,19 @@ def get_reduced_message_list(prompt_template, message_list): return [system_message] + new_message_list -def create_request_body(prompt_template, message_list): +def create_request_body(prompt_template, message_list, gpt_model): """ Form request body to be passed to the chat endpoint. """ response_body = { - 'message_list': get_reduced_message_list(prompt_template, message_list) + 'message_list': get_reduced_message_list(prompt_template, message_list), + 'model': gpt_model, } return response_body -def get_chat_response(prompt_template, message_list): +def get_chat_response(prompt_template, message_list, gpt_model): """ Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting. """ @@ -74,7 +75,7 @@ def get_chat_response(prompt_template, message_list): connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1) read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15) - body = create_request_body(prompt_template, message_list) + body = create_request_body(prompt_template, message_list, gpt_model) try: response = requests.post( diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 52ad248..0293fb8 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -3,6 +3,7 @@ """ import logging +from django.conf import settings from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey @@ -20,6 +21,7 @@ pass from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template +from learning_assistant.constants import GptModels, ResponseVariations from learning_assistant.serializers import MessageSerializer from learning_assistant.utils import get_chat_response, user_role_is_staff @@ -73,6 +75,7 @@ def post(self, request, course_run_id): ) unit_id = request.query_params.get('unit_id') + response_variation = request.query_params.get('response_variation') message_list = request.data serializer = MessageSerializer(data=message_list, many=True) @@ -95,9 +98,17 @@ def post(self, request, course_run_id): course_id = get_course_id(course_run_id) - prompt_template = render_prompt_template(request, request.user.id, course_run_id, unit_id, course_id) + if response_variation == ResponseVariations.GPT4_UPDATED_PROMPT: + gpt_model = GptModels.GPT_4o + template_string = getattr(settings, 'LEARNING_ASSISTANT_EXPERIMENTAL_PROMPT_TEMPLATE', '') + else: + gpt_model = GptModels.GPT_3_5_TURBO + template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') - status_code, message = get_chat_response(prompt_template, message_list) + prompt_template = render_prompt_template( + request, request.user.id, course_run_id, unit_id, course_id, template_string + ) + status_code, message = get_chat_response(prompt_template, message_list, gpt_model) return Response(status=status_code, data=message) diff --git a/tests/test_api.py b/tests/test_api.py index 1f18d5f..20a6b14 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -202,8 +202,11 @@ def test_render_prompt_template( course_run_id = self.course_run_id unit_usage_key = 'block-v1:edX+A+B+type@vertical+block@verticalD' course_id = 'edx+test' + template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') - prompt_text = render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id) + prompt_text = render_prompt_template( + request, user_id, course_run_id, unit_usage_key, course_id, template_string + ) if unit_content and flag_enabled: self.assertIn(unit_content, prompt_text) diff --git a/tests/test_utils.py b/tests/test_utils.py index 21b23c4..93ba1bd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,7 +27,7 @@ def setUp(self): self.course_id = 'edx+test' def get_response(self): - return get_chat_response(self.prompt_template, self.message_list) + return get_chat_response(self.prompt_template, self.message_list, 'gpt-version-test') @override_settings(CHAT_COMPLETION_API=None) def test_no_endpoint_setting(self): @@ -89,7 +89,8 @@ def test_post_request_structure(self, mock_requests): headers = {'Content-Type': 'application/json', 'x-api-key': settings.CHAT_COMPLETION_API_KEY} response_body = { - 'message_list': [{'role': 'system', 'content': self.prompt_template}] + self.message_list + 'message_list': [{'role': 'system', 'content': self.prompt_template}] + self.message_list, + 'model': 'gpt-version-test', } self.get_response() diff --git a/tests/test_views.py b/tests/test_views.py index fbbe8b5..0f8068c 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -14,6 +14,8 @@ from django.test.client import Client from django.urls import reverse +from learning_assistant.constants import GptModels, ResponseVariations + User = get_user_model() @@ -155,13 +157,15 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') - def test_chat_response(self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render): + def test_chat_response_default( + self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render + ): mock_waffle.return_value = True mock_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] mock_enrollment.return_value = MagicMock(mode='verified') mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'}) - mock_render.return_value = 'This is a template' + mock_render.return_value = 'This is the default template' test_unit_id = 'test-unit-id' test_data = [ @@ -179,6 +183,54 @@ def test_chat_response(self, mock_mode, mock_enrollment, mock_role, mock_waffle, render_args = mock_render.call_args.args self.assertIn(test_unit_id, render_args) + mock_chat_response.assert_called_with( + 'This is the default template', + test_data, + GptModels.GPT_3_5_TURBO + ) + + @patch('learning_assistant.views.render_prompt_template') + @patch('learning_assistant.views.get_chat_response') + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseMode') + def test_chat_response_variation( + self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render + ): + mock_waffle.return_value = True + mock_role.return_value = 'student' + mock_mode.VERIFIED_MODES = ['verified'] + mock_enrollment.return_value = MagicMock(mode='verified') + mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'}) + mock_render.return_value = 'This is a template for GPT-4o variation' + test_unit_id = 'test-unit-id' + test_response_variation = ResponseVariations.GPT4_UPDATED_PROMPT + + test_data = [ + {'role': 'user', 'content': 'What is 2+2?'}, + {'role': 'assistant', 'content': 'It is 4'} + ] + + response = self.client.post( + reverse( + 'chat', + kwargs={'course_run_id': self.course_id} + )+f'?unit_id={test_unit_id}&response_variation={test_response_variation}', + data=json.dumps(test_data), + content_type='application/json', + ) + self.assertEqual(response.status_code, 200) + + render_args = mock_render.call_args.args + self.assertIn(test_unit_id, render_args) + + mock_chat_response.assert_called_with( + 'This is a template for GPT-4o variation', + test_data, + GptModels.GPT_4o + ) + @ddt.ddt class LearningAssistantEnabledViewTests(LoggedInTestCase):