Skip to content

Commit

Permalink
Merge pull request #360 from harmony-one/llms-refactoring
Browse files Browse the repository at this point in the history
Llms refactoring
  • Loading branch information
fegloff authored Apr 2, 2024
2 parents 002dc84 + 33e6fa8 commit 85fe199
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 40 deletions.
6 changes: 3 additions & 3 deletions src/modules/llms/api/athropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ export const anthropicCompletion = async (
model = LlmsModelsEnum.CLAUDE_OPUS
): Promise<LlmCompletion> => {
logger.info(`Handling ${model} completion`)

const data = {
model,
stream: false,
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: +config.openAi.chatGpt.maxTokens,
messages: conversation
messages: conversation.filter(c => c.model === model)
.map(m => { return { content: m.content, role: m.role } })
}
const url = `${API_ENDPOINT}/anthropic/completions`
const response = await axios.post(url, data)
Expand Down Expand Up @@ -68,7 +68,7 @@ export const anthropicStreamCompletion = async (
stream: true, // Set stream to true to receive the completion as a stream
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined,
messages: conversation.map(m => { return { content: m.content, role: m.role } })
messages: conversation.filter(c => c.model === model).map(m => { return { content: m.content, role: m.role } })
}
let wordCount = 0
let wordCountMinimum = 2
Expand Down
2 changes: 1 addition & 1 deletion src/modules/llms/api/llmApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export const llmCompletion = async (
const data = {
model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo',
stream: false,
messages: conversation
messages: conversation.filter(c => c.model === model)
}
const url = `${API_ENDPOINT}/llms/completions`
const response = await axios.post(url, data)
Expand Down
17 changes: 14 additions & 3 deletions src/modules/llms/api/vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,18 @@ export const vertexCompletion = async (
const data = {
model, // chat-bison@001 'chat-bison', //'gpt-3.5-turbo',
stream: false,
messages: conversation
messages: conversation.filter(c => c.model === model)
.map((msg) => {
const msgFiltered: ChatConversation = { content: msg.content, model: msg.model }
if (model === LlmsModelsEnum.BISON) {
msgFiltered.author = msg.role
} else {
msgFiltered.role = msg.role
}
return msgFiltered
})
}

const url = `${API_ENDPOINT}/vertex/completions`
const response = await axios.post(url, data)
if (response) {
Expand All @@ -34,7 +44,7 @@ export const vertexCompletion = async (
return {
completion: {
content: response.data._prediction_response[0][0].candidates[0].content,
author: 'bot',
role: 'bot', // role replace to author attribute will be done later
model
},
usage: totalOutputTokens + totalInputTokens,
Expand All @@ -60,7 +70,8 @@ export const vertexStreamCompletion = async (
stream: true, // Set stream to true to receive the completion as a stream
system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: limitTokens ? +config.openAi.chatGpt.maxTokens : undefined,
messages: conversation.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } })
messages: conversation.filter(c => c.model === model)
.map(m => { return { parts: { text: m.content }, role: m.role !== 'user' ? 'model' : 'user' } })
}
const url = `${API_ENDPOINT}/vertex/completions/gemini`
if (!ctx.chat?.id) {
Expand Down
22 changes: 2 additions & 20 deletions src/modules/llms/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ import {
type OnMessageContext,
type OnCallBackQueryData,
type MessageExtras,
type ChatConversation,
type ChatPayload
} from '../types'
import { type ParseMode } from 'grammy/types'
import { LlmsModelsEnum } from './types'
import { type Message } from 'grammy/out/types'
import { type LlmCompletion, getChatModel, llmAddUrlDocument } from './api/llmApi'
import { getChatModelPrice } from '../open-ai/api/openAi'
Expand All @@ -16,8 +14,9 @@ export enum SupportedCommands {
bardF = 'bard',
claudeOpus = 'claude',
opus = 'opus',
opusShort = 'o',
claudeSonnet = 'claudes',
opusShort = 'c',
claudeShort = 'c',
sonnet = 'sonnet',
sonnetShort = 's',
claudeHaiku = 'haiku',
Expand Down Expand Up @@ -254,23 +253,6 @@ export const limitPrompt = (prompt: string): string => {
return prompt
}

export const prepareConversation = (
conversation: ChatConversation[],
model: string
): ChatConversation[] => {
return conversation
.filter((msg) => msg.model === model)
.map((msg) => {
const msgFiltered: ChatConversation = { content: msg.content }
if (model === LlmsModelsEnum.BISON) {
msgFiltered.author = msg.author
} else {
msgFiltered.role = msg.role
}
return msgFiltered
})
}

export function extractPdfFilename (url: string): string | null {
const matches = url.match(/\/([^/]+\.pdf)$/)
if (matches) {
Expand Down
22 changes: 9 additions & 13 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import {
isMentioned,
limitPrompt,
MAX_TRIES,
prepareConversation,
SupportedCommands
} from './helpers'
import { getUrlFromText, preparePrompt, sendMessage } from '../open-ai/helpers'
Expand Down Expand Up @@ -134,7 +133,7 @@ export class LlmsBot implements PayableBot {
await this.onChat(ctx, LlmsModelsEnum.GEMINI)
return
}
if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) {
if (ctx.hasCommand([SupportedCommands.claudeOpus, SupportedCommands.opus, SupportedCommands.opusShort, SupportedCommands.claudeShort]) || (hasClaudeOpusPrefix(ctx.message?.text ?? '') !== '')) {
await this.onChat(ctx, LlmsModelsEnum.CLAUDE_OPUS)
return
}
Expand Down Expand Up @@ -601,7 +600,8 @@ export class LlmsBot implements PayableBot {
)
conversation.push({
role: 'assistant',
content: completion.completion?.content ?? ''
content: completion.completion?.content ?? '',
model
})
return {
price: price.price,
Expand All @@ -612,7 +612,8 @@ export class LlmsBot implements PayableBot {
const response = await anthropicCompletion(conversation, model as LlmsModelsEnum)
conversation.push({
role: 'assistant',
content: response.completion?.content ?? ''
content: response.completion?.content ?? '',
model
})
return {
price: response.price,
Expand Down Expand Up @@ -644,13 +645,12 @@ export class LlmsBot implements PayableBot {
usage: 0,
price: 0
}
const chat = prepareConversation(conversation, model)
if (model === LlmsModelsEnum.BISON) {
response = await vertexCompletion(chat, model) // "chat-bison@001");
response = await vertexCompletion(conversation, model) // "chat-bison@001");
} else if (model === LlmsModelsEnum.CLAUDE_OPUS || model === LlmsModelsEnum.CLAUDE_SONNET || model === LlmsModelsEnum.CLAUDE_HAIKU) {
response = await anthropicCompletion(chat, model)
response = await anthropicCompletion(conversation, model)
} else {
response = await llmCompletion(chat, model as LlmsModelsEnum)
response = await llmCompletion(conversation, model as LlmsModelsEnum)
}
if (response.completion) {
await ctx.api.editMessageText(
Expand Down Expand Up @@ -755,13 +755,9 @@ export class LlmsBot implements PayableBot {
}
const chat: ChatConversation = {
content: limitPrompt(prompt),
role: 'user',
model
}
if (model === LlmsModelsEnum.BISON) {
chat.author = 'user'
} else {
chat.role = 'user'
}
chatConversation.push(chat)
const payload = {
conversation: chatConversation,
Expand Down

0 comments on commit 85fe199

Please sign in to comment.