Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini #359

Merged
merged 8 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/modules/llms/api/athropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const logger = pino({
}
})

const API_ENDPOINT = config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint

export const anthropicCompletion = async (
conversation: ChatConversation[],
Expand Down Expand Up @@ -88,9 +88,16 @@ export const anthropicStreamCompletion = async (
if (msg) {
if (msg.startsWith('Input Token')) {
inputTokens = msg.split('Input Token: ')[1]
} else if (msg.startsWith('Text')) {
} else if (msg.startsWith('Output Tokens')) {
outputTokens = msg.split('Output Tokens: ')[1]
} else {
wordCount++
completion += msg.split('Text: ')[1]
completion += msg // .split('Text: ')[1]
if (msg.includes('Output Tokens:')) {
const tokenMsg = msg.split('Output Tokens: ')[1]
outputTokens = tokenMsg.split('Output Tokens: ')[1]
completion = completion.split('Output Tokens: ')[0]
}
if (wordCount > wordCountMinimum) { // if (chunck === '.' && wordCount > wordCountMinimum) {
if (wordCountMinimum < 64) {
wordCountMinimum *= 2
Expand All @@ -114,8 +121,6 @@ export const anthropicStreamCompletion = async (
})
}
}
} else if (msg.startsWith('Output Tokens')) {
outputTokens = msg.split('Output Tokens: ')[1]
}
}
}
Expand Down
102 changes: 99 additions & 3 deletions src/modules/llms/api/vertex.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import axios from 'axios'
import axios, { type AxiosResponse } from 'axios'
import config from '../../../config'
import { type ChatConversation } from '../../types'
import { type OnMessageContext, type ChatConversation, type OnCallBackQueryData } from '../../types'
import { type LlmCompletion } from './llmApi'
import { type Readable } from 'stream'
import { GrammyError } from 'grammy'
import { pino } from 'pino'
import { LlmsModelsEnum } from '../types'

const API_ENDPOINT = config.llms.apiEndpoint // http://localhost:8080' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint

const logger = pino({
name: 'Gemini - llmsBot',
transport: {
target: 'pino-pretty',
options: { colorize: true }
}
})

export const vertexCompletion = async (
conversation: ChatConversation[],
Expand Down Expand Up @@ -35,3 +47,87 @@ export const vertexCompletion = async (
price: 0
}
}

