Skip to content

Commit

Permalink
fix: Fixed the bug preventing Gemini translation. #365
Browse files Browse the repository at this point in the history
  • Loading branch information
bookfere committed Nov 6, 2024
1 parent aa1efd1 commit 756ab9f
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 52 deletions.
6 changes: 4 additions & 2 deletions engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,10 @@ def _is_auto_lang(self):
def translate(self, text):
try:
response = request(
self.get_endpoint(), self.get_body(text), self.get_headers(),
self.method, self.request_timeout, self.proxy_uri, self.stream)
url=self.get_endpoint(), data=self.get_body(text),
headers=self.get_headers(), method=self.method,
timeout=self.request_timeout, proxy_uri=self.proxy_uri,
raw_object=self.stream)
return self.get_result(response)
except Exception as e:
# Combine the error messages for investigation.
Expand Down
14 changes: 9 additions & 5 deletions engines/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
import uuid

from mechanize._response import response_seek_wrapper as Response

from .. import EbookTranslator
from ..lib.utils import request
from ..lib.exception import UnsupportedModel
Expand Down Expand Up @@ -205,14 +207,16 @@ def retrieve(self, output_file_id):
del headers['Content-Type']
response = request(
'%s/%s/content' % (self.file_endpoint, output_file_id),
headers=headers, as_bytes=True)
headers=headers, raw_object=True)
assert isinstance(response, Response)

translations = {}
for line in io.BytesIO(response):
for line in io.BytesIO(response.read()):
result = json.loads(line)
response = result['response']
if response.get('status_code') == 200:
content = response['body']['choices'][0]['message']['content']
response_item = result['response']
if response_item.get('status_code') == 200:
content = response_item[
'body']['choices'][0]['message']['content']
translations[result.get('custom_id')] = content
return translations

Expand Down
10 changes: 3 additions & 7 deletions lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def traceback_error():

def request(
url, data=None, headers={}, method='GET', timeout=30, proxy_uri=None,
as_bytes=False, stream=False):
raw_object=False) -> Response | str:
br = Browser()
br.set_handle_robots(False)
# Do not verify SSL certificates
br.set_ca_data(
context=ssl._create_unverified_context(cert_reqs=ssl.CERT_NONE))
# Set up proxy
proxies = {}
proxies: dict = {}
if proxy_uri is not None:
proxies.update(http=proxy_uri, https=proxy_uri)
else:
Expand All @@ -171,8 +171,4 @@ def request(
_request = Request(url, data, headers=headers, timeout=timeout)
br.open(_request)
response: Response = br.response()
if stream:
return response
if as_bytes:
return response.read()
return response.read().decode('utf-8').strip()
return response if raw_object else response.read().decode('utf-8').strip()
38 changes: 23 additions & 15 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,10 @@ def test_translate(self, mock_request):
'{"text": "你好世界"}', self.translator.translate('Hello World'))

mock_request.assert_called_once_with(
'https://example.com/api', '{"text": "Hello World"}',
{'Authorization': 'Bearer a', 'Content-Type': 'application/json'},
'POST', 10.0, None, False)
url='https://example.com/api', data='{"text": "Hello World"}',
headers={
'Authorization': 'Bearer a', 'Content-Type': 'application/json'
}, method='POST', timeout=10.0, proxy_uri=None, raw_object=False)

@patch(module_name + '.base.request')
def test_translate_with_stream(self, mock_request):
Expand All @@ -205,9 +206,10 @@ def test_translate_with_stream(self, mock_request):
self.assertIs(mock_response, self.translator.translate('Hello World'))

mock_request.assert_called_once_with(
'https://example.com/api', '{"text": "Hello World"}',
{'Authorization': 'Bearer a', 'Content-Type': 'application/json'},
'POST', 10.0, None, True)
url='https://example.com/api', data='{"text": "Hello World"}',
headers={
'Authorization': 'Bearer a', 'Content-Type': 'application/json'
}, method='POST', timeout=10.0, proxy_uri=None, raw_object=True)

@patch(module_name + '.base.request')
def test_translate_with_http_error(self, mock_request):
Expand Down Expand Up @@ -457,7 +459,8 @@ def test_translate_stream(self, mock_request, mock_et):
result = self.translator.translate('Hello World!')

mock_request.assert_called_with(
url, data, headers, 'POST', 30.0, None, True)
url=url, data=data, headers=headers, method='POST', timeout=30.0,
proxy_uri=None, raw_object=True)
self.assertIsInstance(result, GeneratorType)
self.assertEqual('你好世界!', ''.join(result))

