Skip to content

Commit

Permalink
Merge pull request #359 from harmony-one/gemini
Browse files Browse the repository at this point in the history
Gemini
  • Loading branch information
fegloff authored Apr 2, 2024
2 parents 6dac718 + 8d52ff7 commit 002dc84
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 20 deletions.
15 changes: 10 additions & 5 deletions src/modules/llms/api/athropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const logger = pino({
}
})

const API_ENDPOINT = config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint

export const anthropicCompletion = async (
conversation: ChatConversation[],
Expand Down Expand Up @@ -88,9 +88,16 @@ export const anthropicStreamCompletion = async (
if (msg) {
if (msg.startsWith('Input Token')) {
inputTokens = msg.split('Input Token: ')[1]
} else if (msg.startsWith('Text')) {
} else if (msg.startsWith('Output Tokens')) {
outputTokens = msg.split('Output Tokens: ')[1]
} else {
wordCount++
completion += msg.split('Text: ')[1]
completion += msg // .split('Text: ')[1]
if (msg.includes('Output Tokens:')) {
const tokenMsg = msg.split('Output Tokens: ')[1]
outputTokens = tokenMsg.split('Output Tokens: ')[1]
completion = completion.split('Output Tokens: ')[0]
}
if (wordCount > wordCountMinimum) { // if (chunck === '.' && wordCount > wordCountMinimum) {
if (wordCountMinimum < 64) {
wordCountMinimum *= 2
Expand All @@ -114,8 +121,6 @@ export const anthropicStreamCompletion = async (
})
}
}
} else if (msg.startsWith('Output Tokens')) {
outputTokens = msg.split('Output Tokens: ')[1]
}
}
}
Expand Down
102 changes: 99 additions & 3 deletions src/modules/llms/api/vertex.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import axios from 'axios'
import axios, { type AxiosResponse } from 'axios'
import config from '../../../config'
import { type ChatConversation } from '../../types'
import { type OnMessageContext, type ChatConversation, type OnCallBackQueryData } from '../../types'
import { type LlmCompletion } from './llmApi'
import { type Readable } from 'stream'
import { GrammyError } from 'grammy'
import { pino } from 'pino'
import { LlmsModelsEnum } from '../types'

const API_ENDPOINT = config.llms.apiEndpoint // http://localhost:8080' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint

const logger = pino({
name: 'Gemini - llmsBot',
transport: {
target: 'pino-pretty',
options: { colorize: true }
}
})

export const vertexCompletion = async (
conversation: ChatConversation[],
Expand Down Expand Up @@ -35,3 +47,87 @@ export const vertexCompletion = async (
price: 0
}
}

