diff --git a/src/bot.ts b/src/bot.ts index 7c476166..0348b6df 100644 --- a/src/bot.ts +++ b/src/bot.ts @@ -186,7 +186,9 @@ function createInitialSessionData (): BotSessionData { imageGen: { numImages: config.openAi.dalle.sessionDefault.numImages, imgSize: config.openAi.dalle.sessionDefault.imgSize, - isEnabled: config.openAi.dalle.isEnabled + isEnabled: config.openAi.dalle.isEnabled, + imgRequestQueue: [], + isProcessingQueue: false }, chatGpt: { model: config.openAi.chatGpt.model, diff --git a/src/config.ts b/src/config.ts index af086989..a917ce31 100644 --- a/src/config.ts +++ b/src/config.ts @@ -38,7 +38,7 @@ export default { model: 'chat-bison', minimumBalance: 0, isEnabled: Boolean(parseInt(process.env.LLMS_ENABLED ?? '1')), - prefixes: { bardPrefix: [',', 'b.', 'B.'] }, + prefixes: { bardPrefix: ['b.', 'B.'] }, pdfUrl: process.env.PDF_URL ?? '', processingTime: 300000 }, @@ -54,6 +54,8 @@ export default { defaultPrompt: 'beautiful waterfall in a lush jungle, with sunlight shining through the trees', sessionDefault: { + model: 'dall-e-3', + quality: 'hd', numImages: 1, imgSize: '1024x1024' } @@ -82,9 +84,6 @@ export default { chatPrefix: process.env.ASK_PREFIX ? process.env.ASK_PREFIX.split(',') : ['a.', '.'], // , "?", ">", - dallePrefix: process.env.DALLE_PREFIX - ? process.env.DALLE_PREFIX.split(',') - : ['d.'], newPrefix: process.env.NEW_PREFIX ? process.env.NEW_PREFIX.split(',') : ['n.', '..'], diff --git a/src/modules/open-ai/api/openAi.ts b/src/modules/open-ai/api/openAi.ts index dc7b8188..d1d431d5 100644 --- a/src/modules/open-ai/api/openAi.ts +++ b/src/modules/open-ai/api/openAi.ts @@ -35,6 +35,8 @@ export async function postGenerateImg ( imgSize?: string ): Promise { const payload = { + model: config.openAi.dalle.sessionDefault.model, + quality: config.openAi.dalle.sessionDefault.quality, prompt, n: numImgs ?? config.openAi.dalle.sessionDefault.numImages, size: imgSize ?? config.openAi.dalle.sessionDefault.imgSize @@ -42,6 +44,7 @@ export async function postGenerateImg ( const response = await openai.images.generate( payload as OpenAI.Images.ImageGenerateParams ) + console.log(response) return response.data } diff --git a/src/modules/open-ai/helpers.ts b/src/modules/open-ai/helpers.ts index 847cb460..a873124f 100644 --- a/src/modules/open-ai/helpers.ts +++ b/src/modules/open-ai/helpers.ts @@ -16,8 +16,10 @@ export const SupportedCommands = { ask32: { name: 'ask32' }, gpt: { name: 'gpt' }, last: { name: 'last' }, - dalle: { name: 'DALLE' }, - dalleLC: { name: 'dalle' }, + dalle: { name: 'dalle' }, + dalleImg: { name: 'image' }, + dalleShort: { name: 'img' }, + dalleShorter: { name: 'i' }, genImgEn: { name: 'genImgEn' }, on: { name: 'on' }, off: { name: 'off' } @@ -25,6 +27,8 @@ export const SupportedCommands = { export const MAX_TRIES = 3 +const DALLE_PREFIX_LIST = ['i. ', ',', 'image ', 'd.', 'img ', 'i '] + export const isMentioned = ( ctx: OnMessageContext | OnCallBackQueryData ): boolean => { @@ -52,7 +56,7 @@ export const hasChatPrefix = (prompt: string): string => { } export const hasDallePrefix = (prompt: string): string => { - const prefixList = config.openAi.chatGpt.prefixes.dallePrefix + const prefixList = DALLE_PREFIX_LIST for (let i = 0; i < prefixList.length; i++) { if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { return prefixList[i] @@ -259,6 +263,18 @@ export const limitPrompt = (prompt: string): string => { return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words` } +export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => { + const entities = ctx.message?.reply_to_message?.entities + if (entities) { + const urlEntity = entities.find(e => e.type === 'url') + if (urlEntity) { + const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length) + return url + } + } + return undefined +} + // export async function addUrlToCollection (ctx: OnMessageContext | OnCallBackQueryData, chatId: number, url: string, prompt: string): Promise { // const collectionName = await llmAddUrlDocument({ // chatId, @@ -278,15 +294,3 @@ export const limitPrompt = (prompt: string): string => { // msgId // }) // } - -export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => { - const entities = ctx.message?.reply_to_message?.entities - if (entities) { - const urlEntity = entities.find(e => e.type === 'url') - if (urlEntity) { - const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length) - return url - } - } - return undefined -} diff --git a/src/modules/open-ai/index.ts b/src/modules/open-ai/index.ts index 6c76010d..809006aa 100644 --- a/src/modules/open-ai/index.ts +++ b/src/modules/open-ai/index.ts @@ -29,6 +29,7 @@ import { getMessageExtras, getPromptPrice, hasChatPrefix, + hasDallePrefix, hasNewPrefix, hasPrefix, hasUrl, @@ -43,6 +44,7 @@ import { now } from '../../utils/perf' import { AxiosError } from 'axios' import { Callbacks } from '../types' import { LlmsBot } from '../llms' +import { type PhotoSize } from 'grammy/types' export class OpenAIBot implements PayableBot { public readonly module = 'OpenAIBot' @@ -99,8 +101,10 @@ export class OpenAIBot implements PayableBot { return 0 } if ( - ctx.hasCommand(SupportedCommands.dalle.name) || - ctx.hasCommand(SupportedCommands.dalleLC.name) + ctx.hasCommand([SupportedCommands.dalle.name, + SupportedCommands.dalleImg.name, + SupportedCommands.dalleShort.name, + SupportedCommands.dalleShorter.name]) ) { const imageNumber = ctx.session.openAi.imageGen.numImages const imageSize = ctx.session.openAi.imageGen.imgSize @@ -156,7 +160,19 @@ export class OpenAIBot implements PayableBot { ctx.transient.analytics.sessionState = RequestState.Success if (this.isSupportedImageReply(ctx)) { - await this.onAlterImage(ctx) + const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo + const prompt = ctx.message?.caption ?? ctx.message?.text + ctx.session.openAi.imageGen.imgRequestQueue.push({ + prompt, + photo, + command: 'alter' + }) + if (!ctx.session.openAi.imageGen.isProcessingQueue) { + ctx.session.openAi.imageGen.isProcessingQueue = true + await this.onImgRequestHandler(ctx).then(() => { + ctx.session.openAi.imageGen.isProcessingQueue = false + }) + } return } @@ -213,11 +229,26 @@ export class OpenAIBot implements PayableBot { } if ( - ctx.hasCommand(SupportedCommands.dalle.name) || - ctx.hasCommand(SupportedCommands.dalleLC.name) || - (ctx.message?.text?.startsWith('dalle ') && ctx.chat?.type === 'private') + ctx.hasCommand([SupportedCommands.dalle.name, + SupportedCommands.dalleImg.name, + SupportedCommands.dalleShort.name, + SupportedCommands.dalleShorter.name]) || + (ctx.message?.text?.startsWith('image ') && ctx.chat?.type === 'private') ) { - await this.onGenImgCmd(ctx) + let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string + if (!prompt || prompt.split(' ').length === 1) { + prompt = config.openAi.dalle.defaultPrompt + } + ctx.session.openAi.imageGen.imgRequestQueue.push({ + command: 'dalle', + prompt + }) + if (!ctx.session.openAi.imageGen.isProcessingQueue) { + ctx.session.openAi.imageGen.isProcessingQueue = true + await this.onImgRequestHandler(ctx).then(() => { + ctx.session.openAi.imageGen.isProcessingQueue = false + }) + } return } @@ -231,13 +262,34 @@ export class OpenAIBot implements PayableBot { return } - if (hasNewPrefix(ctx.message?.text ?? '') !== '') { + const text = ctx.message?.text ?? '' + + if (hasNewPrefix(text) !== '') { await this.onEnd(ctx) await this.onPrefix(ctx) return } - if (hasChatPrefix(ctx.message?.text ?? '') !== '') { + if (hasDallePrefix(text) !== '') { + const prefix = hasDallePrefix(text) + let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string + if (!prompt || prompt.split(' ').length === 1) { + prompt = config.openAi.dalle.defaultPrompt + } + ctx.session.openAi.imageGen.imgRequestQueue.push({ + command: 'dalle', + prompt: prompt.slice(prefix.length) + }) + if (!ctx.session.openAi.imageGen.isProcessingQueue) { + ctx.session.openAi.imageGen.isProcessingQueue = true + await this.onImgRequestHandler(ctx).then(() => { + ctx.session.openAi.imageGen.isProcessingQueue = false + }) + } + return + } + + if (hasChatPrefix(text) !== '') { await this.onPrefix(ctx) return } @@ -497,23 +549,44 @@ export class OpenAIBot implements PayableBot { } } - onGenImgCmd = async (ctx: OnMessageContext | OnCallBackQueryData): Promise => { - try { - if (ctx.session.openAi.imageGen.isEnabled) { - let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string - if (!prompt || prompt.split(' ').length === 1) { - prompt = config.openAi.dalle.defaultPrompt + async onImgRequestHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise { + while (ctx.session.openAi.imageGen.imgRequestQueue.length > 0) { + try { + const img = ctx.session.openAi.imageGen.imgRequestQueue.shift() + if (await this.hasBalance(ctx)) { + if (img?.command === 'dalle') { + await this.onGenImgCmd(img?.prompt, ctx) + } else { + await this.onAlterImage(img?.photo, img?.prompt, ctx) + } + ctx.chatAction = null + } else { + await this.onNotBalanceMessage(ctx) } + } catch (e: any) { + await this.onError(ctx, e) + } + } + } + + onGenImgCmd = async (prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { + try { + if (ctx.session.openAi.imageGen.isEnabled && ctx.chat?.id) { ctx.chatAction = 'upload_photo' + // eslint-disable-next-line @typescript-eslint/naming-convention + const { message_id } = await ctx.reply( + 'Generating dalle image...', { message_thread_id: ctx.message?.message_thread_id } + ) const numImages = ctx.session.openAi.imageGen.numImages const imgSize = ctx.session.openAi.imageGen.imgSize - const imgs = await postGenerateImg(prompt, numImages, imgSize) + const imgs = await postGenerateImg(prompt ?? '', numImages, imgSize) const msgExtras = getMessageExtras({ caption: `/dalle ${prompt}` }) await Promise.all(imgs.map(async (img: any) => { await ctx.replyWithPhoto(img.url, msgExtras).catch(async (e) => { await this.onError(ctx, e, MAX_TRIES) }) })) + await ctx.api.deleteMessage(ctx.chat?.id, message_id) ctx.transient.analytics.sessionState = RequestState.Success ctx.transient.analytics.actualResponseTime = now() } else { @@ -533,12 +606,9 @@ export class OpenAIBot implements PayableBot { } } - onAlterImage = async (ctx: OnMessageContext | OnCallBackQueryData): Promise => { + onAlterImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { try { if (ctx.session.openAi.imageGen.isEnabled) { - const photo = - ctx.message?.photo ?? ctx.message?.reply_to_message?.photo - const prompt = ctx.message?.caption ?? ctx.message?.text const fileId = photo?.pop()?.file_id // with pop() get full image quality if (!fileId) { await ctx.reply('Cannot retrieve the image file. Please try again.') @@ -710,74 +780,3 @@ export class OpenAIBot implements PayableBot { } } } - -// onGenImgEnCmd = async (ctx: OnMessageContext | OnCallBackQueryData) => { -// try { -// if (ctx.session.openAi.imageGen.isEnabled) { -// const prompt = await ctx.match; -// if (!prompt) { -// sendMessage(ctx, "Error: Missing prompt", { -// topicId: ctx.message?.message_thread_id, -// }).catch((e) => -// this.onError(ctx, e, MAX_TRIES, "Error: Missing prompt") -// ); -// return; -// } -// const payload = { -// chatId: await ctx.chat?.id!, -// prompt: prompt as string, -// numImages: await ctx.session.openAi.imageGen.numImages, -// imgSize: await ctx.session.openAi.imageGen.imgSize, -// }; -// sendMessage(ctx, "generating improved prompt...", { -// topicId: ctx.message?.message_thread_id, -// }).catch((e) => -// this.onError(ctx, e, MAX_TRIES, "generating improved prompt...") -// ); -// await imgGenEnhanced(payload, ctx); -// } else { -// sendMessage(ctx, "Bot disabled", { -// topicId: ctx.message?.message_thread_id, -// }).catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled")); -// } -// } catch (e) { -// this.onError(ctx, e); -// } -// }; - -// private async imgGenEnhanced( -// data: ImageGenPayload, -// ctx: OnMessageContext | OnCallBackQueryData -// ) { -// const { chatId, prompt, numImages, imgSize, model } = data; -// try { -// const upgratedPrompt = await improvePrompt(prompt, model!); -// if (upgratedPrompt) { -// await ctx -// .reply( -// `The following description was added to your prompt: ${upgratedPrompt}` -// ) -// .catch((e) => { -// throw e; -// }); -// } -// // bot.api.sendMessage(chatId, "generating the output..."); -// const imgs = await postGenerateImg( -// upgratedPrompt || prompt, -// numImages, -// imgSize -// ); -// imgs.map(async (img: any) => { -// await ctx -// .replyWithPhoto(img.url, { -// caption: `/DALLE ${upgratedPrompt || prompt}`, -// }) -// .catch((e) => { -// throw e; -// }); -// }); -// return true; -// } catch (e) { -// throw e; -// } -// }; diff --git a/src/modules/open-ai/types.ts b/src/modules/open-ai/types.ts index 8c2e6b6f..46a60c9e 100644 --- a/src/modules/open-ai/types.ts +++ b/src/modules/open-ai/types.ts @@ -50,16 +50,24 @@ export const ChatGPTModels: Record = { } export const DalleGPTModels: Record = { + '1024x1792': { + size: '1024x1792', + price: 0.10 + }, + '1792x1024': { + size: '1792x1024', + price: 0.10 + }, '1024x1024': { size: '1024x1024', - price: 0.02 + price: 0.10 }, '512x512': { size: '512x512', - price: 0.018 + price: 0.10 }, '256x256': { size: '256x256', - price: 0.016 + price: 0.10 } } diff --git a/src/modules/sd-images/helpers.ts b/src/modules/sd-images/helpers.ts index 8b2f2e03..9025679e 100644 --- a/src/modules/sd-images/helpers.ts +++ b/src/modules/sd-images/helpers.ts @@ -5,9 +5,9 @@ import { getLoraByParam, type ILora } from './api/loras-config' import { childrenWords, sexWords } from './words-blacklist' export enum COMMAND { - TEXT_TO_IMAGE = 'image', + TEXT_TO_IMAGE = 'sdimage', IMAGE_TO_IMAGE = 'img2img', - TEXT_TO_IMAGES = 'images', + TEXT_TO_IMAGES = 'sdimages', CONSTRUCTOR = 'constructor', HELP = 'help', TRAIN = 'train' @@ -38,7 +38,7 @@ const removeSpaceFromBegin = (text: string): string => { return text.slice(idx) } -const SPECIAL_IMG_CMD_SYMBOLS = ['i.', 'l.', 'I.', '? ', '! ', ': ', '; ', 'r.', 'R.', 'd.', 'D.', '( ', '$ ', '& ', '< '] +const SPECIAL_IMG_CMD_SYMBOLS = ['l.', '? ', '! ', ': ', '; ', 'r.', 'R.', 'd.', 'D.', '( ', '$ ', '& ', '< '] export const getPrefix = (prompt: string, prefixList: string[]): string => { for (let i = 0; i < prefixList.length; i++) { @@ -126,13 +126,13 @@ export const parseCtx = (ctx: Context): IOperation | false => { } if ( - (hasCommand(ctx, 'image') || hasCommand(ctx, 'imagine')) || hasCommand(ctx, 'img') + (hasCommand(ctx, 'sdimage') || hasCommand(ctx, 'sdimagine')) || hasCommand(ctx, 'sdimg') ) { command = COMMAND.TEXT_TO_IMAGE } if ( - (hasCommand(ctx, 'image2') || hasCommand(ctx, 'imagine2')) || hasCommand(ctx, 'img2') + (hasCommand(ctx, 'sdimage2') || hasCommand(ctx, 'sdimagine2')) || hasCommand(ctx, 'sdimg2') ) { command = COMMAND.TEXT_TO_IMAGE model = model && ({ ...model, serverNumber: 2 }) @@ -146,7 +146,7 @@ export const parseCtx = (ctx: Context): IOperation | false => { lora = getLoraByParam('logo', model?.baseModel ?? 'SDXL 1.0') } - if (hasCommand(ctx, 'images')) { + if (hasCommand(ctx, 'sd-images')) { command = COMMAND.TEXT_TO_IMAGES } diff --git a/src/modules/types.ts b/src/modules/types.ts index 5ce60a29..2d34a5b6 100644 --- a/src/modules/types.ts +++ b/src/modules/types.ts @@ -9,7 +9,7 @@ import { type ConversationFlavor } from '@grammyjs/conversations' import { type AutoChatActionFlavor } from '@grammyjs/auto-chat-action' -import { type ParseMode } from 'grammy/types' +import { type PhotoSize, type ParseMode } from 'grammy/types' import { type InlineKeyboardMarkup } from 'grammy/out/types' import type { FileFlavor } from '@grammyjs/files' @@ -17,6 +17,8 @@ export interface ImageGenSessionData { numImages: number imgSize: string isEnabled: boolean + imgRequestQueue: ImageRequest[] + isProcessingQueue: boolean } export interface MessageExtras { @@ -43,6 +45,12 @@ export interface ChatConversation { content: string model?: string } + +export interface ImageRequest { + command?: 'dalle' | 'alter' + prompt?: string + photo?: PhotoSize[] | undefined +} export interface ChatGptSessionData { model: string isEnabled: boolean