Skip to content

Commit

Permalink
feat: support variation param for model and prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Zacharis278 committed Jun 27, 2024
1 parent a4a0d02 commit 9b9df46
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 13 deletions.
3 changes: 1 addition & 2 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_block_content(request, user_id, course_id, unit_usage_key):
return cache_data['content_length'], cache_data['content_items']


def render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id):
def render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id, template_string):
"""
Return a rendered prompt template, specified by the LEARNING_ASSISTANT_PROMPT_TEMPLATE setting.
"""
Expand All @@ -117,7 +117,6 @@ def render_prompt_template(request, user_id, course_run_id, unit_usage_key, cour
skill_names = course_data['skill_names']
title = course_data['title']

template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')
template = Environment(loader=BaseLoader).from_string(template_string)
data = template.render(unit_content=unit_content, skill_names=skill_names, title=title)
return data
Expand Down
9 changes: 9 additions & 0 deletions learning_assistant/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,12 @@
"html": "TEXT",
"video": "VIDEO",
}


class GptModels:
GPT_3_5_TURBO = 'gpt-3.5-turbo'
GPT_4o = 'gpt-4o'


class ResponseVariations:
GPT4_UPDATED_PROMPT = 'updated_prompt'
9 changes: 5 additions & 4 deletions learning_assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,19 @@ def get_reduced_message_list(prompt_template, message_list):
return [system_message] + new_message_list


def create_request_body(prompt_template, message_list):
def create_request_body(prompt_template, message_list, gpt_model):
"""
Form request body to be passed to the chat endpoint.
"""
response_body = {
'message_list': get_reduced_message_list(prompt_template, message_list)
'message_list': get_reduced_message_list(prompt_template, message_list),
'model': gpt_model,
}

return response_body


def get_chat_response(prompt_template, message_list):
def get_chat_response(prompt_template, message_list, gpt_model):
"""
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
"""
Expand All @@ -74,7 +75,7 @@ def get_chat_response(prompt_template, message_list):
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1)
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15)

body = create_request_body(prompt_template, message_list)
body = create_request_body(prompt_template, message_list, gpt_model)

try:
response = requests.post(
Expand Down
15 changes: 13 additions & 2 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import logging

from django.conf import settings
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
Expand All @@ -20,6 +21,7 @@
pass

from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template
from learning_assistant.constants import GptModels, ResponseVariations
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response, user_role_is_staff

Expand Down Expand Up @@ -73,6 +75,7 @@ def post(self, request, course_run_id):
)

unit_id = request.query_params.get('unit_id')
response_variation = request.query_params.get('response_variation')

message_list = request.data
serializer = MessageSerializer(data=message_list, many=True)
Expand All @@ -95,9 +98,17 @@ def post(self, request, course_run_id):

course_id = get_course_id(course_run_id)

prompt_template = render_prompt_template(request, request.user.id, course_run_id, unit_id, course_id)
if response_variation == ResponseVariations.GPT4_UPDATED_PROMPT:
gpt_model = GptModels.GPT_4o
template_string = getattr(settings, 'LEARNING_ASSISTANT_EXPERIMENTAL_PROMPT_TEMPLATE', '')
else:
gpt_model = GptModels.GPT_3_5_TURBO
template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')

status_code, message = get_chat_response(prompt_template, message_list)
prompt_template = render_prompt_template(
request, request.user.id, course_run_id, unit_id, course_id, template_string
)
status_code, message = get_chat_response(prompt_template, message_list, gpt_model)

return Response(status=status_code, data=message)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,11 @@ def test_render_prompt_template(
course_run_id = self.course_run_id
unit_usage_key = 'block-v1:edX+A+B+type@vertical+block@verticalD'
course_id = 'edx+test'
template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')

prompt_text = render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id)
prompt_text = render_prompt_template(
request, user_id, course_run_id, unit_usage_key, course_id, template_string
)

if unit_content and flag_enabled:
self.assertIn(unit_content, prompt_text)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
self.course_id = 'edx+test'

def get_response(self):
return get_chat_response(self.prompt_template, self.message_list)
return get_chat_response(self.prompt_template, self.message_list, 'gpt-version-test')

@override_settings(CHAT_COMPLETION_API=None)
def test_no_endpoint_setting(self):
Expand Down Expand Up @@ -89,7 +89,8 @@ def test_post_request_structure(self, mock_requests):
headers = {'Content-Type': 'application/json', 'x-api-key': settings.CHAT_COMPLETION_API_KEY}

response_body = {
'message_list': [{'role': 'system', 'content': self.prompt_template}] + self.message_list
'message_list': [{'role': 'system', 'content': self.prompt_template}] + self.message_list,
'model': 'gpt-version-test',
}

self.get_response()
Expand Down
56 changes: 54 additions & 2 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from django.test.client import Client
from django.urls import reverse

from learning_assistant.constants import GptModels, ResponseVariations

User = get_user_model()


Expand Down Expand Up @@ -155,13 +157,15 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render):
@patch('learning_assistant.views.get_user_role')
@patch('learning_assistant.views.CourseEnrollment.get_enrollment')
@patch('learning_assistant.views.CourseMode')
def test_chat_response(self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render):
def test_chat_response_default(
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render
):
mock_waffle.return_value = True
mock_role.return_value = 'student'
mock_mode.VERIFIED_MODES = ['verified']
mock_enrollment.return_value = MagicMock(mode='verified')
mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'})
mock_render.return_value = 'This is a template'
mock_render.return_value = 'This is the default template'
test_unit_id = 'test-unit-id'

test_data = [
Expand All @@ -179,6 +183,54 @@ def test_chat_response(self, mock_mode, mock_enrollment, mock_role, mock_waffle,
render_args = mock_render.call_args.args
self.assertIn(test_unit_id, render_args)

mock_chat_response.assert_called_with(
'This is the default template',
test_data,
GptModels.GPT_3_5_TURBO
)

@patch('learning_assistant.views.render_prompt_template')
@patch('learning_assistant.views.get_chat_response')
@patch('learning_assistant.views.learning_assistant_enabled')
@patch('learning_assistant.views.get_user_role')
@patch('learning_assistant.views.CourseEnrollment.get_enrollment')
@patch('learning_assistant.views.CourseMode')
def test_chat_response_variation(
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render
):
mock_waffle.return_value = True
mock_role.return_value = 'student'
mock_mode.VERIFIED_MODES = ['verified']
mock_enrollment.return_value = MagicMock(mode='verified')
mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'})
mock_render.return_value = 'This is a template for GPT-4o variation'
test_unit_id = 'test-unit-id'
test_response_variation = ResponseVariations.GPT4_UPDATED_PROMPT

test_data = [
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': 'It is 4'}
]

response = self.client.post(
reverse(
'chat',
kwargs={'course_run_id': self.course_id}
)+f'?unit_id={test_unit_id}&response_variation={test_response_variation}',
data=json.dumps(test_data),
content_type='application/json',
)
self.assertEqual(response.status_code, 200)

render_args = mock_render.call_args.args
self.assertIn(test_unit_id, render_args)

mock_chat_response.assert_called_with(
'This is a template for GPT-4o variation',
test_data,
GptModels.GPT_4o
)


@ddt.ddt
class LearningAssistantEnabledViewTests(LoggedInTestCase):
Expand Down

0 comments on commit 9b9df46

Please sign in to comment.