From e1bd8fb8cb5d8c265ca6f4fa0eacb8a10fd92308 Mon Sep 17 00:00:00 2001 From: bookfere Date: Wed, 15 Nov 2023 22:26:43 +0800 Subject: [PATCH] feat: Ability to specify a custom model for ChatGPT. resolved #167 --- engines/chatgpt.py | 32 ++++++++++------------ setting.py | 64 ++++++++++++++++++++++++++++++++++---------- tests/test_engine.py | 7 +++-- 3 files changed, 67 insertions(+), 36 deletions(-) diff --git a/engines/chatgpt.py b/engines/chatgpt.py index d4ab0f3..1843f80 100644 --- a/engines/chatgpt.py +++ b/engines/chatgpt.py @@ -45,7 +45,8 @@ def __init__(self): Base.__init__(self) self.endpoint = self.config.get('endpoint', self.endpoint) self.prompt = self.config.get('prompt', self.prompt) - self.model = self.config.get('model', self.model) + if self.model is not None: + self.model = self.config.get('model', self.model) self.sampling = self.config.get('sampling', self.sampling) self.temperature = self.config.get('temperature', self.temperature) self.top_p = self.config.get('top_p', self.top_p) @@ -73,18 +74,20 @@ def _get_headers(self): 'User-Agent': 'Ebook-Translator/%s' % EbookTranslator.__version__ } - def _get_body(self, text): - return { + def _get_data(self, text): + data = { 'stream': self.stream, - 'model': self.model, 'messages': [ {'role': 'system', 'content': self._get_prompt()}, {'role': 'user', 'content': text} ] } + if self.model is not None: + data.update(model=self.model) + return data def translate(self, text): - data = self._get_body(text) + data = self._get_data(text) sampling_value = getattr(self, self.sampling) data.update({self.sampling: sampling_value}) @@ -119,11 +122,10 @@ def _parse_stream(self, data): class AzureChatgptTranslate(ChatgptTranslate): name = 'ChatGPT(Azure)' alias = 'ChatGPT (Azure)' - endpoint = ('https://{your-resource-name}.openai.azure.com/openai/' - 'deployments/{deployment-id}/chat/completions' - '?api-version={api-version}') - models = ['gpt-35-turbo', 'gpt-4', 'gpt-4-32k'] - model = 'gpt-35-turbo' + endpoint = ( + '$AZURE_OPENAI_ENDPOINT/openai/deployments/gpt-35-turbo/chat/' + 'completions?api-version=2023-05-15') + model = None def _get_headers(self): return { @@ -131,11 +133,5 @@ def _get_headers(self): 'api-key': self.api_key } - def _get_body(self, text): - data = ChatgptTranslate._get_body(self, text) - # Some versions do not support the `model` parameter. - for version in ('2023-03-15-preview', '2023-05-15'): - if self.endpoint.endswith(version): - del data['model'] - break - return data + def _get_data(self, text): + return ChatgptTranslate._get_data(self, text) diff --git a/setting.py b/setting.py index 16d0540..325bb0e 100644 --- a/setting.py +++ b/setting.py @@ -393,18 +393,24 @@ def layout_engine(self): # ChatGPT Setting chatgpt_group = QGroupBox(_('Tune ChatGPT')) chatgpt_group.setVisible(False) - endpoint_layout = QFormLayout(chatgpt_group) - self.set_form_layout_policy(endpoint_layout) + chatgpt_layout = QFormLayout(chatgpt_group) + self.set_form_layout_policy(chatgpt_layout) self.prompt = QPlainTextEdit() self.prompt.setMinimumHeight(80) self.prompt.setMaximumHeight(80) - endpoint_layout.addRow(_('Prompt'), self.prompt) + chatgpt_layout.addRow(_('Prompt'), self.prompt) self.chatgpt_endpoint = QLineEdit() - endpoint_layout.addRow(_('Endpoint'), self.chatgpt_endpoint) + chatgpt_layout.addRow(_('Endpoint'), self.chatgpt_endpoint) - chatgpt_model = QComboBox() - endpoint_layout.addRow(_('Model'), chatgpt_model) + chatgpt_model = QWidget() + chatgpt_model_layout = QHBoxLayout(chatgpt_model) + chatgpt_model_layout.setContentsMargins(0, 0, 0, 0) + chatgpt_select = QComboBox() + chatgpt_custom = QLineEdit() + chatgpt_model_layout.addWidget(chatgpt_select) + chatgpt_model_layout.addWidget(chatgpt_custom) + chatgpt_layout.addRow(_('Model'), chatgpt_model) self.disable_wheel_event(chatgpt_model) @@ -427,13 +433,13 @@ def layout_engine(self): sampling_layout.addWidget(top_p) sampling_layout.addWidget(top_p_value) sampling_layout.addStretch(1) - endpoint_layout.addRow(_('Sampling'), sampling_widget) + chatgpt_layout.addRow(_('Sampling'), sampling_widget) self.disable_wheel_event(temperature_value) self.disable_wheel_event(top_p_value) stream_enabled = QCheckBox(_('Enable streaming text like in ChatGPT')) - endpoint_layout.addRow(_('Stream'), stream_enabled) + chatgpt_layout.addRow(_('Stream'), stream_enabled) sampling_btn_group = QButtonGroup(sampling_widget) sampling_btn_group.addButton(temperature, 0) @@ -460,12 +466,42 @@ def show_chatgpt_preferences(): self.chatgpt_endpoint.setText( config.get('endpoint', self.current_engine.endpoint)) # Model - chatgpt_model.clear() - chatgpt_model.addItems(self.current_engine.models) - chatgpt_model.setCurrentText( - config.get('model', self.current_engine.model)) - chatgpt_model.currentTextChanged.connect( - lambda model: self.current_engine.config.update(model=model)) + if self.current_engine.model is not None: + chatgpt_layout.setRowVisible(chatgpt_model, True) + chatgpt_select.clear() + chatgpt_select.addItems(self.current_engine.models) + chatgpt_select.addItem(_('Custom')) + model = config.get('model', self.current_engine.model) + chatgpt_select.setCurrentText( + model if model in self.current_engine.models + else _('Custom')) + + def setup_chatgpt_model(model): + if model in self.current_engine.models: + chatgpt_custom.setVisible(False) + else: + chatgpt_custom.setVisible(True) + if model != _('Custom'): + chatgpt_custom.setText(model) + setup_chatgpt_model(model) + + def update_chatgpt_model(model): + if not model or _(model) == _('Custom'): + model = self.current_engine.models[0] + config.update(model=model) + + def change_chatgpt_model(model): + setup_chatgpt_model(model) + update_chatgpt_model(model) + + chatgpt_custom.textChanged.connect( + lambda model: update_chatgpt_model(model=model.strip())) + chatgpt_select.currentTextChanged.connect(change_chatgpt_model) + self.save_config.connect( + lambda: chatgpt_select.setCurrentText(config.get('model'))) + else: + chatgpt_layout.setRowVisible(chatgpt_model, False) + # Sampling sampling = config.get('sampling', self.current_engine.sampling) btn_id = self.current_engine.samplings.index(sampling) diff --git a/tests/test_engine.py b/tests/test_engine.py index 809687f..d1e53f3 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -115,11 +115,11 @@ def test_translate_stream(self, mock_browser, mock_request, mock_et): 'question-like content.') data = json.dumps({ 'stream': True, - 'model': 'gpt-3.5-turbo', 'messages': [ {'role': 'system', 'content': prompt}, {'role': 'user', 'content': 'Hello World!'} ], + 'model': 'gpt-3.5-turbo', 'temperature': 1, }) mock_et.__version__ = '1.0.0' @@ -174,7 +174,6 @@ def test_translate(self, mock_browser, mock_request): 'question-like content.') data = json.dumps({ 'stream': True, - # 'model': 'gpt-35-turbo', 'messages': [ {'role': 'system', 'content': prompt}, {'role': 'user', 'content': 'Hello World!'} @@ -192,8 +191,8 @@ def test_translate(self, mock_browser, mock_request): template % i.encode() for i in '你好世界!'] \ + ['data: [DONE]'.encode()] mock_browser.return_value.response.return_value = mock_response - url = 'https://test.openai.azure.com/openai/deployments/test/' \ - 'chat/completions?api-version=2023-05-15' + url = ('https://docs-test-001.openai.azure.com/openai/deployments/' + 'gpt-35-turbo/chat/completions?api-version=2023-05-15') self.translator.endpoint = url result = self.translator.translate('Hello World!') mock_request.assert_called_with(