diff --git a/src/modules/llms/api/athropic.ts b/src/modules/llms/api/athropic.ts index 9eccd2c..758a791 100644 --- a/src/modules/llms/api/athropic.ts +++ b/src/modules/llms/api/athropic.ts @@ -16,7 +16,7 @@ const logger = pino({ } }) -const API_ENDPOINT = config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint +const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint export const anthropicCompletion = async ( conversation: ChatConversation[], @@ -88,9 +88,16 @@ export const anthropicStreamCompletion = async ( if (msg) { if (msg.startsWith('Input Token')) { inputTokens = msg.split('Input Token: ')[1] - } else if (msg.startsWith('Text')) { + } else if (msg.startsWith('Output Tokens')) { + outputTokens = msg.split('Output Tokens: ')[1] + } else { wordCount++ - completion += msg.split('Text: ')[1] + completion += msg // .split('Text: ')[1] + if (msg.includes('Output Tokens:')) { + const tokenMsg = msg.split('Output Tokens: ')[1] + outputTokens = tokenMsg.split('Output Tokens: ')[1] + completion = completion.split('Output Tokens: ')[0] + } if (wordCount > wordCountMinimum) { // if (chunck === '.' && wordCount > wordCountMinimum) { if (wordCountMinimum < 64) { wordCountMinimum *= 2 @@ -114,8 +121,6 @@ export const anthropicStreamCompletion = async ( }) } } - } else if (msg.startsWith('Output Tokens')) { - outputTokens = msg.split('Output Tokens: ')[1] } } } diff --git a/src/modules/llms/api/vertex.ts b/src/modules/llms/api/vertex.ts index d63a87d..372cb95 100644 --- a/src/modules/llms/api/vertex.ts +++ b/src/modules/llms/api/vertex.ts @@ -1,9 +1,21 @@ -import axios from 'axios' +import axios, { type AxiosResponse } from 'axios' import config from '../../../config' -import { type ChatConversation } from '../../types' +import { type OnMessageContext, type ChatConversation, type OnCallBackQueryData } from '../../types' import { type LlmCompletion } from './llmApi' +import { type Readable } from 'stream' +import { GrammyError } from 'grammy' +import { pino } from 'pino' +import { LlmsModelsEnum } from '../types' -const API_ENDPOINT = config.llms.apiEndpoint // http://localhost:8080' // config.llms.apiEndpoint +const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint + +const logger = pino({ + name: 'Gemini - llmsBot', + transport: { + target: 'pino-pretty', + options: { colorize: true } + } +}) export const vertexCompletion = async ( conversation: ChatConversation[], @@ -35,3 +47,87 @@ export const vertexCompletion = async ( price: 0 } } + +export const vertexStreamCompletion = async ( + conversation: ChatConversation[], + model = LlmsModelsEnum.CLAUDE_OPUS, + ctx: OnMessageContext | OnCallBackQueryData, + msgId: number, + limitTokens = true +): Promise => { + const data = { + model, + stream: true, // Set stream to true to receive the completion as a stream + system: config.openAi.chatGpt.chatCompletionContext, + max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined, + messages: conversation.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } }) + } + const url = `${API_ENDPOINT}/vertex/completions/gemini` + if (!ctx.chat?.id) { + throw new Error('Context chat id should not be empty after openAI streaming') + } + const response: AxiosResponse = await axios.post(url, data, { responseType: 'stream' }) + // Create a Readable stream from the response + const completionStream: Readable = response.data + // Read and process the stream + let completion = '' + let outputTokens = '' + let inputTokens = '' + for await (const chunk of completionStream) { + const msg = chunk.toString() + if (msg) { + completion += msg // .split('Text: ')[1] + if (msg.includes('Input Token:')) { + const tokenMsg = msg.split('Input Token: ')[1] + inputTokens = tokenMsg.split('Output Tokens: ')[0] + outputTokens = tokenMsg.split('Output Tokens: ')[1] + completion = completion.split('Input Token: ')[0] + } + completion = completion.replaceAll('...', '') + completion += '...' + if (ctx.chat?.id) { + await ctx.api + .editMessageText(ctx.chat?.id, msgId, completion) + .catch(async (e: any) => { + if (e instanceof GrammyError) { + if (e.error_code !== 400) { + throw e + } else { + logger.error(e) + } + } else { + throw e + } + }) + } + } + } + completion = completion.replaceAll('...', '') + await ctx.api + .editMessageText(ctx.chat?.id, msgId, completion) + .catch((e: any) => { + if (e instanceof GrammyError) { + if (e.error_code !== 400) { + throw e + } else { + logger.error(e) + } + } else { + throw e + } + }) + const totalOutputTokens = outputTokens // response.headers['x-openai-output-tokens'] + const totalInputTokens = inputTokens // response.headers['x-openai-input-tokens'] + + return { + completion: { + content: completion, + role: 'assistant', + model + }, + usage: parseInt(totalOutputTokens, 10) + parseInt(totalInputTokens, 10), + price: 0, + inputTokens: parseInt(totalInputTokens, 10), + outputTokens: parseInt(totalOutputTokens, 10) + } +} diff --git a/src/modules/llms/helpers.ts b/src/modules/llms/helpers.ts index 9d82dd7..ed943a1 100644 --- a/src/modules/llms/helpers.ts +++ b/src/modules/llms/helpers.ts @@ -26,13 +26,16 @@ export enum SupportedCommands { j2Ultra = 'j2-ultra', sum = 'sum', ctx = 'ctx', - pdf = 'pdf' + pdf = 'pdf', + gemini = 'gemini', + gShort = 'g' } export const MAX_TRIES = 3 const LLAMA_PREFIX_LIST = ['* '] const BARD_PREFIX_LIST = ['b. ', 'B. '] const CLAUDE_OPUS_PREFIX_LIST = ['c. '] +const GEMINI_PREFIX_LIST = ['g. '] export const isMentioned = ( ctx: OnMessageContext | OnCallBackQueryData @@ -80,6 +83,16 @@ export const hasClaudeOpusPrefix = (prompt: string): string => { return '' } +export const hasGeminiPrefix = (prompt: string): string => { + const prefixList = GEMINI_PREFIX_LIST + for (let i = 0; i < prefixList.length; i++) { + if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { + return prefixList[i] + } + } + return '' +} + export const hasUrl = ( ctx: OnMessageContext | OnCallBackQueryData, prompt: string @@ -211,7 +224,7 @@ export const sendMessage = async ( export const hasPrefix = (prompt: string): string => { return ( - hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt) + hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt) || hasGeminiPrefix(prompt) ) } diff --git a/src/modules/llms/index.ts b/src/modules/llms/index.ts index c6dd8bc..28af012 100644 --- a/src/modules/llms/index.ts +++ b/src/modules/llms/index.ts @@ -22,6 +22,7 @@ import { getPromptPrice, hasBardPrefix, hasClaudeOpusPrefix, + hasGeminiPrefix, hasLlamaPrefix, hasPrefix, hasUrl, @@ -32,7 +33,7 @@ import { SupportedCommands } from './helpers' import { getUrlFromText, preparePrompt, sendMessage } from '../open-ai/helpers' -import { vertexCompletion } from './api/vertex' +import { vertexCompletion, vertexStreamCompletion } from './api/vertex' import { type LlmCompletion, llmCompletion, llmCheckCollectionStatus, queryUrlDocument, deleteCollection } from './api/llmApi' import { LlmsModelsEnum } from './types' import * as Sentry from '@sentry/node' @@ -129,6 +130,10 @@ export class LlmsBot implements PayableBot { await this.onChat(ctx, LlmsModelsEnum.BISON) return } + if (ctx.hasCommand([SupportedCommands.gemini, SupportedCommands.gShort]) || (hasGeminiPrefix(ctx.message?.text ?? '') !== '')) { + await this.onChat(ctx, LlmsModelsEnum.GEMINI) + return + } if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) { await this.onChat(ctx, LlmsModelsEnum.CLAUDE_OPUS) return @@ -567,13 +572,23 @@ export class LlmsBot implements PayableBot { if (isTypingEnabled) { ctx.chatAction = 'typing' } - const completion = await anthropicStreamCompletion( - conversation, - model as LlmsModelsEnum, - ctx, - msgId, - true // telegram messages has a character limit - ) + let completion: LlmCompletion + if (model === LlmsModelsEnum.GEMINI) { + completion = await vertexStreamCompletion(conversation, + model as LlmsModelsEnum, + ctx, + msgId, + true // telegram messages has a character limit + ) + } else { + completion = await anthropicStreamCompletion( + conversation, + model as LlmsModelsEnum, + ctx, + msgId, + true // telegram messages has a character limit + ) + } if (isTypingEnabled) { ctx.chatAction = null } @@ -754,7 +769,7 @@ export class LlmsBot implements PayableBot { ctx } let result: { price: number, chat: ChatConversation[] } = { price: 0, chat: [] } - if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET) { + if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.GEMINI) { result = await this.completionGen(payload) // , prompt.msgId, prompt.outputFormat) } else { result = await this.promptGen(payload) @@ -873,6 +888,7 @@ export class LlmsBot implements PayableBot { ctx.transient.analytics.actualResponseTime = now() } } else if (e instanceof AxiosError) { + this.logger.error(`${e.message}`) await sendMessage(ctx, 'Error handling your request').catch(async (e) => { await this.onError(ctx, e, retryCount - 1) }) diff --git a/src/modules/llms/types.ts b/src/modules/llms/types.ts index 54bb775..b4ffa32 100644 --- a/src/modules/llms/types.ts +++ b/src/modules/llms/types.ts @@ -6,7 +6,8 @@ export enum LlmsModelsEnum { J2_ULTRA = 'j2-ultra', CLAUDE_OPUS = 'claude-3-opus-20240229', CLAUDE_SONNET = 'claude-3-sonnet-20240229', - CLAUDE_HAIKU = 'claude-3-haiku-20240307' + CLAUDE_HAIKU = 'claude-3-haiku-20240307', + GEMINI = 'gemini-1.0-pro' } export const LlmsModels: Record = { @@ -17,6 +18,13 @@ export const LlmsModels: Record = { maxContextTokens: 8192, chargeType: 'CHAR' }, + 'gemini-1.0-pro': { + name: 'gemini-1.0-pro', + inputPrice: 0.00025, // 3.00 (1M Tokens) => 0.003 (1K tokens) + outputPrice: 0.00125, + maxContextTokens: 4096, + chargeType: 'CHAR' + }, 'gpt-4-32k': { name: 'gpt-4-32k', inputPrice: 0.06, // 6