From ce88b3a69bbae54b446baf332f15aecea12938a4 Mon Sep 17 00:00:00 2001 From: fegloff Date: Tue, 16 Apr 2024 19:52:39 -0500 Subject: [PATCH] refactor session data --- src/bot.ts | 4 +- src/helpers.ts | 55 +++++++++++-------- src/modules/1country/index.ts | 2 +- src/modules/document-handler/index.ts | 2 +- src/modules/llms/api/openai.ts | 4 +- src/modules/llms/dalleBot.ts | 73 +++++++++++++------------- src/modules/llms/llmsBase.ts | 7 +-- src/modules/llms/menu/openaiMenu.ts | 8 +-- src/modules/llms/openaiBot.ts | 2 +- src/modules/subagents/llamaSubagent.ts | 35 ------------ src/modules/types.ts | 17 +----- 11 files changed, 86 insertions(+), 123 deletions(-) diff --git a/src/bot.ts b/src/bot.ts index 225c7363..8c739dde 100644 --- a/src/bot.ts +++ b/src/bot.ts @@ -333,7 +333,7 @@ const PayableBots: Record = { claudeBot: { bot: claudeBot }, vertexBot: { bot: vertexBot }, openAiBot: { - enabled: (ctx: OnMessageContext) => ctx.session.openAi.imageGen.isEnabled, + enabled: (ctx: OnMessageContext) => ctx.session.dalle.isEnabled, bot: openAiBot }, oneCountryBot: { bot: oneCountryBot } @@ -394,7 +394,7 @@ const onMessage = async (ctx: OnMessageContext): Promise => { return } // Any message interacts with ChatGPT (only for private chats or /ask on enabled on group chats) - if (ctx.update.message.chat && (ctx.chat.type === 'private' || ctx.session.openAi.chatGpt.isFreePromptChatGroups)) { + if (ctx.update.message.chat && (ctx.chat.type === 'private' || ctx.session.chatGpt.isFreePromptChatGroups)) { await openAiBot.onEvent(ctx, (e) => { logger.error(e) }) diff --git a/src/helpers.ts b/src/helpers.ts index d9ca1ba9..3d68b5aa 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -3,28 +3,28 @@ import { type BotSessionData } from './modules/types' export function createInitialSessionData (): BotSessionData { return { - openAi: { - imageGen: { - numImages: config.openAi.dalle.sessionDefault.numImages, - imgSize: config.openAi.dalle.sessionDefault.imgSize, - isEnabled: config.openAi.dalle.isEnabled, - imgRequestQueue: [], - isProcessingQueue: false, - imageGenerated: [], - isInscriptionLotteryEnabled: config.openAi.dalle.isInscriptionLotteryEnabled, - imgInquiried: [] - }, - chatGpt: { - model: config.openAi.chatGpt.model, - isEnabled: config.openAi.chatGpt.isEnabled, - isFreePromptChatGroups: config.openAi.chatGpt.isFreePromptChatGroups, - chatConversation: [], - price: 0, - usage: 0, - isProcessingQueue: false, - requestQueue: [] - } - }, + // openAi: { + // imageGen: { + // numImages: config.openAi.dalle.sessionDefault.numImages, + // imgSize: config.openAi.dalle.sessionDefault.imgSize, + // isEnabled: config.openAi.dalle.isEnabled, + // imgRequestQueue: [], + // isProcessingQueue: false, + // imageGenerated: [], + // isInscriptionLotteryEnabled: config.openAi.dalle.isInscriptionLotteryEnabled, + // imgInquiried: [] + // }, + // chatGpt: { + // model: config.openAi.chatGpt.model, + // isEnabled: config.openAi.chatGpt.isEnabled, + // isFreePromptChatGroups: config.openAi.chatGpt.isFreePromptChatGroups, + // chatConversation: [], + // price: 0, + // usage: 0, + // isProcessingQueue: false, + // requestQueue: [] + // } + // }, oneCountry: { lastDomain: '' }, translate: { languages: [], @@ -50,11 +50,22 @@ export function createInitialSessionData (): BotSessionData { chatGpt: { model: config.llms.model, isEnabled: config.llms.isEnabled, + isFreePromptChatGroups: config.openAi.chatGpt.isFreePromptChatGroups, chatConversation: [], price: 0, usage: 0, isProcessingQueue: false, requestQueue: [] + }, + dalle: { + numImages: config.openAi.dalle.sessionDefault.numImages, + imgSize: config.openAi.dalle.sessionDefault.imgSize, + isEnabled: config.openAi.dalle.isEnabled, + imgRequestQueue: [], + isProcessingQueue: false, + imageGenerated: [], + isInscriptionLotteryEnabled: config.openAi.dalle.isInscriptionLotteryEnabled, + imgInquiried: [] } } } diff --git a/src/modules/1country/index.ts b/src/modules/1country/index.ts index b3dbe5d1..83880f5b 100644 --- a/src/modules/1country/index.ts +++ b/src/modules/1country/index.ts @@ -554,7 +554,7 @@ export class OneCountryBot implements PayableBot { ).catch(async (e) => { await this.onError(ctx, e, retryCount - 1) }) ctx.transient.analytics.actualResponseTime = now() if (method === 'editMessageText') { - ctx.session.openAi.chatGpt.chatConversation.pop() // deletes last prompt + ctx.session.chatGpt.chatConversation.pop() // deletes last prompt } await sleep(retryAfter * 1000) // wait retryAfter seconds to enable bot this.botSuspended = false diff --git a/src/modules/document-handler/index.ts b/src/modules/document-handler/index.ts index 9d9db90d..46c44c00 100644 --- a/src/modules/document-handler/index.ts +++ b/src/modules/document-handler/index.ts @@ -116,7 +116,7 @@ export class DocumentHandler implements PayableBot { ).catch(async (e) => { await this.onError(ctx, e, retryCount - 1) }) ctx.transient.analytics.actualResponseTime = now() if (method === 'editMessageText') { - ctx.session.openAi.chatGpt.chatConversation.pop() // deletes last prompt + ctx.session.chatGpt.chatConversation.pop() // deletes last prompt } await sleep(retryAfter * 1000) // wait retryAfter seconds to enable bot } else { diff --git a/src/modules/llms/api/openai.ts b/src/modules/llms/api/openai.ts index 1c4cbf58..dc7816b0 100644 --- a/src/modules/llms/api/openai.ts +++ b/src/modules/llms/api/openai.ts @@ -164,7 +164,7 @@ export const streamChatCompletion = async ( } } completion = completion.replaceAll('...', '') - const inputTokens = getTokenNumber(conversation[conversation.length - 1].content as string) + ctx.session.openAi.chatGpt.usage + const inputTokens = getTokenNumber(conversation[conversation.length - 1].content as string) + ctx.session.chatGpt.usage const outputTokens = getTokenNumber(completion) await ctx.api .editMessageText(ctx.chat?.id, msgId, completion) @@ -257,7 +257,7 @@ export const streamChatVisionCompletion = async ( } } completion = completion.replaceAll('...', '') - const inputTokens = getTokenNumber(prompt) + ctx.session.openAi.chatGpt.usage + const inputTokens = getTokenNumber(prompt) + ctx.session.chatGpt.usage const outputTokens = getTokenNumber(completion) await ctx.api .editMessageText(ctx.chat?.id, msgId, completion) diff --git a/src/modules/llms/dalleBot.ts b/src/modules/llms/dalleBot.ts index 02866256..91b93f70 100644 --- a/src/modules/llms/dalleBot.ts +++ b/src/modules/llms/dalleBot.ts @@ -37,7 +37,7 @@ import { InlineKeyboard } from 'grammy' export class DalleBot extends LlmsBase { constructor (payments: BotPayments) { - super(payments, 'DalleBot', 'chatGpt') + super(payments, 'DalleBot', 'dalle') if (!config.openAi.dalle.isEnabled) { this.logger.warn('DALL·E 2 Image Bot is disabled in config') } @@ -45,13 +45,7 @@ export class DalleBot extends LlmsBase { public getEstimatedPrice (ctx: any): number { try { - // if (this.isSupportedImageReply(ctx) && !isNaN(+prompts)) { - // const imageNumber = ctx.message?.caption || ctx.message?.text - // const imageSize = ctx.session.openAi.imageGen.imgSize - // const model = getDalleModel(imageSize) - // const price = getDalleModelPrice(model, true, imageNumber) // cents - // return price * priceAdjustment - // } + const session = this.getSession(ctx) if ( ctx.hasCommand([ SupportedCommands.dalle, @@ -60,8 +54,8 @@ export class DalleBot extends LlmsBase { SupportedCommands.dalleShorter ]) ) { - const imageNumber = ctx.session.openAi.imageGen.numImages - const imageSize = ctx.session.openAi.imageGen.imgSize + const imageNumber = session.numImages + const imageSize = session.imgSize const model = getDalleModel(imageSize) const price = getDalleModelPrice(model, true, imageNumber) // cents return price * PRICE_ADJUSTMENT @@ -83,11 +77,11 @@ export class DalleBot extends LlmsBase { SupportedCommands.dalleShort, SupportedCommands.dalleShorter ]) - + const session = this.getSession(ctx) const photo = ctx.message?.reply_to_message?.photo if (photo) { const imgId = photo?.[0].file_unique_id ?? '' - if (ctx.session.openAi.imageGen.imgInquiried.find((i) => i === imgId)) { + if (session.imgInquiried.find((i) => i === imgId)) { return false } } @@ -112,8 +106,9 @@ export class DalleBot extends LlmsBase { } isSupportedImageReply (ctx: OnMessageContext | OnCallBackQueryData): boolean { + const session = this.getSession(ctx) const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo - if (photo && ctx.session.openAi.imageGen.isEnabled) { + if (photo && session.isEnabled) { const prompt = ctx.message?.caption ?? ctx.message?.text if ( prompt && @@ -165,6 +160,7 @@ export class DalleBot extends LlmsBase { refundCallback: (reason?: string) => void ): Promise { ctx.transient.analytics.module = this.module + const session = this.getSession(ctx) const isSupportedEvent = this.isSupportedEvent(ctx) if (!isSupportedEvent && ctx.chat?.type !== 'private') { this.logger.warn(`### unsupported command ${ctx.message?.text}`) @@ -175,16 +171,16 @@ export class DalleBot extends LlmsBase { const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo const prompt = ctx.message?.caption ?? ctx.message?.text ?? '' const imgId = photo?.[0].file_unique_id ?? '' - if (!ctx.session.openAi.imageGen.imgInquiried.find((i) => i === imgId)) { - ctx.session.openAi.imageGen.imgRequestQueue.push({ + if (!session.imgInquiried.find((i) => i === imgId)) { + session.imgRequestQueue.push({ prompt, photo, command: 'vision' // !isNaN(+prompt) ? 'alter' : 'vision' }) - if (!ctx.session.openAi.imageGen.isProcessingQueue) { - ctx.session.openAi.imageGen.isProcessingQueue = true + if (!session.isProcessingQueue) { + session.isProcessingQueue = true await this.onImgRequestHandler(ctx).then(() => { - ctx.session.openAi.imageGen.isProcessingQueue = false + session.isProcessingQueue = false }) } } else { @@ -197,15 +193,15 @@ export class DalleBot extends LlmsBase { const photoUrl = getUrlFromText(ctx) if (photoUrl) { const prompt = ctx.match - ctx.session.openAi.imageGen.imgRequestQueue.push({ + session.imgRequestQueue.push({ prompt, photoUrl, command: 'vision' // !isNaN(+prompt) ? 'alter' : 'vision' }) - if (!ctx.session.openAi.imageGen.isProcessingQueue) { - ctx.session.openAi.imageGen.isProcessingQueue = true + if (!session.isProcessingQueue) { + session.isProcessingQueue = true await this.onImgRequestHandler(ctx).then(() => { - ctx.session.openAi.imageGen.isProcessingQueue = false + session.isProcessingQueue = false }) } } @@ -240,14 +236,14 @@ export class DalleBot extends LlmsBase { refundCallback('Prompt has bad words') return } - ctx.session.openAi.imageGen.imgRequestQueue.push({ + session.imgRequestQueue.push({ command: 'dalle', prompt }) - if (!ctx.session.openAi.imageGen.isProcessingQueue) { - ctx.session.openAi.imageGen.isProcessingQueue = true + if (!session.isProcessingQueue) { + session.isProcessingQueue = true await this.onImgRequestHandler(ctx).then(() => { - ctx.session.openAi.imageGen.isProcessingQueue = false + session.isProcessingQueue = false }) } return @@ -261,17 +257,18 @@ export class DalleBot extends LlmsBase { } async onImgRequestHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise { - while (ctx.session.openAi.imageGen.imgRequestQueue.length > 0) { + const session = this.getSession(ctx) + while (session.imgRequestQueue.length > 0) { try { - const img = ctx.session.openAi.imageGen.imgRequestQueue.shift() - const minBalance = await getMinBalance(ctx, ctx.session.openAi.chatGpt.model) + const img = session.imgRequestQueue.shift() + const minBalance = await getMinBalance(ctx, ctx.session.chatGpt.model) if (await this.hasBalance(ctx, minBalance)) { if (img?.command === 'dalle') { await this.onGenImgCmd(img?.prompt, ctx) } else { await this.onInquiryImage(img?.photo, img?.photoUrl, img?.prompt, ctx) if (img?.photo?.[0].file_unique_id) { - ctx.session.openAi.imageGen.imgInquiried.push(img?.photo?.[0].file_unique_id) + session.imgInquiried.push(img?.photo?.[0].file_unique_id) } } ctx.chatAction = null @@ -286,19 +283,20 @@ export class DalleBot extends LlmsBase { async onGenImgCmd (prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise { try { - if (ctx.session.openAi.imageGen.isEnabled && ctx.chat?.id) { + const session = this.getSession(ctx) + if (session.isEnabled && ctx.chat?.id) { ctx.chatAction = 'upload_photo' // eslint-disable-next-line @typescript-eslint/naming-convention const { message_id } = await ctx.reply( 'Generating image via OpenAI\'s DALL·E 3...', { message_thread_id: ctx.message?.message_thread_id } ) - const numImages = ctx.session.openAi.imageGen.numImages - const imgSize = ctx.session.openAi.imageGen.imgSize + const numImages = session.numImages + const imgSize = session.imgSize const imgs = await postGenerateImg(prompt ?? '', numImages, imgSize) if (imgs.length > 0) { await Promise.all(imgs.map(async (img: any) => { - if (ctx.session.openAi.imageGen.isInscriptionLotteryEnabled) { - const inlineKeyboard = new InlineKeyboard().text('Share to enter lottery', `share-payload|${ctx.session.openAi.imageGen.imageGenerated.length}`) // ${imgs[0].url} + if (session.isInscriptionLotteryEnabled) { + const inlineKeyboard = new InlineKeyboard().text('Share to enter lottery', `share-payload|${session.imageGenerated.length}`) // ${imgs[0].url} const msgExtras = getMessageExtras({ caption: `/dalle ${prompt}\n\n Check [q.country](https://q.country) for general lottery information`, reply_markup: inlineKeyboard, @@ -308,7 +306,7 @@ export class DalleBot extends LlmsBase { const msg = await ctx.replyWithPhoto(img.url, msgExtras) const genImg = msg.photo const fileId = genImg?.pop()?.file_id - ctx.session.openAi.imageGen.imageGenerated.push({ prompt, photoUrl: img.url, photoId: fileId }) + session.imageGenerated.push({ prompt, photoUrl: img.url, photoId: fileId }) } else { const msgExtras = getMessageExtras({ caption: `/dalle ${prompt}` }) await ctx.replyWithPhoto(img.url, msgExtras) @@ -340,7 +338,8 @@ export class DalleBot extends LlmsBase { prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { try { - if (ctx.session.openAi.imageGen.isEnabled) { + const session = this.getSession(ctx) + if (session.isEnabled) { // let filePath = '' let imgList = [] if (photo) { diff --git a/src/modules/llms/llmsBase.ts b/src/modules/llms/llmsBase.ts index 5d5004c6..1b3ef4a9 100644 --- a/src/modules/llms/llmsBase.ts +++ b/src/modules/llms/llmsBase.ts @@ -11,7 +11,8 @@ import { RequestState, type BotSessionData, type LlmsSessionData, - type SubagentResult + type SubagentResult, + type ImageGenSessionData } from '../types' import { appText } from '../../utils/text' import { chatService } from '../../database/services' @@ -84,8 +85,8 @@ export abstract class LlmsBase implements PayableBot { this.subagents = subagents } - protected getSession (ctx: OnMessageContext | OnCallBackQueryData): LlmsSessionData { - return (ctx.session[this.sessionDataKey as keyof BotSessionData] as LlmsSessionData) + protected getSession (ctx: OnMessageContext | OnCallBackQueryData): LlmsSessionData & ImageGenSessionData { + return (ctx.session[this.sessionDataKey as keyof BotSessionData] as LlmsSessionData & ImageGenSessionData) } protected async runSubagents (ctx: OnMessageContext | OnCallBackQueryData, msg: ChatConversation): Promise { diff --git a/src/modules/llms/menu/openaiMenu.ts b/src/modules/llms/menu/openaiMenu.ts index c6291674..c2287457 100644 --- a/src/modules/llms/menu/openaiMenu.ts +++ b/src/modules/llms/menu/openaiMenu.ts @@ -51,18 +51,18 @@ const chatGPTimageDefaultOptions = new Menu(MenuIds.CHAT_GPT_MODEL) function getLabel (m: string, ctx: any): string { let label = m console.log( - ctx.session.openAi.chatGpt.model, + ctx.session.chatGpt.model, m, - ctx.session.openAi.chatGpt.model === m + ctx.session.chatGpt.model === m ) - if (ctx.session.openAi.chatGpt.model === m) { + if (ctx.session.chatGpt.model === m) { label += ' ✅' } return label } function setModel (m: string, ctx: any): void { - ctx.session.openAi.chatGpt.model = m + ctx.session.chatGpt.model = m ctx.menu.back() } diff --git a/src/modules/llms/openaiBot.ts b/src/modules/llms/openaiBot.ts index 03674481..3044bb51 100644 --- a/src/modules/llms/openaiBot.ts +++ b/src/modules/llms/openaiBot.ts @@ -156,7 +156,7 @@ export class OpenAIBot extends LlmsBase { return } - if (ctx.chat?.type === 'private' || ctx.session.openAi.chatGpt.isFreePromptChatGroups) { + if (ctx.chat?.type === 'private' || session.isFreePromptChatGroups) { await this.onChat(ctx, LlmsModelsEnum.GPT_4, true) return } diff --git a/src/modules/subagents/llamaSubagent.ts b/src/modules/subagents/llamaSubagent.ts index cd28c6ed..0d07d281 100644 --- a/src/modules/subagents/llamaSubagent.ts +++ b/src/modules/subagents/llamaSubagent.ts @@ -172,23 +172,6 @@ export class LlamaAgent extends SubagentBase { }) } - // public async onPdfReplyHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise { - // try { - // const fileName = this.isSupportedPdfReply(ctx) - // const prompt = ctx.message?.text ?? 'Summarize this context' - // if (fileName !== '') { - // const collection = ctx.session.collections.activeCollections.find(c => c.fileName === fileName) - // if (collection) { - // await this.queryUrlCollection(ctx, collection.url, prompt) - // } - // } - // ctx.transient.analytics.actualResponseTime = now() - // } catch (e: any) { - // this.logger.error(`onPdfReplyHandler error: ${e}`) - // throw e - // } - // } - private getCollectionConversation (ctx: OnMessageContext | OnCallBackQueryData, collection: Collection): ChatConversation[] { if (ctx.session.collections.currentCollection === collection.collectionName) { return ctx.session.collections.collectionConversation @@ -222,24 +205,6 @@ export class LlamaAgent extends SubagentBase { } } - // async onUrlReplyHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise { - // try { - // const url = getUrlFromText(ctx) - // if (url) { - // const prompt = ctx.message?.text ?? 'summarize' - // const collection = ctx.session.collections.activeCollections.find(c => c.url === url[0]) - // const newPrompt = `${prompt}` // ${url} - // if (collection) { - // await this.queryUrlCollection(ctx, url[0], newPrompt) - // } - // ctx.transient.analytics.actualResponseTime = now() - // } - // } catch (e: any) { - // this.logger.error(`onUrlReplyHandler: ${e.toString()}`) - // throw e - // } - // } - async onCheckCollectionStatus (ctx: OnMessageContext | OnCallBackQueryData): Promise { const processingTime = config.llms.processingTime const session = this.getSession(ctx) diff --git a/src/modules/types.ts b/src/modules/types.ts index 1dbed324..d3bb21b6 100644 --- a/src/modules/types.ts +++ b/src/modules/types.ts @@ -79,30 +79,17 @@ export interface promptRequest { outputFormat?: 'text' | 'voice' commandPrefix?: string } -export interface ChatGptSessionData { - model: string - isEnabled: boolean - isFreePromptChatGroups: boolean - chatConversation: ChatConversation[] - usage: number - price: number - requestQueue: promptRequest[] - isProcessingQueue: boolean -} export interface LlmsSessionData { model: string isEnabled: boolean + isFreePromptChatGroups?: boolean chatConversation: ChatConversation[] usage: number price: number requestQueue: ChatConversation[] isProcessingQueue: boolean } -export interface OpenAiSessionData { - imageGen: ImageGenSessionData - chatGpt: ChatGptSessionData -} export interface OneCountryData { lastDomain: string @@ -177,11 +164,11 @@ export interface SubagentSessionData { export interface BotSessionData { oneCountry: OneCountryData collections: CollectionSessionData - openAi: OpenAiSessionData translate: TranslateBotData llms: LlmsSessionData chatGpt: LlmsSessionData subagents: SubagentSessionData + dalle: ImageGenSessionData } export interface TransientStateContext {