Skip to content

Commit

Permalink
Merge branch 'master' into stats-update
Browse files Browse the repository at this point in the history
  • Loading branch information
fegloff committed Jan 12, 2024
2 parents 6b0ef48 + f348cf8 commit bf0ceb5
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 121 deletions.
4 changes: 3 additions & 1 deletion src/bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ function createInitialSessionData (): BotSessionData {
imageGen: {
numImages: config.openAi.dalle.sessionDefault.numImages,
imgSize: config.openAi.dalle.sessionDefault.imgSize,
isEnabled: config.openAi.dalle.isEnabled
isEnabled: config.openAi.dalle.isEnabled,
imgRequestQueue: [],
isProcessingQueue: false
},
chatGpt: {
model: config.openAi.chatGpt.model,
Expand Down
7 changes: 3 additions & 4 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export default {
model: 'chat-bison',
minimumBalance: 0,
isEnabled: Boolean(parseInt(process.env.LLMS_ENABLED ?? '1')),
prefixes: { bardPrefix: [',', 'b.', 'B.'] },
prefixes: { bardPrefix: ['b.', 'B.'] },
pdfUrl: process.env.PDF_URL ?? '',
processingTime: 300000
},
Expand All @@ -54,6 +54,8 @@ export default {
defaultPrompt:
'beautiful waterfall in a lush jungle, with sunlight shining through the trees',
sessionDefault: {
model: 'dall-e-3',
quality: 'hd',
numImages: 1,
imgSize: '1024x1024'
}
Expand Down Expand Up @@ -82,9 +84,6 @@ export default {
chatPrefix: process.env.ASK_PREFIX
? process.env.ASK_PREFIX.split(',')
: ['a.', '.'], // , "?", ">",
dallePrefix: process.env.DALLE_PREFIX
? process.env.DALLE_PREFIX.split(',')
: ['d.'],
newPrefix: process.env.NEW_PREFIX
? process.env.NEW_PREFIX.split(',')
: ['n.', '..'],
Expand Down
3 changes: 3 additions & 0 deletions src/modules/open-ai/api/openAi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ export async function postGenerateImg (
imgSize?: string
): Promise<OpenAI.Images.Image[]> {
const payload = {
model: config.openAi.dalle.sessionDefault.model,
quality: config.openAi.dalle.sessionDefault.quality,
prompt,
n: numImgs ?? config.openAi.dalle.sessionDefault.numImages,
size: imgSize ?? config.openAi.dalle.sessionDefault.imgSize
}
const response = await openai.images.generate(
payload as OpenAI.Images.ImageGenerateParams
)
console.log(response)
return response.data
}

Expand Down
34 changes: 19 additions & 15 deletions src/modules/open-ai/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@ export const SupportedCommands = {
ask32: { name: 'ask32' },
gpt: { name: 'gpt' },
last: { name: 'last' },
dalle: { name: 'DALLE' },
dalleLC: { name: 'dalle' },
dalle: { name: 'dalle' },
dalleImg: { name: 'image' },
dalleShort: { name: 'img' },
dalleShorter: { name: 'i' },
genImgEn: { name: 'genImgEn' },
on: { name: 'on' },
off: { name: 'off' }
}

export const MAX_TRIES = 3

const DALLE_PREFIX_LIST = ['i. ', ',', 'image ', 'd.', 'img ', 'i ']

export const isMentioned = (
ctx: OnMessageContext | OnCallBackQueryData
): boolean => {
Expand Down Expand Up @@ -52,7 +56,7 @@ export const hasChatPrefix = (prompt: string): string => {
}

export const hasDallePrefix = (prompt: string): string => {
const prefixList = config.openAi.chatGpt.prefixes.dallePrefix
const prefixList = DALLE_PREFIX_LIST
for (let i = 0; i < prefixList.length; i++) {
if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) {
return prefixList[i]
Expand Down Expand Up @@ -259,6 +263,18 @@ export const limitPrompt = (prompt: string): string => {
return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words`
}

export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => {
const entities = ctx.message?.reply_to_message?.entities
if (entities) {
const urlEntity = entities.find(e => e.type === 'url')
if (urlEntity) {
const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length)
return url
}
}
return undefined
}

// export async function addUrlToCollection (ctx: OnMessageContext | OnCallBackQueryData, chatId: number, url: string, prompt: string): Promise<void> {
// const collectionName = await llmAddUrlDocument({
// chatId,
Expand All @@ -278,15 +294,3 @@ export const limitPrompt = (prompt: string): string => {
// msgId
// })
// }

export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => {
const entities = ctx.message?.reply_to_message?.entities
if (entities) {
const urlEntity = entities.find(e => e.type === 'url')
if (urlEntity) {
const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length)
return url
}
}
return undefined
}
181 changes: 90 additions & 91 deletions src/modules/open-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
getMessageExtras,
getPromptPrice,
hasChatPrefix,
hasDallePrefix,
hasNewPrefix,
hasPrefix,
hasUrl,
Expand All @@ -43,6 +44,7 @@ import { now } from '../../utils/perf'
import { AxiosError } from 'axios'
import { Callbacks } from '../types'
import { LlmsBot } from '../llms'
import { type PhotoSize } from 'grammy/types'

export class OpenAIBot implements PayableBot {
public readonly module = 'OpenAIBot'
Expand Down Expand Up @@ -99,8 +101,10 @@ export class OpenAIBot implements PayableBot {
return 0
}
if (
ctx.hasCommand(SupportedCommands.dalle.name) ||
ctx.hasCommand(SupportedCommands.dalleLC.name)
ctx.hasCommand([SupportedCommands.dalle.name,
SupportedCommands.dalleImg.name,
SupportedCommands.dalleShort.name,
SupportedCommands.dalleShorter.name])
) {
const imageNumber = ctx.session.openAi.imageGen.numImages
const imageSize = ctx.session.openAi.imageGen.imgSize
Expand Down Expand Up @@ -156,7 +160,19 @@ export class OpenAIBot implements PayableBot {
ctx.transient.analytics.sessionState = RequestState.Success

if (this.isSupportedImageReply(ctx)) {
await this.onAlterImage(ctx)
const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
const prompt = ctx.message?.caption ?? ctx.message?.text
ctx.session.openAi.imageGen.imgRequestQueue.push({
prompt,
photo,
command: 'alter'
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
return
}

Expand Down Expand Up @@ -213,11 +229,26 @@ export class OpenAIBot implements PayableBot {
}

if (
ctx.hasCommand(SupportedCommands.dalle.name) ||
ctx.hasCommand(SupportedCommands.dalleLC.name) ||
(ctx.message?.text?.startsWith('dalle ') && ctx.chat?.type === 'private')
ctx.hasCommand([SupportedCommands.dalle.name,
SupportedCommands.dalleImg.name,
SupportedCommands.dalleShort.name,
SupportedCommands.dalleShorter.name]) ||
(ctx.message?.text?.startsWith('image ') && ctx.chat?.type === 'private')
) {
await this.onGenImgCmd(ctx)
let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string
if (!prompt || prompt.split(' ').length === 1) {
prompt = config.openAi.dalle.defaultPrompt
}
ctx.session.openAi.imageGen.imgRequestQueue.push({
command: 'dalle',
prompt
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
return
}

Expand All @@ -231,13 +262,34 @@ export class OpenAIBot implements PayableBot {
return
}

if (hasNewPrefix(ctx.message?.text ?? '') !== '') {
const text = ctx.message?.text ?? ''

if (hasNewPrefix(text) !== '') {
await this.onEnd(ctx)
await this.onPrefix(ctx)
return
}

if (hasChatPrefix(ctx.message?.text ?? '') !== '') {
if (hasDallePrefix(text) !== '') {
const prefix = hasDallePrefix(text)
let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string
if (!prompt || prompt.split(' ').length === 1) {
prompt = config.openAi.dalle.defaultPrompt
}
ctx.session.openAi.imageGen.imgRequestQueue.push({
command: 'dalle',
prompt: prompt.slice(prefix.length)
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
return
}

if (hasChatPrefix(text) !== '') {
await this.onPrefix(ctx)
return
}
Expand Down Expand Up @@ -497,23 +549,44 @@ export class OpenAIBot implements PayableBot {
}
}

onGenImgCmd = async (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
let prompt = (ctx.match ? ctx.match : ctx.message?.text) as string
if (!prompt || prompt.split(' ').length === 1) {
prompt = config.openAi.dalle.defaultPrompt
async onImgRequestHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
while (ctx.session.openAi.imageGen.imgRequestQueue.length > 0) {
try {
const img = ctx.session.openAi.imageGen.imgRequestQueue.shift()
if (await this.hasBalance(ctx)) {
if (img?.command === 'dalle') {
await this.onGenImgCmd(img?.prompt, ctx)
} else {
await this.onAlterImage(img?.photo, img?.prompt, ctx)
}
ctx.chatAction = null
} else {
await this.onNotBalanceMessage(ctx)
}
} catch (e: any) {
await this.onError(ctx, e)
}
}
}

onGenImgCmd = async (prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled && ctx.chat?.id) {
ctx.chatAction = 'upload_photo'
// eslint-disable-next-line @typescript-eslint/naming-convention
const { message_id } = await ctx.reply(
'Generating dalle image...', { message_thread_id: ctx.message?.message_thread_id }
)
const numImages = ctx.session.openAi.imageGen.numImages
const imgSize = ctx.session.openAi.imageGen.imgSize
const imgs = await postGenerateImg(prompt, numImages, imgSize)
const imgs = await postGenerateImg(prompt ?? '', numImages, imgSize)
const msgExtras = getMessageExtras({ caption: `/dalle ${prompt}` })
await Promise.all(imgs.map(async (img: any) => {
await ctx.replyWithPhoto(img.url, msgExtras).catch(async (e) => {
await this.onError(ctx, e, MAX_TRIES)
})
}))
await ctx.api.deleteMessage(ctx.chat?.id, message_id)
ctx.transient.analytics.sessionState = RequestState.Success
ctx.transient.analytics.actualResponseTime = now()
} else {
Expand All @@ -533,12 +606,9 @@ export class OpenAIBot implements PayableBot {
}
}

onAlterImage = async (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
onAlterImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
const photo =
ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
const prompt = ctx.message?.caption ?? ctx.message?.text
const fileId = photo?.pop()?.file_id // with pop() get full image quality
if (!fileId) {
await ctx.reply('Cannot retrieve the image file. Please try again.')
Expand Down Expand Up @@ -710,74 +780,3 @@ export class OpenAIBot implements PayableBot {
}
}
}

// onGenImgEnCmd = async (ctx: OnMessageContext | OnCallBackQueryData) => {
// try {
// if (ctx.session.openAi.imageGen.isEnabled) {
// const prompt = await ctx.match;
// if (!prompt) {
// sendMessage(ctx, "Error: Missing prompt", {
// topicId: ctx.message?.message_thread_id,
// }).catch((e) =>
// this.onError(ctx, e, MAX_TRIES, "Error: Missing prompt")
// );
// return;
// }
// const payload = {
// chatId: await ctx.chat?.id!,
// prompt: prompt as string,
// numImages: await ctx.session.openAi.imageGen.numImages,
// imgSize: await ctx.session.openAi.imageGen.imgSize,
// };
// sendMessage(ctx, "generating improved prompt...", {
// topicId: ctx.message?.message_thread_id,
// }).catch((e) =>
// this.onError(ctx, e, MAX_TRIES, "generating improved prompt...")
// );
// await imgGenEnhanced(payload, ctx);
// } else {
// sendMessage(ctx, "Bot disabled", {
// topicId: ctx.message?.message_thread_id,
// }).catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled"));
// }
// } catch (e) {
// this.onError(ctx, e);
// }
// };

// private async imgGenEnhanced(
// data: ImageGenPayload,
// ctx: OnMessageContext | OnCallBackQueryData
// ) {
// const { chatId, prompt, numImages, imgSize, model } = data;
// try {
// const upgratedPrompt = await improvePrompt(prompt, model!);
// if (upgratedPrompt) {
// await ctx
// .reply(
// `The following description was added to your prompt: ${upgratedPrompt}`
// )
// .catch((e) => {
// throw e;
// });
// }
// // bot.api.sendMessage(chatId, "generating the output...");
// const imgs = await postGenerateImg(
// upgratedPrompt || prompt,
// numImages,
// imgSize
// );
// imgs.map(async (img: any) => {
// await ctx
// .replyWithPhoto(img.url, {
// caption: `/DALLE ${upgratedPrompt || prompt}`,
// })
// .catch((e) => {
// throw e;
// });
// });
// return true;
// } catch (e) {
// throw e;
// }
// };
Loading

0 comments on commit bf0ceb5

Please sign in to comment.