Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llms refactoring #360

Merged
merged 10 commits into from
Apr 2, 2024
Merged
21 changes: 13 additions & 8 deletions src/modules/llms/api/athropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ const logger = pino({
}
})

const API_ENDPOINT = config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint

export const anthropicCompletion = async (
conversation: ChatConversation[],
model = LlmsModelsEnum.CLAUDE_OPUS
): Promise<LlmCompletion> => {
logger.info(`Handling ${model} completion`)

const data = {
model,
stream: false,
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: +config.openAi.chatGpt.maxTokens,
messages: conversation
messages: conversation.filter(c => c.model === model)
.map(m => { return { content: m.content, role: m.role } })
}
const url = `${API_ENDPOINT}/anthropic/completions`
const response = await axios.post(url, data)
Expand Down Expand Up @@ -68,7 +68,7 @@ export const anthropicStreamCompletion = async (
stream: true, // Set stream to true to receive the completion as a stream
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined,
messages: conversation.map(m => { return { content: m.content, role: m.role } })
messages: conversation.filter(c => c.model === model).map(m => { return { content: m.content, role: m.role } })
}
let wordCount = 0
let wordCountMinimum = 2
Expand All @@ -88,9 +88,16 @@ export const anthropicStreamCompletion = async (
if (msg) {
if (msg.startsWith('Input Token')) {
inputTokens = msg.split('Input Token: ')[1]
} else if (msg.startsWith('Text')) {
} else if (msg.startsWith('Output Tokens')) {
outputTokens = msg.split('Output Tokens: ')[1]
} else {
wordCount++
completion += msg.split('Text: ')[1]
completion += msg // .split('Text: ')[1]
if (msg.includes('Output Tokens:')) {
const tokenMsg = msg.split('Output Tokens: ')[1]
outputTokens = tokenMsg.split('Output Tokens: ')[1]
completion = completion.split('Output Tokens: ')[0]
}
if (wordCount > wordCountMinimum) { // if (chunck === '.' && wordCount > wordCountMinimum) {
if (wordCountMinimum < 64) {
wordCountMinimum *= 2
Expand All @@ -114,8 +121,6 @@ export const anthropicStreamCompletion = async (
})
}
}
} else if (msg.startsWith('Output Tokens')) {
outputTokens = msg.split('Output Tokens: ')[1]
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/modules/llms/api/llmApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export const llmCompletion = async (
const data = {
model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo',
stream: false,
messages: conversation
messages: conversation.filter(c => c.model === model)
}
const url = `${API_ENDPOINT}/llms/completions`
const response = await axios.post(url, data)
Expand Down
117 changes: 112 additions & 5 deletions src/modules/llms/api/vertex.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import axios from 'axios'
import axios, { type AxiosResponse } from 'axios'
import config from '../../../config'
import { type ChatConversation } from '../../types'
import { type OnMessageContext, type ChatConversation, type OnCallBackQueryData } from '../../types'
import { type LlmCompletion } from './llmApi'
import { type Readable } from 'stream'
import { GrammyError } from 'grammy'
import { pino } from 'pino'
import { LlmsModelsEnum } from '../types'

const API_ENDPOINT = config.llms.apiEndpoint // http://localhost:8080' // config.llms.apiEndpoint
const API_ENDPOINT = config.llms.apiEndpoint // config.llms.apiEndpoint // 'http://127.0.0.1:5000' // config.llms.apiEndpoint

const logger = pino({
name: 'Gemini - llmsBot',
transport: {
target: 'pino-pretty',
options: { colorize: true }
}
})

export const vertexCompletion = async (
conversation: ChatConversation[],
Expand All @@ -12,8 +24,18 @@ export const vertexCompletion = async (
const data = {
model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo',
stream: false,
messages: conversation
messages: conversation.filter(c => c.model === model)
.map((msg) => {
const msgFiltered: ChatConversation = { content: msg.content, model: msg.model }
if (model === LlmsModelsEnum.BISON) {
msgFiltered.author = msg.role
} else {
msgFiltered.role = msg.role
}
return msgFiltered
})
}

const url = `${API_ENDPOINT}/vertex/completions`
const response = await axios.post(url, data)
if (response) {
Expand All @@ -22,7 +44,7 @@ export const vertexCompletion = async (
return {
completion: {
content: response.data._prediction_response[0][0].candidates[0].content,
author: 'bot',
role: 'bot', // role replace to author attribute will be done later
model
},
usage: totalOutputTokens + totalInputTokens,
Expand All @@ -35,3 +57,88 @@ export const vertexCompletion = async (
price: 0
}
}

export const vertexStreamCompletion = async (
conversation: ChatConversation[],
model = LlmsModelsEnum.CLAUDE_OPUS,
ctx: OnMessageContext | OnCallBackQueryData,
msgId: number,
limitTokens = true
): Promise<LlmCompletion> => {
const data = {
model,
stream: true, // Set stream to true to receive the completion as a stream
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined,
messages: conversation.filter(c => c.model === model)
.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } })
}
const url = `${API_ENDPOINT}/vertex/completions/gemini`
if (!ctx.chat?.id) {
throw new Error('Context chat id should not be empty after openAI streaming')
}
const response: AxiosResponse = await axios.post(url, data, { responseType: 'stream' })
// Create a Readable stream from the response
const completionStream: Readable = response.data
// Read and process the stream
let completion = ''
let outputTokens = ''
let inputTokens = ''
for await (const chunk of completionStream) {
const msg = chunk.toString()
if (msg) {
completion += msg // .split('Text: ')[1]
if (msg.includes('Input Token:')) {
const tokenMsg = msg.split('Input Token: ')[1]
inputTokens = tokenMsg.split('Output Tokens: ')[0]
outputTokens = tokenMsg.split('Output Tokens: ')[1]
completion = completion.split('Input Token: ')[0]
}
completion = completion.replaceAll('...', '')
completion += '...'
if (ctx.chat?.id) {
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch(async (e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
}
}
}
completion = completion.replaceAll('...', '')
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch((e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
const totalOutputTokens = outputTokens // response.headers['x-openai-output-tokens']
const totalInputTokens = inputTokens // response.headers['x-openai-input-tokens']

return {
completion: {
content: completion,
role: 'assistant',
model
},
usage: parseInt(totalOutputTokens, 10) + parseInt(totalInputTokens, 10),
price: 0,
inputTokens: parseInt(totalInputTokens, 10),
outputTokens: parseInt(totalOutputTokens, 10)
}
}
39 changes: 17 additions & 22 deletions src/modules/llms/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ import {
type OnMessageContext,
type OnCallBackQueryData,
type MessageExtras,
type ChatConversation,
type ChatPayload
} from '../types'
import { type ParseMode } from 'grammy/types'
import { LlmsModelsEnum } from './types'
import { type Message } from 'grammy/out/types'
import { type LlmCompletion, getChatModel, llmAddUrlDocument } from './api/llmApi'
import { getChatModelPrice } from '../open-ai/api/openAi'
Expand All @@ -16,8 +14,9 @@ export enum SupportedCommands {
bardF = 'bard',
claudeOpus = 'claude',
opus = 'opus',
opusShort = 'o',
claudeSonnet = 'claudes',
opusShort = 'c',
claudeShort = 'c',
sonnet = 'sonnet',
sonnetShort = 's',
claudeHaiku = 'haiku',
Expand All @@ -26,13 +25,16 @@ export enum SupportedCommands {
j2Ultra = 'j2-ultra',
sum = 'sum',
ctx = 'ctx',
pdf = 'pdf'
pdf = 'pdf',
gemini = 'gemini',
gShort = 'g'
}

export const MAX_TRIES = 3
const LLAMA_PREFIX_LIST = ['* ']
const BARD_PREFIX_LIST = ['b. ', 'B. ']
const CLAUDE_OPUS_PREFIX_LIST = ['c. ']
const GEMINI_PREFIX_LIST = ['g. ']

export const isMentioned = (
ctx: OnMessageContext | OnCallBackQueryData
Expand Down Expand Up @@ -80,6 +82,16 @@ export const hasClaudeOpusPrefix = (prompt: string): string => {
return ''
}

export const hasGeminiPrefix = (prompt: string): string => {
const prefixList = GEMINI_PREFIX_LIST
for (let i = 0; i < prefixList.length; i++) {
if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) {
return prefixList[i]
}
}
return ''
}

export const hasUrl = (
ctx: OnMessageContext | OnCallBackQueryData,
prompt: string
Expand Down Expand Up @@ -211,7 +223,7 @@ export const sendMessage = async (

export const hasPrefix = (prompt: string): string => {
return (
hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt)
hasBardPrefix(prompt) || hasLlamaPrefix(prompt) || hasClaudeOpusPrefix(prompt) || hasGeminiPrefix(prompt)
)
}

Expand Down Expand Up @@ -241,23 +253,6 @@ export const limitPrompt = (prompt: string): string => {
return prompt
}

export const prepareConversation = (
conversation: ChatConversation[],
model: string
): ChatConversation[] => {
return conversation
.filter((msg) => msg.model === model)
.map((msg) => {
const msgFiltered: ChatConversation = { content: msg.content }
if (model === LlmsModelsEnum.BISON) {
msgFiltered.author = msg.author
} else {
msgFiltered.role = msg.role
}
return msgFiltered
})
}

export function extractPdfFilename (url: string): string | null {
const matches = url.match(/\/([^/]+\.pdf)$/)
if (matches) {
Expand Down
Loading
Loading