Skip to content

Commit

Permalink
add gemini model
Browse files Browse the repository at this point in the history
  • Loading branch information
fegloff committed Mar 19, 2024
1 parent 1950121 commit 049319e
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/modules/llms/api/athropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const logger = pino({
}
})

const API_ENDPOINT = config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint

export const anthropicCompletion = async (
conversation: ChatConversation[],
Expand Down
93 changes: 90 additions & 3 deletions src/modules/llms/api/vertex.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import axios from 'axios'
import axios, { type AxiosResponse } from 'axios'
import config from '../../../config'
import { type ChatConversation } from '../../types'
import { type OnMessageContext, type ChatConversation, type OnCallBackQueryData } from '../../types'
import { type LlmCompletion } from './llmApi'
import { type Readable } from 'stream'
import { GrammyError } from 'grammy'
import { pino } from 'pino'
import { LlmsModelsEnum } from '../types'

const API_ENDPOINT = config.llms.apiEndpoint // http://localhost:8080' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // http://localhost:8080' // config.llms.apiEndpoint

const logger = pino({
name: 'Gemini - llmsBot',
transport: {
target: 'pino-pretty',
options: { colorize: true }
}
})

export const vertexCompletion = async (
conversation: ChatConversation[],
Expand Down Expand Up @@ -35,3 +47,78 @@ export const vertexCompletion = async (
price: 0
}
}

export const vertexStreamCompletion = async (
conversation: ChatConversation[],
model = LlmsModelsEnum.CLAUDE_OPUS,
ctx: OnMessageContext | OnCallBackQueryData,
msgId: number,
limitTokens = true
): Promise<LlmCompletion> => {
const data = {
model,
stream: true, // Set stream to true to receive the completion as a stream
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined,
messages: conversation.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } })
}
const url = `${API_ENDPOINT}/vertex/completions/gemini`
if (!ctx.chat?.id) {
throw new Error('Context chat id should not be empty after openAI streaming')
}
const response: AxiosResponse = await axios.post(url, data, { responseType: 'stream' })
// Create a Readable stream from the response
const completionStream: Readable = response.data
// Read and process the stream
let completion = ''
for await (const chunk of completionStream) {
const msg = chunk.toString()
if (msg) {
completion += msg.split('Text: ')[1]
completion = completion.replaceAll('...', '')
completion += '...'
if (ctx.chat?.id) {
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch(async (e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
}
}
}
completion = completion.replaceAll('...', '')
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch((e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
const totalOutputTokens = '10' // response.headers['x-openai-output-tokens']
const totalInputTokens = '10' // response.headers['x-openai-input-tokens']
return {
completion: {
content: completion,
role: 'assistant',
model
},
usage: parseInt(totalOutputTokens, 10) + parseInt(totalInputTokens, 10),
price: 0,
inputTokens: parseInt(totalInputTokens, 10),
outputTokens: parseInt(totalOutputTokens, 10)
}
}
4 changes: 3 additions & 1 deletion src/modules/llms/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ export enum SupportedCommands {
j2Ultra = 'j2-ultra',
sum = 'sum',
ctx = 'ctx',
pdf = 'pdf'
pdf = 'pdf',
gemini = 'gemini',
gShort = 'g'
}

export const MAX_TRIES = 3
Expand Down
36 changes: 26 additions & 10 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {
SupportedCommands
} from './helpers'
import { getUrlFromText, preparePrompt, sendMessage } from '../open-ai/helpers'
import { vertexCompletion } from './api/vertex'
import { vertexCompletion, vertexStreamCompletion } from './api/vertex'
import { type LlmCompletion, llmCompletion, llmCheckCollectionStatus, queryUrlDocument, deleteCollection } from './api/llmApi'
import { LlmsModelsEnum } from './types'
import * as Sentry from '@sentry/node'
Expand Down Expand Up @@ -129,6 +129,10 @@ export class LlmsBot implements PayableBot {
await this.onChat(ctx, LlmsModelsEnum.BISON)
return
}
if (ctx.hasCommand(SupportedCommands.gemini) || ctx.hasCommand(SupportedCommands.gShort)) {
await this.onChat(ctx, LlmsModelsEnum.GEMINI)
return
}
if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) {
await this.onChat(ctx, LlmsModelsEnum.CLAUDE_OPUS)
return
Expand Down Expand Up @@ -567,13 +571,23 @@ export class LlmsBot implements PayableBot {
if (isTypingEnabled) {
ctx.chatAction = 'typing'
}
const completion = await anthropicStreamCompletion(
conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
let completion: LlmCompletion
if (model === LlmsModelsEnum.GEMINI) {
completion = await vertexStreamCompletion(conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
} else {
completion = await anthropicStreamCompletion(
conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
}
if (isTypingEnabled) {
ctx.chatAction = null
}
Expand All @@ -585,7 +599,7 @@ export class LlmsBot implements PayableBot {
`streamChatCompletion result = tokens: ${price.promptTokens + price.completionTokens} | ${model} | price: ${price.price}¢` // }
)
conversation.push({
role: 'assistant',
role: model === LlmsModelsEnum.GEMINI ? 'model' : 'assistant',
content: completion.completion?.content ?? ''
})
return {
Expand Down Expand Up @@ -754,7 +768,7 @@ export class LlmsBot implements PayableBot {
ctx
}
let result: { price: number, chat: ChatConversation[] } = { price: 0, chat: [] }
if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET) {
if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.GEMINI) {
result = await this.completionGen(payload) // , prompt.msgId, prompt.outputFormat)
} else {
result = await this.promptGen(payload)
Expand Down Expand Up @@ -816,6 +830,7 @@ export class LlmsBot implements PayableBot {
Sentry.setContext('llms', { retryCount, msg })
Sentry.captureException(e)
ctx.chatAction = null
console.log('FCO', e)
if (retryCount === 0) {
// Retry limit reached, log an error or take alternative action
this.logger.error(`Retry limit reached for error: ${e}`)
Expand Down Expand Up @@ -873,6 +888,7 @@ export class LlmsBot implements PayableBot {
ctx.transient.analytics.actualResponseTime = now()
}
} else if (e instanceof AxiosError) {
this.logger.error(`${e.message}`)
await sendMessage(ctx, 'Error handling your request').catch(async (e) => {
await this.onError(ctx, e, retryCount - 1)
})
Expand Down
10 changes: 9 additions & 1 deletion src/modules/llms/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ export enum LlmsModelsEnum {
J2_ULTRA = 'j2-ultra',
CLAUDE_OPUS = 'claude-3-opus-20240229',
CLAUDE_SONNET = 'claude-3-sonnet-20240229',
CLAUDE_HAIKU = 'claude-3-haiku-20240307'
CLAUDE_HAIKU = 'claude-3-haiku-20240307',
GEMINI = 'gemini-1.0-pro'
}

export const LlmsModels: Record<string, ChatModel> = {
Expand All @@ -17,6 +18,13 @@ export const LlmsModels: Record<string, ChatModel> = {
maxContextTokens: 8192,
chargeType: 'CHAR'
},
'gemini-1.0-pro': {
name: 'gemini-1.0-pro',
inputPrice: 0.00025, // 3.00 (1M Tokens) => 0.003 (1K tokens)
outputPrice: 0.00125,
maxContextTokens: 4096,
chargeType: 'CHAR'
},
'gpt-4-32k': {
name: 'gpt-4-32k',
inputPrice: 0.06, // 6
Expand Down

0 comments on commit 049319e

Please sign in to comment.