diff --git a/src/modules/llms/api/athropic.ts b/src/modules/llms/api/athropic.ts index 9eccd2c6..0f6156b4 100644 --- a/src/modules/llms/api/athropic.ts +++ b/src/modules/llms/api/athropic.ts @@ -16,20 +16,20 @@ const logger = pino({ } }) -const API_ENDPOINT = config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint +const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint export const anthropicCompletion = async ( conversation: ChatConversation[], model = LlmsModelsEnum.CLAUDE_OPUS ): Promise => { logger.info(`Handling ${model} completion`) - const data = { model, stream: false, system: config.openAi.chatGpt.chatCompletionContext, max_tokens: +config.openAi.chatGpt.maxTokens, - messages: conversation + messages: conversation.filter(c => c.model === model) + .map(m => { return { content: m.content, role: m.role } }) } const url = `${API_ENDPOINT}/anthropic/completions` const response = await axios.post(url, data) @@ -68,7 +68,7 @@ export const anthropicStreamCompletion = async ( stream: true, // Set stream to true to receive the completion as a stream system: config.openAi.chatGpt.chatCompletionContext, max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined, - messages: conversation.map(m => { return { content: m.content, role: m.role } }) + messages: conversation.filter(c => c.model === model).map(m => { return { content: m.content, role: m.role } }) } let wordCount = 0 let wordCountMinimum = 2 @@ -88,9 +88,16 @@ export const anthropicStreamCompletion = async ( if (msg) { if (msg.startsWith('Input Token')) { inputTokens = msg.split('Input Token: ')[1] - } else if (msg.startsWith('Text')) { + } else if (msg.startsWith('Output Tokens')) { + outputTokens = msg.split('Output Tokens: ')[1] + } else { wordCount++ - completion += msg.split('Text: ')[1] + completion += msg // .split('Text: ')[1] + if (msg.includes('Output Tokens:')) { + const tokenMsg = msg.split('Output Tokens: ')[1] + outputTokens = tokenMsg.split('Output Tokens: ')[1] + completion = completion.split('Output Tokens: ')[0] + } if (wordCount > wordCountMinimum) { // if (chunck === '.' && wordCount > wordCountMinimum) { if (wordCountMinimum < 64) { wordCountMinimum *= 2 @@ -114,8 +121,6 @@ export const anthropicStreamCompletion = async ( }) } } - } else if (msg.startsWith('Output Tokens')) { - outputTokens = msg.split('Output Tokens: ')[1] } } } diff --git a/src/modules/llms/api/llmApi.ts b/src/modules/llms/api/llmApi.ts index 0169495e..10c782b3 100644 --- a/src/modules/llms/api/llmApi.ts +++ b/src/modules/llms/api/llmApi.ts @@ -99,7 +99,7 @@ export const llmCompletion = async ( const data = { model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo', stream: false, - messages: conversation + messages: conversation.filter(c => c.model === model) } const url = `${API_ENDPOINT}/llms/completions` const response = await axios.post(url, data) diff --git a/src/modules/llms/api/vertex.ts b/src/modules/llms/api/vertex.ts index d63a87dc..3697d3a3 100644 --- a/src/modules/llms/api/vertex.ts +++ b/src/modules/llms/api/vertex.ts @@ -1,9 +1,21 @@ -import axios from 'axios' +import axios, { type AxiosResponse } from 'axios' import config from '../../../config' -import { type ChatConversation } from '../../types' +import { type OnMessageContext, type ChatConversation, type OnCallBackQueryData } from '../../types' import { type LlmCompletion } from './llmApi' +import { type Readable } from 'stream' +import { GrammyError } from 'grammy' +import { pino } from 'pino' +import { LlmsModelsEnum } from '../types' -const API_ENDPOINT = config.llms.apiEndpoint // http://localhost:8080' // config.llms.apiEndpoint +const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint + +const logger = pino({ + name: 'Gemini - llmsBot', + transport: { + target: 'pino-pretty', + options: { colorize: true } + } +}) export const vertexCompletion = async ( conversation: ChatConversation[], @@ -12,8 +24,18 @@ export const vertexCompletion = async ( const data = { model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo', stream: false, - messages: conversation + messages: conversation.filter(c => c.model === model) + .map((msg) => { + const msgFiltered: ChatConversation = { content: msg.content, model: msg.model } + if (model === LlmsModelsEnum.BISON) { + msgFiltered.author = msg.role + } else { + msgFiltered.role = msg.role + } + return msgFiltered + }) } + const url = `${API_ENDPOINT}/vertex/completions` const response = await axios.post(url, data) if (response) { @@ -22,7 +44,7 @@ export const vertexCompletion = async ( return { completion: { content: response.data._prediction_response[0][0].candidates[0].content, - author: 'bot', + role: 'bot', // role replace to author attribute will be done later model }, usage: totalOutputTokens + totalInputTokens, @@ -35,3 +57,88 @@ export const vertexCompletion = async ( price: 0 } } + +export const vertexStreamCompletion = async ( + conversation: ChatConversation[], + model = LlmsModelsEnum.CLAUDE_OPUS, + ctx: OnMessageContext | OnCallBackQueryData, + msgId: number, + limitTokens = true +): Promise => { + const data = { + model, + stream: true, // Set stream to true to receive the completion as a stream + system: config.openAi.chatGpt.chatCompletionContext, + max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined, + messages: conversation.filter(c => c.model === model) + .map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } }) + } + const url = `${API_ENDPOINT}/vertex/completions/gemini` + if (!ctx.chat?.id) { + throw new Error('Context chat id should not be empty after openAI streaming') + } + const response: AxiosResponse = await axios.post(url, data, { responseType: 'stream' }) + // Create a Readable stream from the response + const completionStream: Readable = response.data + // Read and process the stream + let completion = '' + let outputTokens = '' + let inputTokens = '' + for await (const chunk of completionStream) { + const msg = chunk.toString() + if (msg) { + completion += msg // .split('Text: ')[1] + if (msg.includes('Input Token:')) { + const tokenMsg = msg.split('Input Token: ')[1] + inputTokens = tokenMsg.split('Output Tokens: ')[0] + outputTokens = tokenMsg.split('Output Tokens: ')[1] + completion = completion.split('Input Token: ')[0] + } + completion = completion.replaceAll('...', '') + completion += '...' + if (ctx.chat?.id) { + await ctx.api + .editMessageText(ctx.chat?.id, msgId, completion) + .catch(async (e: any) => { + if (e instanceof GrammyError) { + if (e.error_code !== 400) { + throw e + } else { + logger.error(e) + } + } else { + throw e + } + }) + } + } + } + completion = completion.replaceAll('...', '') + await ctx.api + .editMessageText(ctx.chat?.id, msgId, completion) + .catch((e: any) => { + if (e instanceof GrammyError) { + if (e.error_code !== 400) { + throw e + } else { + logger.error(e) + } + } else { + throw e + } + }) + const totalOutputTokens = outputTokens // response.headers['x-openai-output-tokens'] + const totalInputTokens = inputTokens // response.headers['x-openai-input-tokens'] + + return { + completion: { + content: completion, + role: 'assistant', + model + }, + usage: parseInt(totalOutputTokens, 10) + parseInt(totalInputTokens, 10), + price: 0, + inputTokens: parseInt(totalInputTokens, 10), + outputTokens: parseInt(totalOutputTokens, 10) + } +} diff --git a/src/modules/llms/helpers.ts b/src/modules/llms/helpers.ts index 9d82dd72..546c0e39 100644 --- a/src/modules/llms/helpers.ts +++ b/src/modules/llms/helpers.ts @@ -2,11 +2,9 @@ import { type OnMessageContext, type OnCallBackQueryData, type MessageExtras, - type ChatConversation, type ChatPayload } from '../types' import { type ParseMode } from 'grammy/types' -import { LlmsModelsEnum } from './types' import { type Message } from 'grammy/out/types' import { type LlmCompletion, getChatModel, llmAddUrlDocument } from './api/llmApi' import { getChatModelPrice } from '../open-ai/api/openAi' @@ -16,8 +14,9 @@ export enum SupportedCommands { bardF = 'bard', claudeOpus = 'claude', opus = 'opus', + opusShort = 'o', claudeSonnet = 'claudes', - opusShort = 'c', + claudeShort = 'c', sonnet = 'sonnet', sonnetShort = 's', claudeHaiku = 'haiku', @@ -26,13 +25,16 @@ export enum SupportedCommands { j2Ultra = 'j2-ultra', sum = 'sum', ctx = 'ctx', - pdf = 'pdf' + pdf = 'pdf', + gemini = 'gemini', + gShort = 'g' } export const MAX_TRIES = 3 const LLAMA_PREFIX_LIST = ['* '] const BARD_PREFIX_LIST = ['b. ', 'B. '] const CLAUDE_OPUS_PREFIX_LIST = ['c. '] +const GEMINI_PREFIX_LIST = ['g. '] export const isMentioned = ( ctx: OnMessageContext | OnCallBackQueryData @@ -80,6 +82,16 @@ export const hasClaudeOpusPrefix = (prompt: string): string => { return '' } +export const hasGeminiPrefix = (prompt: string): string => { + const prefixList = GEMINI_PREFIX_LIST + for (let i = 0; i < prefixList.length; i++) { + if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { + return prefixList[i] + } + } + return '' +} + export const hasUrl = ( ctx: OnMessageContext | OnCallBackQueryData, prompt: string @@ -211,7 +223,7 @@ export const sendMessage = async ( export const hasPrefix = (prompt: string): string => { return ( - hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt) + hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt) || hasGeminiPrefix(prompt) ) } @@ -241,23 +253,6 @@ export const limitPrompt = (prompt: string): string => { return prompt } -export const prepareConversation = ( - conversation: ChatConversation[], - model: string -): ChatConversation[] => { - return conversation - .filter((msg) => msg.model === model) - .map((msg) => { - const msgFiltered: ChatConversation = { content: msg.content } - if (model === LlmsModelsEnum.BISON) { - msgFiltered.author = msg.author - } else { - msgFiltered.role = msg.role - } - return msgFiltered - }) -} - export function extractPdfFilename (url: string): string | null { const matches = url.match(/\/([^/]+\.pdf)$/) if (matches) { diff --git a/src/modules/llms/index.ts b/src/modules/llms/index.ts index c6dd8bc0..42a314d5 100644 --- a/src/modules/llms/index.ts +++ b/src/modules/llms/index.ts @@ -22,17 +22,17 @@ import { getPromptPrice, hasBardPrefix, hasClaudeOpusPrefix, + hasGeminiPrefix, hasLlamaPrefix, hasPrefix, hasUrl, isMentioned, limitPrompt, MAX_TRIES, - prepareConversation, SupportedCommands } from './helpers' import { getUrlFromText, preparePrompt, sendMessage } from '../open-ai/helpers' -import { vertexCompletion } from './api/vertex' +import { vertexCompletion, vertexStreamCompletion } from './api/vertex' import { type LlmCompletion, llmCompletion, llmCheckCollectionStatus, queryUrlDocument, deleteCollection } from './api/llmApi' import { LlmsModelsEnum } from './types' import * as Sentry from '@sentry/node' @@ -129,7 +129,11 @@ export class LlmsBot implements PayableBot { await this.onChat(ctx, LlmsModelsEnum.BISON) return } - if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) { + if (ctx.hasCommand([SupportedCommands.gemini, SupportedCommands.gShort]) || (hasGeminiPrefix(ctx.message?.text ?? '') !== '')) { + await this.onChat(ctx, LlmsModelsEnum.GEMINI) + return + } + if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort, SupportedCommands.claudeShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) { await this.onChat(ctx, LlmsModelsEnum.CLAUDE_OPUS) return } @@ -567,13 +571,23 @@ export class LlmsBot implements PayableBot { if (isTypingEnabled) { ctx.chatAction = 'typing' } - const completion = await anthropicStreamCompletion( - conversation, - model as LlmsModelsEnum, - ctx, - msgId, - true // telegram messages has a character limit - ) + let completion: LlmCompletion + if (model === LlmsModelsEnum.GEMINI) { + completion = await vertexStreamCompletion(conversation, + model as LlmsModelsEnum, + ctx, + msgId, + true // telegram messages has a character limit + ) + } else { + completion = await anthropicStreamCompletion( + conversation, + model as LlmsModelsEnum, + ctx, + msgId, + true // telegram messages has a character limit + ) + } if (isTypingEnabled) { ctx.chatAction = null } @@ -586,7 +600,8 @@ export class LlmsBot implements PayableBot { ) conversation.push({ role: 'assistant', - content: completion.completion?.content ?? '' + content: completion.completion?.content ?? '', + model }) return { price: price.price, @@ -597,7 +612,8 @@ export class LlmsBot implements PayableBot { const response = await anthropicCompletion(conversation, model as LlmsModelsEnum) conversation.push({ role: 'assistant', - content: response.completion?.content ?? '' + content: response.completion?.content ?? '', + model }) return { price: response.price, @@ -629,13 +645,12 @@ export class LlmsBot implements PayableBot { usage: 0, price: 0 } - const chat = prepareConversation(conversation, model) if (model === LlmsModelsEnum.BISON) { - response = await vertexCompletion(chat, model) // "chat-bison@001"); + response = await vertexCompletion(conversation, model) // "chat-bison@001"); } else if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.CLAUDE_HAIKU) { - response = await anthropicCompletion(chat, model) + response = await anthropicCompletion(conversation, model) } else { - response = await llmCompletion(chat, model as LlmsModelsEnum) + response = await llmCompletion(conversation, model as LlmsModelsEnum) } if (response.completion) { await ctx.api.editMessageText( @@ -740,13 +755,9 @@ export class LlmsBot implements PayableBot { } const chat: ChatConversation = { content: limitPrompt(prompt), + role: 'user', model } - if (model === LlmsModelsEnum.BISON) { - chat.author = 'user' - } else { - chat.role = 'user' - } chatConversation.push(chat) const payload = { conversation: chatConversation, @@ -754,7 +765,7 @@ export class LlmsBot implements PayableBot { ctx } let result: { price: number, chat: ChatConversation[] } = { price: 0, chat: [] } - if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET) { + if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.GEMINI) { result = await this.completionGen(payload) // , prompt.msgId, prompt.outputFormat) } else { result = await this.promptGen(payload) @@ -873,6 +884,7 @@ export class LlmsBot implements PayableBot { ctx.transient.analytics.actualResponseTime = now() } } else if (e instanceof AxiosError) { + this.logger.error(`${e.message}`) await sendMessage(ctx, 'Error handling your request').catch(async (e) => { await this.onError(ctx, e, retryCount - 1) }) diff --git a/src/modules/llms/types.ts b/src/modules/llms/types.ts index 54bb7756..b4ffa32e 100644 --- a/src/modules/llms/types.ts +++ b/src/modules/llms/types.ts @@ -6,7 +6,8 @@ export enum LlmsModelsEnum { J2_ULTRA = 'j2-ultra', CLAUDE_OPUS = 'claude-3-opus-20240229', CLAUDE_SONNET = 'claude-3-sonnet-20240229', - CLAUDE_HAIKU = 'claude-3-haiku-20240307' + CLAUDE_HAIKU = 'claude-3-haiku-20240307', + GEMINI = 'gemini-1.0-pro' } export const LlmsModels: Record = { @@ -17,6 +18,13 @@ export const LlmsModels: Record = { maxContextTokens: 8192, chargeType: 'CHAR' }, + 'gemini-1.0-pro': { + name: 'gemini-1.0-pro', + inputPrice: 0.00025, // 3.00 (1M Tokens) => 0.003 (1K tokens) + outputPrice: 0.00125, + maxContextTokens: 4096, + chargeType: 'CHAR' + }, 'gpt-4-32k': { name: 'gpt-4-32k', inputPrice: 0.06, // 6