Expand Down Expand Up @@ -626,7 +629,7 @@ def test_retrieve(self, mock_request):
line_2 = (
b'{"custom_id":"def","response":{"status_code":200,"body":{'
b'"choices": [{"message": {"content": "B"}}]}}}')
mock_request.return_value = line_1 + b'\n' + line_2
mock_request.return_value.read.return_value = line_1 + b'\n' + line_2
self.mock_translator.get_headers.return_value = {
'Content-Type': 'application/json',
'Authorization': 'Bearer abc',
Expand All @@ -641,7 +644,8 @@ def test_retrieve(self, mock_request):
'User-Agent': 'Ebook-Translator/v1.0.0'}
mock_request.assert_called_once_with(
'https://api.openai.com/v1/files/test-batch-id/content',
headers=headers, as_bytes=True)
headers=headers, raw_object=True)
mock_request().read.assert_called_once()

@patch(module_name + '.openai.request')
def test_create(self, mock_request):
Expand Down Expand Up @@ -804,7 +808,8 @@ def test_translate_stream(self, mock_request):
self.translator.endpoint = url
result = self.translator.translate('Hello World!')
mock_request.assert_called_with(
url, data, headers, 'POST', 30.0, None, True)
url=url, data=data, headers=headers, method='POST', timeout=30.0,
proxy_uri=None, raw_object=True)
self.assertIsInstance(result, GeneratorType)
self.assertEqual('你好世界!', ''.join(result))

Expand Down Expand Up @@ -873,7 +878,8 @@ def test_translate(self, mock_request, mock_et):
result = self.translator.translate('Hello World!')

mock_request.assert_called_with(
url, data, headers, 'POST', 30.0, None, False)
url=url, data=data, headers=headers, method='POST', timeout=30.0,
proxy_uri=None, raw_object=False)
self.assertEqual('你好世界!', result)

@patch(module_name + '.anthropic.EbookTranslator')
Expand Down Expand Up @@ -943,7 +949,8 @@ def test_translate_stream(self, mock_request, mock_et):
self.translator.model = 'claude-2.1'
result = self.translator.translate('Hello World!')
mock_request.assert_called_with(
url, data, headers, 'POST', 30.0, None, True)
url=url, data=data, headers=headers, method='POST', timeout=30.0,
proxy_uri=None, raw_object=True)
self.assertIsInstance(result, GeneratorType)
self.assertEqual('你好世界!', ''.join(result))

Expand Down Expand Up @@ -1076,9 +1083,10 @@ def test_translate(self, mock_request):
mock_request.return_value = '{"text": "你好世界"}'
self.assertEqual('你好世界', translator.translate('Hello "World"'))
mock_request.assert_called_with(
'https://example.api',
b'{"source": "en", "target": "zh", "text": "Hello \\"World\\""}',
{'Content-Type': 'application/json'}, 'POST', 10.0, None, False)
url='https://example.api', data=b'{"source": "en", "target": "zh",'
b' "text": "Hello \\"World\\""}',
headers={'Content-Type': 'application/json'}, method='POST',
timeout=10.0, proxy_uri=None, raw_object=False)
# XML response
translator.response = 'response.text'
mock_request.return_value = '<test>你好世界</test>'
Expand Down
25 changes: 2 additions & 23 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,33 +101,12 @@ def test_request_output_as_string(
@patch(module_name + '.ssl')
@patch(module_name + '.Request')
@patch(module_name + '.Browser')
def test_request_output_as_bytes(
def test_request_output_as_raw_object(
self, mock_browser, mock_request, mock_ssl):
browser = mock_browser()

self.assertIs(
request('https://example.com/api', 'test data', as_bytes=True),
browser.response().read())

browser.set_handle_robots.assert_called_once_with(False)
mock_ssl._create_unverified_context.assert_called_once_with(
cert_reqs=mock_ssl.CERT_NONE)
browser.set_ca_data.assert_called_once_with(
context=mock_ssl._create_unverified_context())

mock_request.assert_called_once_with(
'https://example.com/api', 'test data', headers={}, timeout=30,
method='GET')
browser.open.assert_called_once_with(mock_request())

@patch(module_name + '.ssl')
@patch(module_name + '.Request')
@patch(module_name + '.Browser')
def test_request_with_stream(self, mock_browser, mock_request, mock_ssl):
browser = mock_browser()

self.assertIs(
request('https://example.com/api', 'test data', stream=True),
request('https://example.com/api', 'test data', raw_object=True),
browser.response())

browser.set_handle_robots.assert_called_once_with(False)
Expand Down

0 comments on commit 756ab9f

Please sign in to comment.