diff --git a/src/modules/llms/index.ts b/src/modules/llms/index.ts index 35b381b1..4696ddad 100644 --- a/src/modules/llms/index.ts +++ b/src/modules/llms/index.ts @@ -85,7 +85,7 @@ export class LlmsBot implements PayableBot { return undefined } - private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string | undefined { + private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined { return getUrlFromText(ctx) } @@ -251,14 +251,16 @@ export class LlmsBot implements PayableBot { async onUrlReplyHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise { try { - const url = getUrlFromText(ctx) ?? '' - const prompt = ctx.message?.text ?? 'summarize' - const collection = ctx.session.collections.activeCollections.find(c => c.url === url) - const newPrompt = `${prompt}` // ${url} - if (collection) { - await this.queryUrlCollection(ctx, url, newPrompt) + const url = getUrlFromText(ctx) + if (url) { + const prompt = ctx.message?.text ?? 'summarize' + const collection = ctx.session.collections.activeCollections.find(c => c.url === url[0]) + const newPrompt = `${prompt}` // ${url} + if (collection) { + await this.queryUrlCollection(ctx, url[0], newPrompt) + } + ctx.transient.analytics.actualResponseTime = now() } - ctx.transient.analytics.actualResponseTime = now() } catch (e: any) { await this.onError(ctx, e) } @@ -550,7 +552,7 @@ export class LlmsBot implements PayableBot { await ctx.api.editMessageText( ctx.chat.id, msgId, - response.completion.content + response.completion.content as string ) conversation.push(response.completion) // const price = getPromptPrice(completion, data); @@ -629,7 +631,7 @@ export class LlmsBot implements PayableBot { while (ctx.session.llms.requestQueue.length > 0) { try { const msg = ctx.session.llms.requestQueue.shift() - const prompt = msg?.content + const prompt = msg?.content as string const model = msg?.model const { chatConversation } = ctx.session.llms if (await this.hasBalance(ctx)) { diff --git a/src/modules/open-ai/api/openAi.ts b/src/modules/open-ai/api/openAi.ts index d1d431d5..ad34603d 100644 --- a/src/modules/open-ai/api/openAi.ts +++ b/src/modules/open-ai/api/openAi.ts @@ -1,7 +1,6 @@ import OpenAI from 'openai' import { encode } from 'gpt-tokenizer' import { GrammyError } from 'grammy' - import config from '../../../config' import { deleteFile, getImage } from '../utils/file' import { @@ -15,9 +14,12 @@ import { type ChatModel, ChatGPTModels, type DalleGPTModel, - DalleGPTModels + DalleGPTModels, + ChatGPTModelsEnum } from '../types' import type fs from 'fs' +import { type ChatCompletionMessageParam } from 'openai/resources/chat/completions' +import { type Stream } from 'openai/streaming' const openai = new OpenAI({ apiKey: config.openAiKey }) @@ -83,15 +85,12 @@ export async function chatCompletion ( model = config.openAi.chatGpt.model, limitTokens = true ): Promise { - const payload = { + const response = await openai.chat.completions.create({ model, max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined, temperature: config.openAi.dalle.completions.temperature, - messages: conversation - } - const response = await openai.chat.completions.create( - payload as OpenAI.Chat.CompletionCreateParamsNonStreaming - ) + messages: conversation as ChatCompletionMessageParam[] + }) const chatModel = getChatModel(model) if (response.usage?.prompt_tokens === undefined) { throw new Error('Unknown number of prompt tokens used') @@ -120,7 +119,7 @@ export const streamChatCompletion = async ( let wordCountMinimum = 2 const stream = await openai.chat.completions.create({ model, - messages: conversation as OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[], + messages: conversation as ChatCompletionMessageParam[], // OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[], stream: true, max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined, temperature: config.openAi.dalle.completions.temperature || 0.8 @@ -177,13 +176,85 @@ export const streamChatCompletion = async ( } }) return completion - // } catch (e) { - // reject(e) - // } - // }) - // } catch (error: any) { - // return await Promise.reject(error) - // } +} + +export const streamChatVisionCompletion = async ( + ctx: OnMessageContext | OnCallBackQueryData, + model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW, + prompt: string, + imgUrls: string[], + msgId: number, + limitTokens = true +): Promise => { + let completion = '' + let wordCountMinimum = 2 + const payload: any = { + model, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: prompt }, + ...imgUrls.map(img => ({ + type: 'image_url', + image_url: { url: img } + })) + ] + } + ], + stream: true, + max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined + } + const stream = await openai.chat.completions.create(payload) as unknown as Stream + let wordCount = 0 + if (!ctx.chat?.id) { + throw new Error('Context chat id should not be empty after openAI streaming') + } + for await (const part of stream) { + wordCount++ + const chunck = part.choices[0]?.delta?.content + ? part.choices[0]?.delta?.content + : '' + completion += chunck + + if (wordCount > wordCountMinimum) { + if (wordCountMinimum < 64) { + wordCountMinimum *= 2 + } + completion = completion.replaceAll('...', '') + completion += '...' + wordCount = 0 + await ctx.api + .editMessageText(ctx.chat?.id, msgId, completion) + .catch(async (e: any) => { + if (e instanceof GrammyError) { + if (e.error_code !== 400) { + throw e + } else { + logger.error(e) + } + } else { + throw e + } + }) + } + } + completion = completion.replaceAll('...', '') + + await ctx.api + .editMessageText(ctx.chat?.id, msgId, completion) + .catch((e: any) => { + if (e instanceof GrammyError) { + if (e.error_code !== 400) { + throw e + } else { + logger.error(e) + } + } else { + throw e + } + }) + return completion } export async function improvePrompt (promptText: string, model: string): Promise { diff --git a/src/modules/open-ai/helpers.ts b/src/modules/open-ai/helpers.ts index 42b71457..5a858a41 100644 --- a/src/modules/open-ai/helpers.ts +++ b/src/modules/open-ai/helpers.ts @@ -9,7 +9,7 @@ import { isValidUrl } from './utils/web-crawler' export const SupportedCommands = { chat: { name: 'chat' }, ask: { name: 'ask' }, - // sum: { name: 'sum' }, + vision: { name: 'vision' }, ask35: { name: 'ask35' }, new: { name: 'new' }, gpt4: { name: 'gpt4' }, @@ -235,8 +235,8 @@ export const hasPrefix = (prompt: string): string => { export const getPromptPrice = (completion: string, data: ChatPayload): { price: number, promptTokens: number, completionTokens: number } => { const { conversation, ctx, model } = data - const prompt = conversation[conversation.length - 1].content - const promptTokens = getTokenNumber(prompt) + const prompt = data.prompt ? data.prompt : conversation[conversation.length - 1].content + const promptTokens = getTokenNumber(prompt as string) const completionTokens = getTokenNumber(completion) const modelPrice = getChatModel(model) const price = @@ -263,13 +263,14 @@ export const limitPrompt = (prompt: string): string => { return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words` } -export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => { - const entities = ctx.message?.reply_to_message?.entities - if (entities) { - const urlEntity = entities.find(e => e.type === 'url') - if (urlEntity) { - const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length) - return url +export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined => { + const entities = ctx.message?.entities ? ctx.message?.entities : ctx.message?.reply_to_message?.entities + const text = ctx.message?.text ? ctx.message?.text : ctx.message?.reply_to_message?.text + if (entities && text) { + const urlEntity = entities.filter(e => e.type === 'url') + if (urlEntity.length > 0) { + const urls = urlEntity.map(e => text.slice(e.offset, e.offset + e.length)) + return urls } } return undefined diff --git a/src/modules/open-ai/index.ts b/src/modules/open-ai/index.ts index 809006aa..746cc18c 100644 --- a/src/modules/open-ai/index.ts +++ b/src/modules/open-ai/index.ts @@ -18,7 +18,8 @@ import { getDalleModel, getDalleModelPrice, postGenerateImg, - streamChatCompletion + streamChatCompletion, + streamChatVisionCompletion } from './api/openAi' import { appText } from './utils/text' import { chatService } from '../../database/services' @@ -28,6 +29,7 @@ import { sleep } from '../sd-images/utils' import { getMessageExtras, getPromptPrice, + getUrlFromText, hasChatPrefix, hasDallePrefix, hasNewPrefix, @@ -90,7 +92,7 @@ export class OpenAIBot implements PayableBot { try { const priceAdjustment = config.openAi.chatGpt.priceAdjustment const prompts = ctx.match - if (this.isSupportedImageReply(ctx)) { + if (this.isSupportedImageReply(ctx) && !isNaN(+prompts)) { const imageNumber = ctx.message?.caption || ctx.message?.text const imageSize = ctx.session.openAi.imageGen.imgSize const model = getDalleModel(imageSize) @@ -142,7 +144,7 @@ export class OpenAIBot implements PayableBot { const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo if (photo && ctx.session.openAi.imageGen.isEnabled) { const prompt = ctx.message?.caption ?? ctx.message?.text - if (prompt && !isNaN(+prompt)) { + if (prompt) { // && !isNaN(+prompt) return true } } @@ -161,11 +163,11 @@ export class OpenAIBot implements PayableBot { if (this.isSupportedImageReply(ctx)) { const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo - const prompt = ctx.message?.caption ?? ctx.message?.text + const prompt = ctx.message?.caption ?? ctx.message?.text ?? '' ctx.session.openAi.imageGen.imgRequestQueue.push({ prompt, photo, - command: 'alter' + command: !isNaN(+prompt) ? 'alter' : 'vision' }) if (!ctx.session.openAi.imageGen.isProcessingQueue) { ctx.session.openAi.imageGen.isProcessingQueue = true @@ -228,6 +230,24 @@ export class OpenAIBot implements PayableBot { return } + if (ctx.hasCommand(SupportedCommands.vision.name)) { + const photoUrl = getUrlFromText(ctx) + if (photoUrl) { + const prompt = ctx.match + ctx.session.openAi.imageGen.imgRequestQueue.push({ + prompt, + photoUrl, + command: !isNaN(+prompt) ? 'alter' : 'vision' + }) + if (!ctx.session.openAi.imageGen.isProcessingQueue) { + ctx.session.openAi.imageGen.isProcessingQueue = true + await this.onImgRequestHandler(ctx).then(() => { + ctx.session.openAi.imageGen.isProcessingQueue = false + }) + } + } + } + if ( ctx.hasCommand([SupportedCommands.dalle.name, SupportedCommands.dalleImg.name, @@ -556,8 +576,10 @@ export class OpenAIBot implements PayableBot { if (await this.hasBalance(ctx)) { if (img?.command === 'dalle') { await this.onGenImgCmd(img?.prompt, ctx) - } else { + } else if (img?.command === 'alter') { await this.onAlterImage(img?.photo, img?.prompt, ctx) + } else { + await this.onInquiryImage(img?.photo, img?.photoUrl, img?.prompt, ctx) } ctx.chatAction = null } else { @@ -575,7 +597,7 @@ export class OpenAIBot implements PayableBot { ctx.chatAction = 'upload_photo' // eslint-disable-next-line @typescript-eslint/naming-convention const { message_id } = await ctx.reply( - 'Generating dalle image...', { message_thread_id: ctx.message?.message_thread_id } + 'Generating image via OpenAI\'s DALL·E 3...', { message_thread_id: ctx.message?.message_thread_id } ) const numImages = ctx.session.openAi.imageGen.numImages const imgSize = ctx.session.openAi.imageGen.imgSize @@ -606,6 +628,63 @@ export class OpenAIBot implements PayableBot { } } + onInquiryImage = async (photo: PhotoSize[] | undefined, photoUrl: string[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { + try { + if (ctx.session.openAi.imageGen.isEnabled) { + // let filePath = '' + let imgList = [] + if (photo) { + const fileId = photo?.pop()?.file_id // with pop() get full image quality + if (!fileId) { + await ctx.reply('Cannot retrieve the image file. Please try again.') + ctx.transient.analytics.actualResponseTime = now() + return + } + const file = await ctx.api.getFile(fileId) + imgList.push(`${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}`) + } else { + imgList = photoUrl ?? [] + } + const msgId = ( + await ctx.reply('...', { + message_thread_id: + ctx.message?.message_thread_id ?? + ctx.message?.reply_to_message?.message_thread_id + }) + ).message_id + const model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW + const completion = await streamChatVisionCompletion(ctx, model, prompt ?? '', imgList, msgId, true) + if (completion) { + ctx.transient.analytics.sessionState = RequestState.Success + ctx.transient.analytics.actualResponseTime = now() + const price = getPromptPrice(completion, { + conversation: [], + prompt, + model, + ctx + }) + this.logger.info( + `streamChatCompletion result = tokens: ${ + price.promptTokens + price.completionTokens + } | ${model} | price: ${price.price}¢` + ) + if ( + !(await this.payments.pay(ctx as OnMessageContext, price.price)) + ) { + await this.onNotBalanceMessage(ctx) + } + } + } + } catch (e: any) { + await this.onError( + ctx, + e, + MAX_TRIES, + 'An error occurred while generating the AI edit' + ) + } + } + onAlterImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { try { if (ctx.session.openAi.imageGen.isEnabled) { diff --git a/src/modules/open-ai/types.ts b/src/modules/open-ai/types.ts index 46a60c9e..3feb117e 100644 --- a/src/modules/open-ai/types.ts +++ b/src/modules/open-ai/types.ts @@ -16,6 +16,7 @@ export enum ChatGPTModelsEnum { GPT_4_32K = 'gpt-4-32k', GPT_35_TURBO = 'gpt-3.5-turbo', GPT_35_TURBO_16K = 'gpt-3.5-turbo-16k', + GPT_4_VISION_PREVIEW = 'gpt-4-vision-preview' } export const ChatGPTModels: Record = { @@ -46,6 +47,13 @@ export const ChatGPTModels: Record = { outputPrice: 0.004, maxContextTokens: 16000, chargeType: 'TOKEN' + }, + 'gpt-4-vision-preview': { + name: 'gpt-4-vision-preview', + inputPrice: 0.03, + outputPrice: 0.06, + maxContextTokens: 16000, + chargeType: 'TOKEN' } } diff --git a/src/modules/payment/index.ts b/src/modules/payment/index.ts index 0b8ba854..5f855685 100644 --- a/src/modules/payment/index.ts +++ b/src/modules/payment/index.ts @@ -388,7 +388,6 @@ export class BotPayments { public async pay (ctx: OnMessageContext, amountUSD: number): Promise { // eslint-disable-next-line @typescript-eslint/naming-convention const { from, message_id, chat } = ctx.update.message - const accountId = this.getAccountId(ctx) const userAccount = this.getUserAccount(accountId) if (!userAccount) { diff --git a/src/modules/types.ts b/src/modules/types.ts index 2d34a5b6..ef9073e4 100644 --- a/src/modules/types.ts +++ b/src/modules/types.ts @@ -36,20 +36,28 @@ export interface ChatCompletion { } export interface ChatPayload { conversation: ChatConversation[] + prompt?: string model: string ctx: OnMessageContext | OnCallBackQueryData } + +export interface VisionContent { + type: string + text?: string + image_url?: { url: string } +} export interface ChatConversation { role?: string author?: string - content: string + content: string | VisionContent[] model?: string } export interface ImageRequest { - command?: 'dalle' | 'alter' + command?: 'dalle' | 'alter' | 'vision' prompt?: string photo?: PhotoSize[] | undefined + photoUrl?: string[] } export interface ChatGptSessionData { model: string diff --git a/tsconfig.json b/tsconfig.json index 4c6d0019..932a5b76 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -4,7 +4,7 @@ "module": "CommonJS", "outDir": "dist", "types": ["node"], - "lib": ["es2022"], + "lib": ["es2022"], // , "dom", "dom.iterable"] "target": "es2020", "emitDecoratorMetadata": true, "experimentalDecorators": true,