export const vertexStreamCompletion = async (
conversation: ChatConversation[],
model = LlmsModelsEnum.CLAUDE_OPUS,
ctx: OnMessageContext | OnCallBackQueryData,
msgId: number,
limitTokens = true
): Promise<LlmCompletion> => {
const data = {
model,
stream: true, // Set stream to true to receive the completion as a stream
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined,
messages: conversation.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } })
}
const url = `${API_ENDPOINT}/vertex/completions/gemini`
if (!ctx.chat?.id) {
throw new Error('Context chat id should not be empty after openAI streaming')
}
const response: AxiosResponse = await axios.post(url, data, { responseType: 'stream' })
// Create a Readable stream from the response
const completionStream: Readable = response.data
// Read and process the stream
let completion = ''
let outputTokens = ''
let inputTokens = ''
for await (const chunk of completionStream) {
const msg = chunk.toString()
if (msg) {
completion += msg // .split('Text: ')[1]
if (msg.includes('Input Token:')) {
const tokenMsg = msg.split('Input Token: ')[1]
inputTokens = tokenMsg.split('Output Tokens: ')[0]
outputTokens = tokenMsg.split('Output Tokens: ')[1]
completion = completion.split('Input Token: ')[0]
}
completion = completion.replaceAll('...', '')
completion += '...'
if (ctx.chat?.id) {
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch(async (e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
}
}
}
completion = completion.replaceAll('...', '')
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch((e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
const totalOutputTokens = outputTokens // response.headers['x-openai-output-tokens']
const totalInputTokens = inputTokens // response.headers['x-openai-input-tokens']

return {
completion: {
content: completion,
role: 'assistant',
model
},
usage: parseInt(totalOutputTokens, 10) + parseInt(totalInputTokens, 10),
price: 0,
inputTokens: parseInt(totalInputTokens, 10),
outputTokens: parseInt(totalOutputTokens, 10)
}
}
17 changes: 15 additions & 2 deletions src/modules/llms/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ export enum SupportedCommands {
j2Ultra = 'j2-ultra',
sum = 'sum',
ctx = 'ctx',
pdf = 'pdf'
pdf = 'pdf',
gemini = 'gemini',
gShort = 'g'
}

export const MAX_TRIES = 3
const LLAMA_PREFIX_LIST = ['* ']
const BARD_PREFIX_LIST = ['b. ', 'B. ']
const CLAUDE_OPUS_PREFIX_LIST = ['c. ']
const GEMINI_PREFIX_LIST = ['g. ']

export const isMentioned = (
ctx: OnMessageContext | OnCallBackQueryData
Expand Down Expand Up @@ -80,6 +83,16 @@ export const hasClaudeOpusPrefix = (prompt: string): string => {
return ''
}

export const hasGeminiPrefix = (prompt: string): string => {
const prefixList = GEMINI_PREFIX_LIST
for (let i = 0; i < prefixList.length; i++) {
if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) {
return prefixList[i]
}
}
return ''
}

export const hasUrl = (
ctx: OnMessageContext | OnCallBackQueryData,
prompt: string
Expand Down Expand Up @@ -211,7 +224,7 @@ export const sendMessage = async (

export const hasPrefix = (prompt: string): string => {
return (
hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt)
hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt) || hasGeminiPrefix(prompt)
)
}

Expand Down
34 changes: 25 additions & 9 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
getPromptPrice,
hasBardPrefix,
hasClaudeOpusPrefix,
hasGeminiPrefix,
hasLlamaPrefix,
hasPrefix,
hasUrl,
Expand All @@ -32,7 +33,7 @@ import {
SupportedCommands
} from './helpers'
import { getUrlFromText, preparePrompt, sendMessage } from '../open-ai/helpers'
import { vertexCompletion } from './api/vertex'
import { vertexCompletion, vertexStreamCompletion } from './api/vertex'
import { type LlmCompletion, llmCompletion, llmCheckCollectionStatus, queryUrlDocument, deleteCollection } from './api/llmApi'
import { LlmsModelsEnum } from './types'
import * as Sentry from '@sentry/node'
Expand Down Expand Up @@ -129,6 +130,10 @@ export class LlmsBot implements PayableBot {
await this.onChat(ctx, LlmsModelsEnum.BISON)
return
}
if (ctx.hasCommand([SupportedCommands.gemini, SupportedCommands.gShort]) || (hasGeminiPrefix(ctx.message?.text ?? '') !== '')) {
await this.onChat(ctx, LlmsModelsEnum.GEMINI)
return
}
if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) {
await this.onChat(ctx, LlmsModelsEnum.CLAUDE_OPUS)
return
Expand Down Expand Up @@ -567,13 +572,23 @@ export class LlmsBot implements PayableBot {
if (isTypingEnabled) {
ctx.chatAction = 'typing'
}
const completion = await anthropicStreamCompletion(
conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
let completion: LlmCompletion
if (model === LlmsModelsEnum.GEMINI) {
completion = await vertexStreamCompletion(conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
} else {
completion = await anthropicStreamCompletion(
conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
}
if (isTypingEnabled) {
ctx.chatAction = null
}
Expand Down Expand Up @@ -754,7 +769,7 @@ export class LlmsBot implements PayableBot {
ctx
}
let result: { price: number, chat: ChatConversation[] } = { price: 0, chat: [] }
if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET) {
if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.GEMINI) {
result = await this.completionGen(payload) // , prompt.msgId, prompt.outputFormat)
} else {
result = await this.promptGen(payload)
Expand Down Expand Up @@ -873,6 +888,7 @@ export class LlmsBot implements PayableBot {
ctx.transient.analytics.actualResponseTime = now()
}
} else if (e instanceof AxiosError) {
this.logger.error(`${e.message}`)
await sendMessage(ctx, 'Error handling your request').catch(async (e) => {
await this.onError(ctx, e, retryCount - 1)
})
Expand Down
10 changes: 9 additions & 1 deletion src/modules/llms/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ export enum LlmsModelsEnum {
J2_ULTRA = 'j2-ultra',
CLAUDE_OPUS = 'claude-3-opus-20240229',
CLAUDE_SONNET = 'claude-3-sonnet-20240229',
CLAUDE_HAIKU = 'claude-3-haiku-20240307'
CLAUDE_HAIKU = 'claude-3-haiku-20240307',
GEMINI = 'gemini-1.0-pro'
}

export const LlmsModels: Record<string, ChatModel> = {
Expand All @@ -17,6 +18,13 @@ export const LlmsModels: Record<string, ChatModel> = {
maxContextTokens: 8192,
chargeType: 'CHAR'
},
'gemini-1.0-pro': {
name: 'gemini-1.0-pro',
inputPrice: 0.00025, // 3.00 (1M Tokens) => 0.003 (1K tokens)
outputPrice: 0.00125,
maxContextTokens: 4096,
chargeType: 'CHAR'
},
'gpt-4-32k': {
name: 'gpt-4-32k',
inputPrice: 0.06, // 6
Expand Down
Loading