Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vision logic #348

Merged
merged 6 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export class LlmsBot implements PayableBot {
return undefined
}

private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string | undefined {
private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined {
return getUrlFromText(ctx)
}

Expand Down Expand Up @@ -251,14 +251,16 @@ export class LlmsBot implements PayableBot {

async onUrlReplyHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
try {
const url = getUrlFromText(ctx) ?? ''
const prompt = ctx.message?.text ?? 'summarize'
const collection = ctx.session.collections.activeCollections.find(c => c.url === url)
const newPrompt = `${prompt}` // ${url}
if (collection) {
await this.queryUrlCollection(ctx, url, newPrompt)
const url = getUrlFromText(ctx)
if (url) {
const prompt = ctx.message?.text ?? 'summarize'
const collection = ctx.session.collections.activeCollections.find(c => c.url === url[0])
const newPrompt = `${prompt}` // ${url}
if (collection) {
await this.queryUrlCollection(ctx, url[0], newPrompt)
}
ctx.transient.analytics.actualResponseTime = now()
}
ctx.transient.analytics.actualResponseTime = now()
} catch (e: any) {
await this.onError(ctx, e)
}
Expand Down Expand Up @@ -550,7 +552,7 @@ export class LlmsBot implements PayableBot {
await ctx.api.editMessageText(
ctx.chat.id,
msgId,
response.completion.content
response.completion.content as string
)
conversation.push(response.completion)
// const price = getPromptPrice(completion, data);
Expand Down Expand Up @@ -629,7 +631,7 @@ export class LlmsBot implements PayableBot {
while (ctx.session.llms.requestQueue.length > 0) {
try {
const msg = ctx.session.llms.requestQueue.shift()
const prompt = msg?.content
const prompt = msg?.content as string
const model = msg?.model
const { chatConversation } = ctx.session.llms
if (await this.hasBalance(ctx)) {
Expand Down
103 changes: 87 additions & 16 deletions src/modules/open-ai/api/openAi.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import OpenAI from 'openai'
import { encode } from 'gpt-tokenizer'
import { GrammyError } from 'grammy'

import config from '../../../config'
import { deleteFile, getImage } from '../utils/file'
import {
Expand All @@ -15,9 +14,12 @@ import {
type ChatModel,
ChatGPTModels,
type DalleGPTModel,
DalleGPTModels
DalleGPTModels,
ChatGPTModelsEnum
} from '../types'
import type fs from 'fs'
import { type ChatCompletionMessageParam } from 'openai/resources/chat/completions'
import { type Stream } from 'openai/streaming'

const openai = new OpenAI({ apiKey: config.openAiKey })

Expand Down Expand Up @@ -83,15 +85,12 @@ export async function chatCompletion (
model = config.openAi.chatGpt.model,
limitTokens = true
): Promise<ChatCompletion> {
const payload = {
const response = await openai.chat.completions.create({
model,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined,
temperature: config.openAi.dalle.completions.temperature,
messages: conversation
}
const response = await openai.chat.completions.create(
payload as OpenAI.Chat.CompletionCreateParamsNonStreaming
)
messages: conversation as ChatCompletionMessageParam[]
})
const chatModel = getChatModel(model)
if (response.usage?.prompt_tokens === undefined) {
throw new Error('Unknown number of prompt tokens used')
Expand Down Expand Up @@ -120,7 +119,7 @@ export const streamChatCompletion = async (
let wordCountMinimum = 2
const stream = await openai.chat.completions.create({
model,
messages: conversation as OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[],
messages: conversation as ChatCompletionMessageParam[], // OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[],
stream: true,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined,
temperature: config.openAi.dalle.completions.temperature || 0.8
Expand Down Expand Up @@ -177,13 +176,85 @@ export const streamChatCompletion = async (
}
})
return completion
// } catch (e) {
// reject(e)
// }
// })
// } catch (error: any) {
// return await Promise.reject(error)
// }
}

export const streamChatVisionCompletion = async (
ctx: OnMessageContext | OnCallBackQueryData,
model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW,
prompt: string,
imgUrls: string[],
msgId: number,
limitTokens = true
): Promise<string> => {
let completion = ''
let wordCountMinimum = 2
const payload: any = {
model,
messages: [
{
role: 'user',
content: [
{ type: 'text', text: prompt },
...imgUrls.map(img => ({
type: 'image_url',
image_url: { url: img }
}))
]
}
],
stream: true,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined
}
const stream = await openai.chat.completions.create(payload) as unknown as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
let wordCount = 0
if (!ctx.chat?.id) {
throw new Error('Context chat id should not be empty after openAI streaming')
}
for await (const part of stream) {
wordCount++
const chunck = part.choices[0]?.delta?.content
? part.choices[0]?.delta?.content
: ''
completion += chunck

if (wordCount > wordCountMinimum) {
if (wordCountMinimum < 64) {
wordCountMinimum *= 2
}
completion = completion.replaceAll('...', '')
completion += '...'
wordCount = 0
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch(async (e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
}
}
completion = completion.replaceAll('...', '')

await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch((e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
return completion
}

export async function improvePrompt (promptText: string, model: string): Promise<string> {
Expand Down
21 changes: 11 additions & 10 deletions src/modules/open-ai/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { isValidUrl } from './utils/web-crawler'
export const SupportedCommands = {
chat: { name: 'chat' },
ask: { name: 'ask' },
// sum: { name: 'sum' },
vision: { name: 'vision' },
ask35: { name: 'ask35' },
new: { name: 'new' },
gpt4: { name: 'gpt4' },
Expand Down Expand Up @@ -235,8 +235,8 @@ export const hasPrefix = (prompt: string): string => {
export const getPromptPrice = (completion: string, data: ChatPayload): { price: number, promptTokens: number, completionTokens: number } => {
const { conversation, ctx, model } = data

const prompt = conversation[conversation.length - 1].content
const promptTokens = getTokenNumber(prompt)
const prompt = data.prompt ? data.prompt : conversation[conversation.length - 1].content
const promptTokens = getTokenNumber(prompt as string)
const completionTokens = getTokenNumber(completion)
const modelPrice = getChatModel(model)
const price =
Expand All @@ -263,13 +263,14 @@ export const limitPrompt = (prompt: string): string => {
return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words`
}

export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => {
const entities = ctx.message?.reply_to_message?.entities
if (entities) {
const urlEntity = entities.find(e => e.type === 'url')
if (urlEntity) {
const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length)
return url
export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined => {
const entities = ctx.message?.entities ? ctx.message?.entities : ctx.message?.reply_to_message?.entities
const text = ctx.message?.text ? ctx.message?.text : ctx.message?.reply_to_message?.text
if (entities && text) {
const urlEntity = entities.filter(e => e.type === 'url')
if (urlEntity.length > 0) {
const urls = urlEntity.map(e => text.slice(e.offset, e.offset + e.length))
return urls
}
}
return undefined
Expand Down
93 changes: 86 additions & 7 deletions src/modules/open-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import {
getDalleModel,
getDalleModelPrice,
postGenerateImg,
streamChatCompletion
streamChatCompletion,
streamChatVisionCompletion
} from './api/openAi'
import { appText } from './utils/text'
import { chatService } from '../../database/services'
Expand All @@ -28,6 +29,7 @@ import { sleep } from '../sd-images/utils'
import {
getMessageExtras,
getPromptPrice,
getUrlFromText,
hasChatPrefix,
hasDallePrefix,
hasNewPrefix,
Expand Down Expand Up @@ -90,7 +92,7 @@ export class OpenAIBot implements PayableBot {
try {
const priceAdjustment = config.openAi.chatGpt.priceAdjustment
const prompts = ctx.match
if (this.isSupportedImageReply(ctx)) {
if (this.isSupportedImageReply(ctx) && !isNaN(+prompts)) {
const imageNumber = ctx.message?.caption || ctx.message?.text
const imageSize = ctx.session.openAi.imageGen.imgSize
const model = getDalleModel(imageSize)
Expand Down Expand Up @@ -142,7 +144,7 @@ export class OpenAIBot implements PayableBot {
const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
if (photo && ctx.session.openAi.imageGen.isEnabled) {
const prompt = ctx.message?.caption ?? ctx.message?.text
if (prompt && !isNaN(+prompt)) {
if (prompt) { // && !isNaN(+prompt)
return true
}
}
Expand All @@ -161,11 +163,11 @@ export class OpenAIBot implements PayableBot {

if (this.isSupportedImageReply(ctx)) {
const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
const prompt = ctx.message?.caption ?? ctx.message?.text
const prompt = ctx.message?.caption ?? ctx.message?.text ?? ''
ctx.session.openAi.imageGen.imgRequestQueue.push({
prompt,
photo,
command: 'alter'
command: !isNaN(+prompt) ? 'alter' : 'vision'
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
Expand Down Expand Up @@ -228,6 +230,24 @@ export class OpenAIBot implements PayableBot {
return
}

if (ctx.hasCommand(SupportedCommands.vision.name)) {
const photoUrl = getUrlFromText(ctx)
if (photoUrl) {
const prompt = ctx.match
ctx.session.openAi.imageGen.imgRequestQueue.push({
prompt,
photoUrl,
command: !isNaN(+prompt) ? 'alter' : 'vision'
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
}
}

if (
ctx.hasCommand([SupportedCommands.dalle.name,
SupportedCommands.dalleImg.name,
Expand Down Expand Up @@ -556,8 +576,10 @@ export class OpenAIBot implements PayableBot {
if (await this.hasBalance(ctx)) {
if (img?.command === 'dalle') {
await this.onGenImgCmd(img?.prompt, ctx)
} else {
} else if (img?.command === 'alter') {
await this.onAlterImage(img?.photo, img?.prompt, ctx)
} else {
await this.onInquiryImage(img?.photo, img?.photoUrl, img?.prompt, ctx)
}
ctx.chatAction = null
} else {
Expand All @@ -575,7 +597,7 @@ export class OpenAIBot implements PayableBot {
ctx.chatAction = 'upload_photo'
// eslint-disable-next-line @typescript-eslint/naming-convention
const { message_id } = await ctx.reply(
'Generating dalle image...', { message_thread_id: ctx.message?.message_thread_id }
'Generating image via OpenAI\'s DALL·E 3...', { message_thread_id: ctx.message?.message_thread_id }
)
const numImages = ctx.session.openAi.imageGen.numImages
const imgSize = ctx.session.openAi.imageGen.imgSize
Expand Down Expand Up @@ -606,6 +628,63 @@ export class OpenAIBot implements PayableBot {
}
}

onInquiryImage = async (photo: PhotoSize[] | undefined, photoUrl: string[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
// let filePath = ''
let imgList = []
if (photo) {
const fileId = photo?.pop()?.file_id // with pop() get full image quality
if (!fileId) {
await ctx.reply('Cannot retrieve the image file. Please try again.')
ctx.transient.analytics.actualResponseTime = now()
return
}
const file = await ctx.api.getFile(fileId)
imgList.push(`${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}`)
} else {
imgList = photoUrl ?? []
}
const msgId = (
await ctx.reply('...', {
message_thread_id:
ctx.message?.message_thread_id ??
ctx.message?.reply_to_message?.message_thread_id
})
).message_id
const model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW
const completion = await streamChatVisionCompletion(ctx, model, prompt ?? '', imgList, msgId, true)
if (completion) {
ctx.transient.analytics.sessionState = RequestState.Success
ctx.transient.analytics.actualResponseTime = now()
const price = getPromptPrice(completion, {
conversation: [],
prompt,
model,
ctx
})
this.logger.info(
`streamChatCompletion result = tokens: ${
price.promptTokens + price.completionTokens
} | ${model} | price: ${price.price}¢`
)
if (
!(await this.payments.pay(ctx as OnMessageContext, price.price))
) {
await this.onNotBalanceMessage(ctx)
}
}
}
} catch (e: any) {
await this.onError(
ctx,
e,
MAX_TRIES,
'An error occurred while generating the AI edit'
)
}
}

onAlterImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
Expand Down
Loading