Skip to content

Commit

Permalink
feat: Validate a custom model for batch translation in ChatGPT.
Browse files Browse the repository at this point in the history
  • Loading branch information
bookfere committed Nov 4, 2024
1 parent 77af131 commit c0d7204
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 51 deletions.
30 changes: 7 additions & 23 deletions engines/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,28 +116,6 @@ def _parse_stream(self, response):

class ChatgptBatchTranslate:
"""https://cookbook.openai.com/examples/batch_processing"""

supported_models = [
'gpt-4o',
'gpt-4-turbo',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'gpt-4-turbo-preview',
'gpt-4-vision-preview',
'gpt-4-turbo-2024-04-09',
'gpt-4-0314',
'gpt-4-32k-0314',
'gpt-4-32k-0613',
'gpt-3.5-turbo-0301',
'gpt-3.5-turbo-16k-0613',
'gpt-3.5-turbo-1106',
'gpt-3.5-turbo-0613',
'text-embedding-3-large',
'text-embedding-3-small',
'text-embedding-ada-002',
]
boundary = uuid.uuid4().hex

def __init__(self, translator):
Expand All @@ -146,6 +124,7 @@ def __init__(self, translator):

domain_name = '://'.join(
urlsplit(self.translator.endpoint, 'https')[:2])
self.model_endpint = '%s/v1/models' % domain_name
self.file_endpoint = '%s/v1/files' % domain_name
self.batch_endpoint = '%s/v1/batches' % domain_name

Expand All @@ -166,6 +145,11 @@ def _create_multipart_form_data(self, body):
data.append('--%s--' % self.boundary)
return '\r\n'.join(data).encode('utf-8')

def supported_models(self):
response = request(
self.model_endpint, headers=self.translator.get_headers())
return [item['id'] for item in json.loads(response).get('data')]

def headers(self, extra_headers={}):
headers = self.translator.get_headers()
headers.update(extra_headers)
Expand All @@ -178,7 +162,7 @@ def upload(self, paragraphs):
"""Upload the original content and retrieve the file id.
https://platform.openai.com/docs/api-reference/files/create
"""
if self.translator.model not in self.supported_models:
if self.translator.model not in self.supported_models():
raise UnsupportedModel(
'The model "{}" does not support batch functionality.'
.format(self.translator.model))
Expand Down
73 changes: 45 additions & 28 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def setUp(self):

def test_get_body(self):
self.assertEqual(self.translator.get_body('test content'), json.dumps({
'model': 'gpt-3.5-turbo',
'model': 'gpt-4o',
'messages': [
{
'role': 'system',
Expand All @@ -406,7 +406,7 @@ def test_get_body_without_stream(self):
self.assertEqual(
self.translator.get_body('test content'),
json.dumps({
'model': 'gpt-3.5-turbo',
'model': 'gpt-4o',
'messages': [
{
'role': 'system',
Expand All @@ -433,7 +433,7 @@ def test_translate_stream(self, mock_request, mock_et):
'only. Do not explain any term or answer any question-like '
'content.')
data = json.dumps({
'model': 'gpt-3.5-turbo',
'model': 'gpt-4o',
'messages': [
{'role': 'system', 'content': prompt},
{'role': 'user', 'content': 'Hello World!'}],
Expand Down Expand Up @@ -479,26 +479,6 @@ def setUp(self):
self.batch_translator = ChatgptBatchTranslate(self.mock_translator)

def test_class_object(self):
self.assertEqual(ChatgptBatchTranslate.supported_models, [
'gpt-4o',
'gpt-4-turbo',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'gpt-4-turbo-preview',
'gpt-4-vision-preview',
'gpt-4-turbo-2024-04-09',
'gpt-4-0314',
'gpt-4-32k-0314',
'gpt-4-32k-0613',
'gpt-3.5-turbo-0301',
'gpt-3.5-turbo-16k-0613',
'gpt-3.5-turbo-1106',
'gpt-3.5-turbo-0613',
'text-embedding-3-large',
'text-embedding-3-small',
'text-embedding-ada-002'])
self.assertRegex(ChatgptBatchTranslate.boundary, r'(?a)^\w+$')

def test_created_translator(self):
Expand All @@ -512,7 +492,41 @@ def test_created_translator(self):
self.batch_translator.batch_endpoint,
'https://api.openai.com/v1/batches')

def test_upload_with_unsupported_model(self):
@patch(module_name + '.openai.request')
def test_supportd_models(self, mock_request):
mock_request.return_value = """
{
"object": "list",
"data": [
{
"id": "model-id-0",
"object": "model",
"created": 1686935002,
"owned_by": "organization-owner"
},
{
"id": "model-id-1",
"object": "model",
"created": 1686935002,
"owned_by": "organization-owner"
},
{
"id": "model-id-2",
"object": "model",
"created": 1686935002,
"owned_by": "openai"
}
],
"object": "list"
}
"""
self.assertEqual(
self.batch_translator.supported_models(),
['model-id-0', 'model-id-1', 'model-id-2'])

@patch(module_name + '.openai.ChatgptBatchTranslate.supported_models')
def test_upload_with_unsupported_model(self, mock_suppored_models):
mock_suppored_models.return_value = ['gpt-4o']
self.mock_translator.model = 'fake-model'
self.mock_translator.stream = True
with self.assertRaises(UnsupportedModel) as cm:
Expand All @@ -522,8 +536,9 @@ def test_upload_with_unsupported_model(self):
'The model "fake-model" does not support batch functionality.')

@patch.object(ChatgptBatchTranslate, 'boundary', new='xxxxxxxxxx')
@patch(module_name + '.openai.ChatgptBatchTranslate.supported_models')
@patch(module_name + '.openai.request')
def test_upload(self, mock_request):
def test_upload(self, mock_request, mock_suppored_models):
mock_request.return_value = """
{
"id": "test-file-id",
Expand All @@ -534,6 +549,8 @@ def test_upload(self, mock_request):
"purpose": "fine-tune"
}
"""
mock_suppored_models.return_value = ['gpt-4o']

mock_paragraph_1 = Mock(Paragraph)
mock_paragraph_1.md5 = 'abc'
mock_paragraph_1.original = 'test content 1'
Expand All @@ -545,7 +562,7 @@ def test_upload(self, mock_request):

def mock_get_body(text):
return json.dumps({
'model': 'gpt-3.5-turbo',
'model': 'gpt-4o',
'messages': [
{'role': 'system', 'content': 'some prompt...'},
{'role': 'user', 'content': text}],
Expand All @@ -566,13 +583,13 @@ def mock_get_body(text):
'Content-Type: application/json\r\n'
'\r\n{"custom_id": "abc", "method": "POST", '
'"url": "/v1/chat/completions", '
'"body": {"model": "gpt-3.5-turbo", '
'"body": {"model": "gpt-4o", '
'"messages": [{"role": "system", '
'"content": "some prompt..."}, {"role": "user", '
'"content": "test content 1"}], "temperature": 1.0}}\n'
'{"custom_id": "def", "method": "POST", '
'"url": "/v1/chat/completions", '
'"body": {"model": "gpt-3.5-turbo", '
'"body": {"model": "gpt-4o", '
'"messages": [{"role": "system", '
'"content": "some prompt..."}, {"role": "user", '
'"content": "test content 2"}], "temperature": 1.0}}\r\n'
Expand Down

0 comments on commit c0d7204

Please sign in to comment.