Skip to content

Commit

Permalink
Merge pull request #348 from harmony-one/vision
Browse files Browse the repository at this point in the history
add vision logic
  • Loading branch information
fegloff authored Jan 17, 2024
2 parents f83b56f + 695f979 commit 751165f
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 47 deletions.
22 changes: 12 additions & 10 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export class LlmsBot implements PayableBot {
return undefined
}

private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string | undefined {
private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined {
return getUrlFromText(ctx)
}

Expand Down Expand Up @@ -251,14 +251,16 @@ export class LlmsBot implements PayableBot {

async onUrlReplyHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
try {
const url = getUrlFromText(ctx) ?? ''
const prompt = ctx.message?.text ?? 'summarize'
const collection = ctx.session.collections.activeCollections.find(c => c.url === url)
const newPrompt = `${prompt}` // ${url}
if (collection) {
await this.queryUrlCollection(ctx, url, newPrompt)
const url = getUrlFromText(ctx)
if (url) {
const prompt = ctx.message?.text ?? 'summarize'
const collection = ctx.session.collections.activeCollections.find(c => c.url === url[0])
const newPrompt = `${prompt}` // ${url}
if (collection) {
await this.queryUrlCollection(ctx, url[0], newPrompt)
}
ctx.transient.analytics.actualResponseTime = now()
}
ctx.transient.analytics.actualResponseTime = now()
} catch (e: any) {
await this.onError(ctx, e)
}
Expand Down Expand Up @@ -550,7 +552,7 @@ export class LlmsBot implements PayableBot {
await ctx.api.editMessageText(
ctx.chat.id,
msgId,
response.completion.content
response.completion.content as string
)
conversation.push(response.completion)
// const price = getPromptPrice(completion, data);
Expand Down Expand Up @@ -629,7 +631,7 @@ export class LlmsBot implements PayableBot {
while (ctx.session.llms.requestQueue.length > 0) {
try {
const msg = ctx.session.llms.requestQueue.shift()
const prompt = msg?.content
const prompt = msg?.content as string
const model = msg?.model
const { chatConversation } = ctx.session.llms
if (await this.hasBalance(ctx)) {
Expand Down
103 changes: 87 additions & 16 deletions src/modules/open-ai/api/openAi.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import OpenAI from 'openai'
import { encode } from 'gpt-tokenizer'
import { GrammyError } from 'grammy'

import config from '../../../config'
import { deleteFile, getImage } from '../utils/file'
import {
Expand All @@ -15,9 +14,12 @@ import {
type ChatModel,
ChatGPTModels,
type DalleGPTModel,
DalleGPTModels
DalleGPTModels,
ChatGPTModelsEnum
} from '../types'
import type fs from 'fs'
import { type ChatCompletionMessageParam } from 'openai/resources/chat/completions'
import { type Stream } from 'openai/streaming'

const openai = new OpenAI({ apiKey: config.openAiKey })

Expand Down Expand Up @@ -83,15 +85,12 @@ export async function chatCompletion (
model = config.openAi.chatGpt.model,
limitTokens = true
): Promise<ChatCompletion> {
const payload = {
const response = await openai.chat.completions.create({
model,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined,
temperature: config.openAi.dalle.completions.temperature,
messages: conversation
}
const response = await openai.chat.completions.create(
payload as OpenAI.Chat.CompletionCreateParamsNonStreaming
)
messages: conversation as ChatCompletionMessageParam[]
})
const chatModel = getChatModel(model)
if (response.usage?.prompt_tokens === undefined) {
throw new Error('Unknown number of prompt tokens used')
Expand Down Expand Up @@ -120,7 +119,7 @@ export const streamChatCompletion = async (
let wordCountMinimum = 2
const stream = await openai.chat.completions.create({
model,
messages: conversation as OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[],
messages: conversation as ChatCompletionMessageParam[], // OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[],
stream: true,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined,
temperature: config.openAi.dalle.completions.temperature || 0.8
Expand Down Expand Up @@ -177,13 +176,85 @@ export const streamChatCompletion = async (
}
})
return completion
// } catch (e) {
// reject(e)
// }
// })
// } catch (error: any) {
// return await Promise.reject(error)
// }
}

export const streamChatVisionCompletion = async (
ctx: OnMessageContext | OnCallBackQueryData,
model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW,
prompt: string,
imgUrls: string[],
msgId: number,
limitTokens = true
): Promise<string> => {
let completion = ''
let wordCountMinimum = 2
const payload: any = {
model,
messages: [
{
role: 'user',
content: [
{ type: 'text', text: prompt },
...imgUrls.map(img => ({
type: 'image_url',
image_url: { url: img }
}))
]
}
],
stream: true,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined
}
const stream = await openai.chat.completions.create(payload) as unknown as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
let wordCount = 0
if (!ctx.chat?.id) {
throw new Error('Context chat id should not be empty after openAI streaming')
}
for await (const part of stream) {
wordCount++
const chunck = part.choices[0]?.delta?.content
? part.choices[0]?.delta?.content
: ''
completion += chunck

if (wordCount > wordCountMinimum) {
if (wordCountMinimum < 64) {
wordCountMinimum *= 2
}
completion = completion.replaceAll('...', '')
completion += '...'
wordCount = 0
await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch(async (e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
}
}
completion = completion.replaceAll('...', '')

await ctx.api
.editMessageText(ctx.chat?.id, msgId, completion)
.catch((e: any) => {
if (e instanceof GrammyError) {
if (e.error_code !== 400) {
throw e
} else {
logger.error(e)
}
} else {
throw e
}
})
return completion
}

export async function improvePrompt (promptText: string, model: string): Promise<string> {
Expand Down
21 changes: 11 additions & 10 deletions src/modules/open-ai/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { isValidUrl } from './utils/web-crawler'
export const SupportedCommands = {
chat: { name: 'chat' },
ask: { name: 'ask' },
// sum: { name: 'sum' },
vision: { name: 'vision' },
ask35: { name: 'ask35' },
new: { name: 'new' },
gpt4: { name: 'gpt4' },
Expand Down Expand Up @@ -235,8 +235,8 @@ export const hasPrefix = (prompt: string): string => {
export const getPromptPrice = (completion: string, data: ChatPayload): { price: number, promptTokens: number, completionTokens: number } => {
const { conversation, ctx, model } = data

const prompt = conversation[conversation.length - 1].content
const promptTokens = getTokenNumber(prompt)
const prompt = data.prompt ? data.prompt : conversation[conversation.length - 1].content
const promptTokens = getTokenNumber(prompt as string)
const completionTokens = getTokenNumber(completion)
const modelPrice = getChatModel(model)
const price =
Expand All @@ -263,13 +263,14 @@ export const limitPrompt = (prompt: string): string => {
return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words`
}

export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => {
const entities = ctx.message?.reply_to_message?.entities
if (entities) {
const urlEntity = entities.find(e => e.type === 'url')
if (urlEntity) {
const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length)
return url
export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined => {
const entities = ctx.message?.entities ? ctx.message?.entities : ctx.message?.reply_to_message?.entities
const text = ctx.message?.text ? ctx.message?.text : ctx.message?.reply_to_message?.text
if (entities && text) {
const urlEntity = entities.filter(e => e.type === 'url')
if (urlEntity.length > 0) {
const urls = urlEntity.map(e => text.slice(e.offset, e.offset + e.length))
return urls
}
}
return undefined
Expand Down
93 changes: 86 additions & 7 deletions src/modules/open-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import {
getDalleModel,
getDalleModelPrice,
postGenerateImg,
streamChatCompletion
streamChatCompletion,
streamChatVisionCompletion
} from './api/openAi'
import { appText } from './utils/text'
import { chatService } from '../../database/services'
Expand All @@ -28,6 +29,7 @@ import { sleep } from '../sd-images/utils'
import {
getMessageExtras,
getPromptPrice,
getUrlFromText,
hasChatPrefix,
hasDallePrefix,
hasNewPrefix,
Expand Down Expand Up @@ -90,7 +92,7 @@ export class OpenAIBot implements PayableBot {
try {
const priceAdjustment = config.openAi.chatGpt.priceAdjustment
const prompts = ctx.match
if (this.isSupportedImageReply(ctx)) {
if (this.isSupportedImageReply(ctx) && !isNaN(+prompts)) {
const imageNumber = ctx.message?.caption || ctx.message?.text
const imageSize = ctx.session.openAi.imageGen.imgSize
const model = getDalleModel(imageSize)
Expand Down Expand Up @@ -142,7 +144,7 @@ export class OpenAIBot implements PayableBot {
const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
if (photo && ctx.session.openAi.imageGen.isEnabled) {
const prompt = ctx.message?.caption ?? ctx.message?.text
if (prompt && !isNaN(+prompt)) {
if (prompt) { // && !isNaN(+prompt)
return true
}
}
Expand All @@ -161,11 +163,11 @@ export class OpenAIBot implements PayableBot {

if (this.isSupportedImageReply(ctx)) {
const photo = ctx.message?.photo ?? ctx.message?.reply_to_message?.photo
const prompt = ctx.message?.caption ?? ctx.message?.text
const prompt = ctx.message?.caption ?? ctx.message?.text ?? ''
ctx.session.openAi.imageGen.imgRequestQueue.push({
prompt,
photo,
command: 'alter'
command: !isNaN(+prompt) ? 'alter' : 'vision'
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
Expand Down Expand Up @@ -228,6 +230,24 @@ export class OpenAIBot implements PayableBot {
return
}

if (ctx.hasCommand(SupportedCommands.vision.name)) {
const photoUrl = getUrlFromText(ctx)
if (photoUrl) {
const prompt = ctx.match
ctx.session.openAi.imageGen.imgRequestQueue.push({
prompt,
photoUrl,
command: !isNaN(+prompt) ? 'alter' : 'vision'
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
}
}

if (
ctx.hasCommand([SupportedCommands.dalle.name,
SupportedCommands.dalleImg.name,
Expand Down Expand Up @@ -556,8 +576,10 @@ export class OpenAIBot implements PayableBot {
if (await this.hasBalance(ctx)) {
if (img?.command === 'dalle') {
await this.onGenImgCmd(img?.prompt, ctx)
} else {
} else if (img?.command === 'alter') {
await this.onAlterImage(img?.photo, img?.prompt, ctx)
} else {
await this.onInquiryImage(img?.photo, img?.photoUrl, img?.prompt, ctx)
}
ctx.chatAction = null
} else {
Expand All @@ -575,7 +597,7 @@ export class OpenAIBot implements PayableBot {
ctx.chatAction = 'upload_photo'
// eslint-disable-next-line @typescript-eslint/naming-convention
const { message_id } = await ctx.reply(
'Generating dalle image...', { message_thread_id: ctx.message?.message_thread_id }
'Generating image via OpenAI\'s DALL·E 3...', { message_thread_id: ctx.message?.message_thread_id }
)
const numImages = ctx.session.openAi.imageGen.numImages
const imgSize = ctx.session.openAi.imageGen.imgSize
Expand Down Expand Up @@ -606,6 +628,63 @@ export class OpenAIBot implements PayableBot {
}
}

onInquiryImage = async (photo: PhotoSize[] | undefined, photoUrl: string[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
// let filePath = ''
let imgList = []
if (photo) {
const fileId = photo?.pop()?.file_id // with pop() get full image quality
if (!fileId) {
await ctx.reply('Cannot retrieve the image file. Please try again.')
ctx.transient.analytics.actualResponseTime = now()
return
}
const file = await ctx.api.getFile(fileId)
imgList.push(`${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}`)
} else {
imgList = photoUrl ?? []
}
const msgId = (
await ctx.reply('...', {
message_thread_id:
ctx.message?.message_thread_id ??
ctx.message?.reply_to_message?.message_thread_id
})
).message_id
const model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW
const completion = await streamChatVisionCompletion(ctx, model, prompt ?? '', imgList, msgId, true)
if (completion) {
ctx.transient.analytics.sessionState = RequestState.Success
ctx.transient.analytics.actualResponseTime = now()
const price = getPromptPrice(completion, {
conversation: [],
prompt,
model,
ctx
})
this.logger.info(
`streamChatCompletion result = tokens: ${
price.promptTokens + price.completionTokens
} | ${model} | price: ${price.price}¢`
)
if (
!(await this.payments.pay(ctx as OnMessageContext, price.price))
) {
await this.onNotBalanceMessage(ctx)
}
}
}
} catch (e: any) {
await this.onError(
ctx,
e,
MAX_TRIES,
'An error occurred while generating the AI edit'
)
}
}

onAlterImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
Expand Down
Loading

0 comments on commit 751165f

Please sign in to comment.