Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dalle 3 integration #345

Merged
merged 7 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ function createInitialSessionData (): BotSessionData {
imageGen: {
numImages: config.openAi.dalle.sessionDefault.numImages,
imgSize: config.openAi.dalle.sessionDefault.imgSize,
isEnabled: config.openAi.dalle.isEnabled
isEnabled: config.openAi.dalle.isEnabled,
imgRequestQueue: [],
isProcessingQueue: false
},
chatGpt: {
model: config.openAi.chatGpt.model,
Expand Down
7 changes: 3 additions & 4 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export default {
model: 'chat-bison',
minimumBalance: 0,
isEnabled: Boolean(parseInt(process.env.LLMS_ENABLED ?? '1')),
prefixes: { bardPrefix: [',', 'b.', 'B.'] },
prefixes: { bardPrefix: ['b.', 'B.'] },
pdfUrl: process.env.PDF_URL ?? '',
processingTime: 300000
},
Expand All @@ -54,6 +54,8 @@ export default {
defaultPrompt:
'beautiful waterfall in a lush jungle, with sunlight shining through the trees',
sessionDefault: {
model: 'dall-e-3',
quality: 'hd',
numImages: 1,
imgSize: '1024x1024'
}
Expand Down Expand Up @@ -82,9 +84,6 @@ export default {
chatPrefix: process.env.ASK_PREFIX
? process.env.ASK_PREFIX.split(',')
: ['a.', '.'], // , "?", ">",
dallePrefix: process.env.DALLE_PREFIX
? process.env.DALLE_PREFIX.split(',')
: ['d.'],
newPrefix: process.env.NEW_PREFIX
? process.env.NEW_PREFIX.split(',')
: ['n.', '..'],
Expand Down
3 changes: 3 additions & 0 deletions src/modules/open-ai/api/openAi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ export async function postGenerateImg (
imgSize?: string
): Promise<OpenAI.Images.Image[]> {
const payload = {
model: config.openAi.dalle.sessionDefault.model,
quality: config.openAi.dalle.sessionDefault.quality,
prompt,
n: numImgs ?? config.openAi.dalle.sessionDefault.numImages,
size: imgSize ?? config.openAi.dalle.sessionDefault.imgSize
}
const response = await openai.images.generate(
payload as OpenAI.Images.ImageGenerateParams
)
console.log(response)
return response.data
}

Expand Down
34 changes: 19 additions & 15 deletions src/modules/open-ai/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@ export const SupportedCommands = {
ask32: { name: 'ask32' },
gpt: { name: 'gpt' },
last: { name: 'last' },
dalle: { name: 'DALLE' },
dalleLC: { name: 'dalle' },
dalle: { name: 'dalle' },
dalleImg: { name: 'image' },
dalleShort: { name: 'img' },
dalleShorter: { name: 'i' },
genImgEn: { name: 'genImgEn' },
on: { name: 'on' },
off: { name: 'off' }
}

export const MAX_TRIES = 3

const DALLE_PREFIX_LIST = ['i. ', ',', 'image ', 'd.', 'img ', 'i ']

export const isMentioned = (
ctx: OnMessageContext | OnCallBackQueryData
): boolean => {
Expand Down Expand Up @@ -52,7 +56,7 @@ export const hasChatPrefix = (prompt: string): string => {
}

export const hasDallePrefix = (prompt: string): string => {
const prefixList = config.openAi.chatGpt.prefixes.dallePrefix
const prefixList = DALLE_PREFIX_LIST
for (let i = 0; i < prefixList.length; i++) {
if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) {
return prefixList[i]
Expand Down Expand Up @@ -259,6 +263,18 @@ export const limitPrompt = (prompt: string): string => {
return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words`
}

export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => {
const entities = ctx.message?.reply_to_message?.entities
if (entities) {
const urlEntity = entities.find(e => e.type === 'url')
if (urlEntity) {
const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length)
return url
}
}
return undefined
}

// export async function addUrlToCollection (ctx: OnMessageContext | OnCallBackQueryData, chatId: number, url: string, prompt: string): Promise<void> {
// const collectionName = await llmAddUrlDocument({
// chatId,
Expand All @@ -278,15 +294,3 @@ export const limitPrompt = (prompt: string): string => {
// msgId
// })
// }

export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => {
const entities = ctx.message?.reply_to_message?.entities
if (entities) {
const urlEntity = entities.find(e => e.type === 'url')
if (urlEntity) {
const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length)
return url
}
}
return undefined
}
181 changes: 90 additions & 91 deletions src/modules/open-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
getMessageExtras,
getPromptPrice,
hasChatPrefix,
hasDallePrefix,
hasNewPrefix,
hasPrefix,
hasUrl,
Expand All @@ -43,6 +44,7 @@ import { now } from '../../utils/perf'
import { AxiosError } from 'axios'
import { Callbacks } from '../types'
import { LlmsBot } from '../llms'
import { type PhotoSize } from 'grammy/types'

export class OpenAIBot implements PayableBot {
public readonly module = 'OpenAIBot'
Expand Down Expand Up @@ -99,8 +101,10 @@ export class OpenAIBot implements PayableBot {
return 0
}
if (
ctx.hasCommand(SupportedCommands.dalle.name) ||
ctx.hasCommand(SupportedCommands.dalleLC.name)
ctx.hasCommand([SupportedCommands.dalle.name,
SupportedCommands.dalleImg.name,
SupportedCommands.dalleShort.name,
SupportedCommands.dalleShorter.name])
) {
const imageNumber = ctx.session.openAi.imageGen.numImages
const imageSize = ctx.session.openAi.imageGen.imgSize
Expand Down Expand Up @@ -156,7 +160,19 @@ export class OpenAIBot implements PayableBot {
ctx.transient.analytics.sessionState = RequestState.Success

if (this.isSupportedImageReply(ctx)) {
await this.onAlterImage(ctx)
const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
const prompt = ctx.message?.caption ?? ctx.message?.text
ctx.session.openAi.imageGen.imgRequestQueue.push({
prompt,
photo,
command: 'alter'
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
return
}

Expand Down Expand Up @@ -213,11 +229,26 @@ export class OpenAIBot implements PayableBot {
}

if (
ctx.hasCommand(SupportedCommands.dalle.name) ||
ctx.hasCommand(SupportedCommands.dalleLC.name) ||
(ctx.message?.text?.startsWith('dalle ') && ctx.chat?.type === 'private')
ctx.hasCommand([SupportedCommands.dalle.name,
SupportedCommands.dalleImg.name,
SupportedCommands.dalleShort.name,
SupportedCommands.dalleShorter.name]) ||
(ctx.message?.text?.startsWith('image ') && ctx.chat?.type === 'private')
) {
await this.onGenImgCmd(ctx)
let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string
if (!prompt || prompt.split(' ').length === 1) {
prompt = config.openAi.dalle.defaultPrompt
}
ctx.session.openAi.imageGen.imgRequestQueue.push({
command: 'dalle',
prompt
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
return
}

Expand All @@ -231,13 +262,34 @@ export class OpenAIBot implements PayableBot {
return
}

if (hasNewPrefix(ctx.message?.text ?? '') !== '') {
const text = ctx.message?.text ?? ''

if (hasNewPrefix(text) !== '') {
await this.onEnd(ctx)
await this.onPrefix(ctx)
return
}

if (hasChatPrefix(ctx.message?.text ?? '') !== '') {
if (hasDallePrefix(text) !== '') {
const prefix = hasDallePrefix(text)
let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string
if (!prompt || prompt.split(' ').length === 1) {
prompt = config.openAi.dalle.defaultPrompt
}
ctx.session.openAi.imageGen.imgRequestQueue.push({
command: 'dalle',
prompt: prompt.slice(prefix.length)
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
return
}

if (hasChatPrefix(text) !== '') {
await this.onPrefix(ctx)
return
}
Expand Down Expand Up @@ -497,23 +549,44 @@ export class OpenAIBot implements PayableBot {
}
}

onGenImgCmd = async (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string
if (!prompt || prompt.split(' ').length === 1) {
prompt = config.openAi.dalle.defaultPrompt
async onImgRequestHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
while (ctx.session.openAi.imageGen.imgRequestQueue.length > 0) {
try {
const img = ctx.session.openAi.imageGen.imgRequestQueue.shift()
if (await this.hasBalance(ctx)) {
if (img?.command === 'dalle') {
await this.onGenImgCmd(img?.prompt, ctx)
} else {
await this.onAlterImage(img?.photo, img?.prompt, ctx)
}
ctx.chatAction = null
} else {
await this.onNotBalanceMessage(ctx)
}
} catch (e: any) {
await this.onError(ctx, e)
}
}
}

onGenImgCmd = async (prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled && ctx.chat?.id) {
ctx.chatAction = 'upload_photo'
// eslint-disable-next-line @typescript-eslint/naming-convention
const { message_id } = await ctx.reply(
'Generating dalle image...', { message_thread_id: ctx.message?.message_thread_id }
)
const numImages = ctx.session.openAi.imageGen.numImages
const imgSize = ctx.session.openAi.imageGen.imgSize
const imgs = await postGenerateImg(prompt, numImages, imgSize)
const imgs = await postGenerateImg(prompt ?? '', numImages, imgSize)
const msgExtras = getMessageExtras({ caption: `/dalle ${prompt}` })
await Promise.all(imgs.map(async (img: any) => {
await ctx.replyWithPhoto(img.url, msgExtras).catch(async (e) => {
await this.onError(ctx, e, MAX_TRIES)
})
}))
await ctx.api.deleteMessage(ctx.chat?.id, message_id)
ctx.transient.analytics.sessionState = RequestState.Success
ctx.transient.analytics.actualResponseTime = now()
} else {
Expand All @@ -533,12 +606,9 @@ export class OpenAIBot implements PayableBot {
}
}

onAlterImage = async (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
onAlterImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
const photo =
ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
const prompt = ctx.message?.caption ?? ctx.message?.text
const fileId = photo?.pop()?.file_id // with pop() get full image quality
if (!fileId) {
await ctx.reply('Cannot retrieve the image file. Please try again.')
Expand Down Expand Up @@ -710,74 +780,3 @@ export class OpenAIBot implements PayableBot {
}
}
}

// onGenImgEnCmd = async (ctx: OnMessageContext | OnCallBackQueryData) => {
// try {
// if (ctx.session.openAi.imageGen.isEnabled) {
// const prompt = await ctx.match;
// if (!prompt) {
// sendMessage(ctx, "Error: Missing prompt", {
// topicId: ctx.message?.message_thread_id,
// }).catch((e) =>
// this.onError(ctx, e, MAX_TRIES, "Error: Missing prompt")
// );
// return;
// }
// const payload = {
// chatId: await ctx.chat?.id!,
// prompt: prompt as string,
// numImages: await ctx.session.openAi.imageGen.numImages,
// imgSize: await ctx.session.openAi.imageGen.imgSize,
// };
// sendMessage(ctx, "generating improved prompt...", {
// topicId: ctx.message?.message_thread_id,
// }).catch((e) =>
// this.onError(ctx, e, MAX_TRIES, "generating improved prompt...")
// );
// await imgGenEnhanced(payload, ctx);
// } else {
// sendMessage(ctx, "Bot disabled", {
// topicId: ctx.message?.message_thread_id,
// }).catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled"));
// }
// } catch (e) {
// this.onError(ctx, e);
// }
// };

// private async imgGenEnhanced(
// data: ImageGenPayload,
// ctx: OnMessageContext | OnCallBackQueryData
// ) {
// const { chatId, prompt, numImages, imgSize, model } = data;
// try {
// const upgratedPrompt = await improvePrompt(prompt, model!);
// if (upgratedPrompt) {
// await ctx
// .reply(
// `The following description was added to your prompt: ${upgratedPrompt}`
// )
// .catch((e) => {
// throw e;
// });
// }
// // bot.api.sendMessage(chatId, "generating the output...");
// const imgs = await postGenerateImg(
// upgratedPrompt || prompt,
// numImages,
// imgSize
// );
// imgs.map(async (img: any) => {
// await ctx
// .replyWithPhoto(img.url, {
// caption: `/DALLE ${upgratedPrompt || prompt}`,
// })
// .catch((e) => {
// throw e;
// });
// });
// return true;
// } catch (e) {
// throw e;
// }
// };
Loading