diff --git a/src/modules/llms/api/athropic.ts b/src/modules/llms/api/athropic.ts index 758a791..0f6156b 100644 --- a/src/modules/llms/api/athropic.ts +++ b/src/modules/llms/api/athropic.ts @@ -23,13 +23,13 @@ export const anthropicCompletion = async ( model = LlmsModelsEnum.CLAUDE_OPUS ): Promise => { logger.info(`Handling ${model} completion`) - const data = { model, stream: false, system: config.openAi.chatGpt.chatCompletionContext, max_tokens: +config.openAi.chatGpt.maxTokens, - messages: conversation + messages: conversation.filter(c => c.model === model) + .map(m => { return { content: m.content, role: m.role } }) } const url = `${API_ENDPOINT}/anthropic/completions` const response = await axios.post(url, data) @@ -68,7 +68,7 @@ export const anthropicStreamCompletion = async ( stream: true, // Set stream to true to receive the completion as a stream system: config.openAi.chatGpt.chatCompletionContext, max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined, - messages: conversation.map(m => { return { content: m.content, role: m.role } }) + messages: conversation.filter(c => c.model === model).map(m => { return { content: m.content, role: m.role } }) } let wordCount = 0 let wordCountMinimum = 2 diff --git a/src/modules/llms/api/llmApi.ts b/src/modules/llms/api/llmApi.ts index 0169495..10c782b 100644 --- a/src/modules/llms/api/llmApi.ts +++ b/src/modules/llms/api/llmApi.ts @@ -99,7 +99,7 @@ export const llmCompletion = async ( const data = { model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo', stream: false, - messages: conversation + messages: conversation.filter(c => c.model === model) } const url = `${API_ENDPOINT}/llms/completions` const response = await axios.post(url, data) diff --git a/src/modules/llms/api/vertex.ts b/src/modules/llms/api/vertex.ts index 372cb95..3697d3a 100644 --- a/src/modules/llms/api/vertex.ts +++ b/src/modules/llms/api/vertex.ts @@ -24,8 +24,18 @@ export const vertexCompletion = async ( const data = { model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo', stream: false, - messages: conversation + messages: conversation.filter(c => c.model === model) + .map((msg) => { + const msgFiltered: ChatConversation = { content: msg.content, model: msg.model } + if (model === LlmsModelsEnum.BISON) { + msgFiltered.author = msg.role + } else { + msgFiltered.role = msg.role + } + return msgFiltered + }) } + const url = `${API_ENDPOINT}/vertex/completions` const response = await axios.post(url, data) if (response) { @@ -34,7 +44,7 @@ export const vertexCompletion = async ( return { completion: { content: response.data._prediction_response[0][0].candidates[0].content, - author: 'bot', + role: 'bot', // role replace to author attribute will be done later model }, usage: totalOutputTokens + totalInputTokens, @@ -60,7 +70,8 @@ export const vertexStreamCompletion = async ( stream: true, // Set stream to true to receive the completion as a stream system: config.openAi.chatGpt.chatCompletionContext, max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined, - messages: conversation.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } }) + messages: conversation.filter(c => c.model === model) + .map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } }) } const url = `${API_ENDPOINT}/vertex/completions/gemini` if (!ctx.chat?.id) { diff --git a/src/modules/llms/helpers.ts b/src/modules/llms/helpers.ts index ed943a1..546c0e3 100644 --- a/src/modules/llms/helpers.ts +++ b/src/modules/llms/helpers.ts @@ -2,11 +2,9 @@ import { type OnMessageContext, type OnCallBackQueryData, type MessageExtras, - type ChatConversation, type ChatPayload } from '../types' import { type ParseMode } from 'grammy/types' -import { LlmsModelsEnum } from './types' import { type Message } from 'grammy/out/types' import { type LlmCompletion, getChatModel, llmAddUrlDocument } from './api/llmApi' import { getChatModelPrice } from '../open-ai/api/openAi' @@ -16,8 +14,9 @@ export enum SupportedCommands { bardF = 'bard', claudeOpus = 'claude', opus = 'opus', + opusShort = 'o', claudeSonnet = 'claudes', - opusShort = 'c', + claudeShort = 'c', sonnet = 'sonnet', sonnetShort = 's', claudeHaiku = 'haiku', @@ -254,23 +253,6 @@ export const limitPrompt = (prompt: string): string => { return prompt } -export const prepareConversation = ( - conversation: ChatConversation[], - model: string -): ChatConversation[] => { - return conversation - .filter((msg) => msg.model === model) - .map((msg) => { - const msgFiltered: ChatConversation = { content: msg.content } - if (model === LlmsModelsEnum.BISON) { - msgFiltered.author = msg.author - } else { - msgFiltered.role = msg.role - } - return msgFiltered - }) -} - export function extractPdfFilename (url: string): string | null { const matches = url.match(/\/([^/]+\.pdf)$/) if (matches) { diff --git a/src/modules/llms/index.ts b/src/modules/llms/index.ts index 28af012..42a314d 100644 --- a/src/modules/llms/index.ts +++ b/src/modules/llms/index.ts @@ -29,7 +29,6 @@ import { isMentioned, limitPrompt, MAX_TRIES, - prepareConversation, SupportedCommands } from './helpers' import { getUrlFromText, preparePrompt, sendMessage } from '../open-ai/helpers' @@ -134,7 +133,7 @@ export class LlmsBot implements PayableBot { await this.onChat(ctx, LlmsModelsEnum.GEMINI) return } - if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) { + if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort, SupportedCommands.claudeShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) { await this.onChat(ctx, LlmsModelsEnum.CLAUDE_OPUS) return } @@ -601,7 +600,8 @@ export class LlmsBot implements PayableBot { ) conversation.push({ role: 'assistant', - content: completion.completion?.content ?? '' + content: completion.completion?.content ?? '', + model }) return { price: price.price, @@ -612,7 +612,8 @@ export class LlmsBot implements PayableBot { const response = await anthropicCompletion(conversation, model as LlmsModelsEnum) conversation.push({ role: 'assistant', - content: response.completion?.content ?? '' + content: response.completion?.content ?? '', + model }) return { price: response.price, @@ -644,13 +645,12 @@ export class LlmsBot implements PayableBot { usage: 0, price: 0 } - const chat = prepareConversation(conversation, model) if (model === LlmsModelsEnum.BISON) { - response = await vertexCompletion(chat, model) // "chat-bison@001"); + response = await vertexCompletion(conversation, model) // "chat-bison@001"); } else if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.CLAUDE_HAIKU) { - response = await anthropicCompletion(chat, model) + response = await anthropicCompletion(conversation, model) } else { - response = await llmCompletion(chat, model as LlmsModelsEnum) + response = await llmCompletion(conversation, model as LlmsModelsEnum) } if (response.completion) { await ctx.api.editMessageText( @@ -755,13 +755,9 @@ export class LlmsBot implements PayableBot { } const chat: ChatConversation = { content: limitPrompt(prompt), + role: 'user', model } - if (model === LlmsModelsEnum.BISON) { - chat.author = 'user' - } else { - chat.role = 'user' - } chatConversation.push(chat) const payload = { conversation: chatConversation,