export const vertexStreamCompletion = async (
conversation: ChatConversation[],
model = LlmsModelsEnum.CLAUDE_OPUS,
ctx: OnMessageContext | OnCallBackQueryData,
msgId: number,
limitTokens = true
): Promise<LlmCompletion> => {
const data = {
model,
stream: true, // Set stream to true to receive the completion as a stream
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined,
messages: conversation.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } })
}
const url = `${API_ENDPOINT}/vertex/completions/gemini`
if (!ctx.chat?.id) {
throw new Error('Context chat id should not be empty after openAI streaming')
}
const response: AxiosResponse = await axios.post(url, data, { responseType: 'stream' })
// Create a Readable stream from the response
const completionStream: Readable = response.data
// Read and process the stream
let completion = ''
let outputTokens = ''
let inputTokens = ''
for await (const chunk of completionStream) {
const msg = chunk.toString()
if (msg) {
completion += msg // .split('Text: ')[1]
if (msg.includes('Input Token:')) {
const tokenMsg = msg.split('Input Token: ')[1]
inputTokens = tokenMsg.split('Output Tokens: ')[0]
outputTokens = tokenMsg.split('Output Tokens: ')[1]
completion = completion.split('Input Token: ')[0]
}
completion = completion.replaceAll('...', '')
completion += '...'
if (ctx.chat?.id) {
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch(async (e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
}
}
}
completion = completion.replaceAll('...', '')
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch((e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
const totalOutputTokens = outputTokens // response.headers['x-openai-output-tokens']
const totalInputTokens = inputTokens // response.headers['x-openai-input-tokens']

return {
completion: {
content: completion,
role: 'assistant',
model
},
usage: parseInt(totalOutputTokens, 10) + parseInt(totalInputTokens, 10),
price: 0,
inputTokens: parseInt(totalInputTokens, 10),
outputTokens: parseInt(totalOutputTokens, 10)
}
}
17 changes: 15 additions & 2 deletions src/modules/llms/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ export enum SupportedCommands {
j2Ultra = 'j2-ultra',
sum = 'sum',
ctx = 'ctx',
pdf = 'pdf'
pdf = 'pdf',
gemini = 'gemini',
gShort = 'g'
}

export const MAX_TRIES = 3
const LLAMA_PREFIX_LIST = ['* ']
const BARD_PREFIX_LIST = ['b. ', 'B. ']
const CLAUDE_OPUS_PREFIX_LIST = ['c. ']
const GEMINI_PREFIX_LIST = ['g. ']

export const isMentioned = (
ctx: OnMessageContext | OnCallBackQueryData
Expand Down Expand Up @@ -80,6 +83,16 @@ export const hasClaudeOpusPrefix = (prompt: string): string => {
return ''
}

export const hasGeminiPrefix = (prompt: string): string => {
const prefixList = GEMINI_PREFIX_LIST
for (let i = 0; i < prefixList.length; i++) {
if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) {
return prefixList[i]
}
}
return ''
}

export const hasUrl = (
ctx: OnMessageContext | OnCallBackQueryData,
prompt: string
Expand Down Expand Up @@ -211,7 +224,7 @@ export const sendMessage = async (

export const hasPrefix = (prompt: string): string => {
return (
hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt)
hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt) || hasGeminiPrefix(prompt)
)
}

Expand Down
34 changes: 25 additions & 9 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
getPromptPrice,
hasBardPrefix,
hasClaudeOpusPrefix,
hasGeminiPrefix,
hasLlamaPrefix,
hasPrefix,
hasUrl,
Expand All @@ -32,7 +33,7 @@ import {
SupportedCommands
} from './helpers'
import { getUrlFromText, preparePrompt, sendMessage } from '../open-ai/helpers'
import { vertexCompletion } from './api/vertex'
import { vertexCompletion, vertexStreamCompletion } from './api/vertex'
import { type LlmCompletion, llmCompletion, llmCheckCollectionStatus, queryUrlDocument, deleteCollection } from './api/llmApi'
import { LlmsModelsEnum } from './types'
import * as Sentry from '@sentry/node'
Expand Down Expand Up @@ -129,6 +130,10 @@ export class LlmsBot implements PayableBot {
await this.onChat(ctx, LlmsModelsEnum.BISON)
return
}
if (ctx.hasCommand([SupportedCommands.gemini, SupportedCommands.gShort]) || (hasGeminiPrefix(ctx.message?.text ?? '') !== '')) {
await this.onChat(ctx, LlmsModelsEnum.GEMINI)
return
}
if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) {
await this.onChat(ctx, LlmsModelsEnum.CLAUDE_OPUS)
return
Expand Down Expand Up @@ -567,13 +572,23 @@ export class LlmsBot implements PayableBot {
if (isTypingEnabled) {
ctx.chatAction = 'typing'
}
const completion = await anthropicStreamCompletion(
conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
let completion: LlmCompletion
if (model === LlmsModelsEnum.GEMINI) {
completion = await vertexStreamCompletion(conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
} else {
completion = await anthropicStreamCompletion(
conversation,
model as LlmsModelsEnum,
ctx,
msgId,
true // telegram messages has a character limit
)
}
if (isTypingEnabled) {
ctx.chatAction = null
}
Expand Down Expand Up @@ -754,7 +769,7 @@ export class LlmsBot implements PayableBot {
ctx
}
let result: { price: number, chat: ChatConversation[] } = { price: 0, chat: [] }
if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET) {
if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.GEMINI) {
result = await this.completionGen(payload) // , prompt.msgId, prompt.outputFormat)
} else {
result = await this.promptGen(payload)
Expand Down Expand Up @@ -873,6 +888,7 @@ export class LlmsBot implements PayableBot {
ctx.transient.analytics.actualResponseTime = now()
}
} else if (e instanceof AxiosError) {
this.logger.error(`${e.message}`)
await sendMessage(ctx, 'Error handling your request').catch(async (e) => {
await this.onError(ctx, e, retryCount - 1)
})
Expand Down
10 changes: 9 additions & 1 deletion src/modules/llms/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ export enum LlmsModelsEnum {
J2_ULTRA = 'j2-ultra',
CLAUDE_OPUS = 'claude-3-opus-20240229',
CLAUDE_SONNET = 'claude-3-sonnet-20240229',
CLAUDE_HAIKU = 'claude-3-haiku-20240307'
CLAUDE_HAIKU = 'claude-3-haiku-20240307',
GEMINI = 'gemini-1.0-pro'
}

export const LlmsModels: Record<string, ChatModel> = {
Expand All @@ -17,6 +18,13 @@ export const LlmsModels: Record<string, ChatModel> = {
maxContextTokens: 8192,
chargeType: 'CHAR'
},
'gemini-1.0-pro': {
name: 'gemini-1.0-pro',
inputPrice: 0.00025, // 3.00 (1M Tokens) => 0.003 (1K tokens)
outputPrice: 0.00125,
maxContextTokens: 4096,
chargeType: 'CHAR'
},
'gpt-4-32k': {
name: 'gpt-4-32k',
inputPrice: 0.06, // 6
Expand Down

0 comments on commit 002dc84

Please sign in to comment.