diff --git a/src/bots/GeminiBot.js b/src/bots/GeminiBot.js index 6bd5caf83e..c7bd540197 100644 --- a/src/bots/GeminiBot.js +++ b/src/bots/GeminiBot.js @@ -16,15 +16,7 @@ export default class GeminiBot extends LangChainBot { let available = false; if (store.state.gemini.apiKey) { - const chatModel = new ChatGoogleGenerativeAI({ - apiKey: store.state.gemini.apiKey, - modelName: this.constructor._model ? this.constructor._model : "", - temperature: store.state.gemini.temperature, - streaming: true, - topK: store.state.gemini.topK, - topP: store.state.gemini.topP, - }); - this.constructor._chatModel = chatModel; + this.setupModel(); available = true; } return available; @@ -33,4 +25,16 @@ export default class GeminiBot extends LangChainBot { getPastRounds() { return store.state.gemini.pastRounds; } + + _setupModel() { + const chatModel = new ChatGoogleGenerativeAI({ + apiKey: store.state.gemini.apiKey, + modelName: this.constructor._model ? this.constructor._model : "", + temperature: store.state.gemini.temperature, + streaming: true, + topK: store.state.gemini.topK, + topP: store.state.gemini.topP, + }); + return chatModel; + } } diff --git a/src/bots/LangChainBot.js b/src/bots/LangChainBot.js index ef775e127f..0ecba51be6 100644 --- a/src/bots/LangChainBot.js +++ b/src/bots/LangChainBot.js @@ -64,6 +64,16 @@ export default class LangChainBot extends Bot { return []; } + setupModel() { + this.constructor._chatModel = this._setupModel(); + } + + _setupModel() { + throw new Error( + "Abstract property '_setupModel' must be implemented in the subclass.", + ); + } + getPastRounds() { throw new Error( "Abstract property 'pastRounds' must be implemented in the subclass.", diff --git a/src/bots/baidu/WenxinQianfanBot.js b/src/bots/baidu/WenxinQianfanBot.js index 7872ac6724..f94504d449 100644 --- a/src/bots/baidu/WenxinQianfanBot.js +++ b/src/bots/baidu/WenxinQianfanBot.js @@ -18,18 +18,23 @@ export default class WenxinQianfanBot extends LangChainBot { let available = false; const { apiKey, secretKey } = store.state.wenxinQianfan; if (apiKey && secretKey) { - const chatModel = new ChatBaiduWenxin({ - modelName: this.constructor._model, - baiduApiKey: apiKey, - baiduSecretKey: secretKey, - streaming: true, - }); - this.constructor._chatModel = chatModel; + this.setupModel(); available = true; } return available; } + _setupModel() { + const { apiKey, secretKey } = store.state.wenxinQianfan; + const chatModel = new ChatBaiduWenxin({ + modelName: this.constructor._model, + baiduApiKey: apiKey, + baiduSecretKey: secretKey, + streaming: true, + }); + return chatModel; + } + getPastRounds() { return store.state.wenxinQianfan.pastRounds; } diff --git a/src/bots/microsoft/AzureOpenAIAPIBot.js b/src/bots/microsoft/AzureOpenAIAPIBot.js index 8240d55683..7cdd66572c 100644 --- a/src/bots/microsoft/AzureOpenAIAPIBot.js +++ b/src/bots/microsoft/AzureOpenAIAPIBot.js @@ -20,22 +20,26 @@ export default class AzureOpenAIAPIBot extends LangChainBot { store.state.azureOpenaiApi.azureOpenAIApiDeploymentName && store.state.azureOpenaiApi.azureOpenAIApiVersion ) { - const chatModel = new ChatOpenAI({ - azureOpenAIApiKey: store.state.azureOpenaiApi.azureApiKey, - azureOpenAIApiInstanceName: - store.state.azureOpenaiApi.azureApiInstanceName, - azureOpenAIApiDeploymentName: - store.state.azureOpenaiApi.azureOpenAIApiDeploymentName, - azureOpenAIApiVersion: store.state.azureOpenaiApi.azureOpenAIApiVersion, - temperature: store.state.azureOpenaiApi.temperature, - streaming: true, - }); - this.constructor._chatModel = chatModel; + this.setupModel(); available = true; } return available; } + _setupModel() { + const chatModel = new ChatOpenAI({ + azureOpenAIApiKey: store.state.azureOpenaiApi.azureApiKey, + azureOpenAIApiInstanceName: + store.state.azureOpenaiApi.azureApiInstanceName, + azureOpenAIApiDeploymentName: + store.state.azureOpenaiApi.azureOpenAIApiDeploymentName, + azureOpenAIApiVersion: store.state.azureOpenaiApi.azureOpenAIApiVersion, + temperature: store.state.azureOpenaiApi.temperature, + streaming: true, + }); + return chatModel; + } + getPastRounds() { return store.state.azureOpenaiApi.pastRounds; } diff --git a/src/bots/openai/OpenAIAPIBot.js b/src/bots/openai/OpenAIAPIBot.js index ce0d8b3c2f..d4166b1677 100644 --- a/src/bots/openai/OpenAIAPIBot.js +++ b/src/bots/openai/OpenAIAPIBot.js @@ -14,23 +14,27 @@ export default class OpenAIAPIBot extends LangChainBot { let available = false; if (store.state.openaiApi.apiKey) { - const chatModel = new ChatOpenAI({ - configuration: { - basePath: store.state.openaiApi.alterUrl - ? store.state.openaiApi.alterUrl - : "", - }, - openAIApiKey: store.state.openaiApi.apiKey, - modelName: this.constructor._model ? this.constructor._model : "", - temperature: store.state.openaiApi.temperature, - streaming: true, - }); - this.constructor._chatModel = chatModel; + this.setupModel(); available = true; } return available; } + _setupModel() { + const chatModel = new ChatOpenAI({ + configuration: { + basePath: store.state.openaiApi.alterUrl + ? store.state.openaiApi.alterUrl + : "", + }, + openAIApiKey: store.state.openaiApi.apiKey, + modelName: this.constructor._model ? this.constructor._model : "", + temperature: store.state.openaiApi.temperature, + streaming: true, + }); + return chatModel; + } + getPastRounds() { return store.state.openaiApi.pastRounds; } diff --git a/src/components/BotSettings/AzureOpenAIAPIBotSettings.vue b/src/components/BotSettings/AzureOpenAIAPIBotSettings.vue index 229b594265..fed5e0aef8 100644 --- a/src/components/BotSettings/AzureOpenAIAPIBotSettings.vue +++ b/src/components/BotSettings/AzureOpenAIAPIBotSettings.vue @@ -3,6 +3,7 @@ :settings="settings" :brand-id="brandId" mutation-type="setAzureOpenaiApi" + :watcher="watcher" > @@ -73,5 +74,10 @@ export default { brandId: Bot._brandId, }; }, + methods: { + watcher() { + Bot.getInstance().setupModel(); + }, + }, }; diff --git a/src/components/BotSettings/CommonBotSettings.vue b/src/components/BotSettings/CommonBotSettings.vue index 65ad067917..2467019db4 100644 --- a/src/components/BotSettings/CommonBotSettings.vue +++ b/src/components/BotSettings/CommonBotSettings.vue @@ -82,7 +82,7 @@ diff --git a/src/components/BotSettings/OpenAIAPIBotSettings.vue b/src/components/BotSettings/OpenAIAPIBotSettings.vue index ce9f7b780a..ccf438ce7c 100644 --- a/src/components/BotSettings/OpenAIAPIBotSettings.vue +++ b/src/components/BotSettings/OpenAIAPIBotSettings.vue @@ -3,10 +3,12 @@ :settings="settings" :brand-id="brandId" mutation-type="setOpenaiApi" + :watcher="watcher" >