Skip to content

Commit

Permalink
add stream completion for vision + add vision command to work with mu…
Browse files Browse the repository at this point in the history
…ltiple img url
  • Loading branch information
fegloff committed Jan 17, 2024
1 parent ec86fcd commit 695f979
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 72 deletions.
18 changes: 10 additions & 8 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export class LlmsBot implements PayableBot {
return undefined
}

private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string | undefined {
private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined {
return getUrlFromText(ctx)
}

Expand Down Expand Up @@ -251,14 +251,16 @@ export class LlmsBot implements PayableBot {

async onUrlReplyHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
try {
const url = getUrlFromText(ctx) ?? ''
const prompt = ctx.message?.text ?? 'summarize'
const collection = ctx.session.collections.activeCollections.find(c => c.url === url)
const newPrompt = `${prompt}` // ${url}
if (collection) {
await this.queryUrlCollection(ctx, url, newPrompt)
const url = getUrlFromText(ctx)
if (url) {
const prompt = ctx.message?.text ?? 'summarize'
const collection = ctx.session.collections.activeCollections.find(c => c.url === url[0])
const newPrompt = `${prompt}` // ${url}
if (collection) {
await this.queryUrlCollection(ctx, url[0], newPrompt)
}
ctx.transient.analytics.actualResponseTime = now()
}
ctx.transient.analytics.actualResponseTime = now()
} catch (e: any) {
await this.onError(ctx, e)
}
Expand Down
39 changes: 5 additions & 34 deletions src/modules/open-ai/api/openAi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
ChatGPTModelsEnum
} from '../types'
import type fs from 'fs'
import { type ChatCompletionMessageParam, type ChatCompletionCreateParamsNonStreaming } from 'openai/resources/chat/completions'
import { type ChatCompletionMessageParam } from 'openai/resources/chat/completions'
import { type Stream } from 'openai/streaming'

const openai = new OpenAI({ apiKey: config.openAiKey })
Expand Down Expand Up @@ -50,34 +50,6 @@ export async function postGenerateImg (
return response.data
}

export async function imgInquiryWithVision (
img: string,
prompt: string,
ctx: OnMessageContext | OnCallBackQueryData
): Promise<string> {
console.log(img, prompt)
const payLoad = {
model: 'gpt-4-vision-preview',
messages: [
{
role: 'user',
content: [
{ type: 'text', text: 'What’s in this image?' },
{
type: 'image_url',
image_url: { url: img }
}
]
}
],
max_tokens: 300
}
console.log('HELLO')
const response = await openai.chat.completions.create(payLoad as unknown as ChatCompletionCreateParamsNonStreaming)
console.log(response.choices[0].message?.content)
return 'hi'
}

export async function alterGeneratedImg (
prompt: string,
filePath: string,
Expand Down Expand Up @@ -207,11 +179,10 @@ export const streamChatCompletion = async (
}

export const streamChatVisionCompletion = async (
conversation: ChatConversation[],
ctx: OnMessageContext | OnCallBackQueryData,
model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW,
prompt: string,
imgUrl: string,
imgUrls: string[],
msgId: number,
limitTokens = true
): Promise<string> => {
Expand All @@ -224,10 +195,10 @@ export const streamChatVisionCompletion = async (
role: 'user',
content: [
{ type: 'text', text: prompt },
{
...imgUrls.map(img => ({
type: 'image_url',
image_url: { url: imgUrl }
}
image_url: { url: img }
}))
]
}
],
Expand Down
17 changes: 9 additions & 8 deletions src/modules/open-ai/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { isValidUrl } from './utils/web-crawler'
export const SupportedCommands = {
chat: { name: 'chat' },
ask: { name: 'ask' },
// sum: { name: 'sum' },
vision: { name: 'vision' },
ask35: { name: 'ask35' },
new: { name: 'new' },
gpt4: { name: 'gpt4' },
Expand Down Expand Up @@ -263,13 +263,14 @@ export const limitPrompt = (prompt: string): string => {
return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words`
}

export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => {
const entities = ctx.message?.reply_to_message?.entities
if (entities) {
const urlEntity = entities.find(e => e.type === 'url')
if (urlEntity) {
const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length)
return url
export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined => {
const entities = ctx.message?.entities ? ctx.message?.entities : ctx.message?.reply_to_message?.entities
const text = ctx.message?.text ? ctx.message?.text : ctx.message?.reply_to_message?.text
if (entities && text) {
const urlEntity = entities.filter(e => e.type === 'url')
if (urlEntity.length > 0) {
const urls = urlEntity.map(e => text.slice(e.offset, e.offset + e.length))
return urls
}
}
return undefined
Expand Down
57 changes: 35 additions & 22 deletions src/modules/open-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { sleep } from '../sd-images/utils'
import {
getMessageExtras,
getPromptPrice,
getUrlFromText,
hasChatPrefix,
hasDallePrefix,
hasNewPrefix,
Expand Down Expand Up @@ -229,6 +230,24 @@ export class OpenAIBot implements PayableBot {
return
}

if (ctx.hasCommand(SupportedCommands.vision.name)) {
const photoUrl = getUrlFromText(ctx)
if (photoUrl) {
const prompt = ctx.match
ctx.session.openAi.imageGen.imgRequestQueue.push({
prompt,
photoUrl,
command: !isNaN(+prompt) ? 'alter' : 'vision'
})
if (!ctx.session.openAi.imageGen.isProcessingQueue) {
ctx.session.openAi.imageGen.isProcessingQueue = true
await this.onImgRequestHandler(ctx).then(() => {
ctx.session.openAi.imageGen.isProcessingQueue = false
})
}
}
}

if (
ctx.hasCommand([SupportedCommands.dalle.name,
SupportedCommands.dalleImg.name,
Expand Down Expand Up @@ -560,7 +579,7 @@ export class OpenAIBot implements PayableBot {
} else if (img?.command === 'alter') {
await this.onAlterImage(img?.photo, img?.prompt, ctx)
} else {
await this.onInquiryImage(img?.photo, img?.prompt, ctx)
await this.onInquiryImage(img?.photo, img?.photoUrl, img?.prompt, ctx)
}
ctx.chatAction = null
} else {
Expand Down Expand Up @@ -609,38 +628,32 @@ export class OpenAIBot implements PayableBot {
}
}

onInquiryImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
onInquiryImage = async (photo: PhotoSize[] | undefined, photoUrl: string[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
const fileId = photo?.pop()?.file_id // with pop() get full image quality
if (!fileId) {
await ctx.reply('Cannot retrieve the image file. Please try again.')
ctx.transient.analytics.actualResponseTime = now()
return
// let filePath = ''
let imgList = []
if (photo) {
const fileId = photo?.pop()?.file_id // with pop() get full image quality
if (!fileId) {
await ctx.reply('Cannot retrieve the image file. Please try again.')
ctx.transient.analytics.actualResponseTime = now()
return
}
const file = await ctx.api.getFile(fileId)
imgList.push(`${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}`)
} else {
imgList = photoUrl ?? []
}
const file = await ctx.api.getFile(fileId)
const filePath = `${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}`
const msgId = (
await ctx.reply('...', {
message_thread_id:
ctx.message?.message_thread_id ??
ctx.message?.reply_to_message?.message_thread_id
})
).message_id
const messages = [
{
role: 'user',
content: [
{ type: 'text', text: prompt },
{
type: 'image_url',
image_url: { url: filePath }
}
]
}
]
const model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW
const completion = await streamChatVisionCompletion(messages, ctx, model, prompt ?? '', filePath, msgId, true)
const completion = await streamChatVisionCompletion(ctx, model, prompt ?? '', imgList, msgId, true)
if (completion) {
ctx.transient.analytics.sessionState = RequestState.Success
ctx.transient.analytics.actualResponseTime = now()
Expand Down
1 change: 1 addition & 0 deletions src/modules/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export interface ImageRequest {
command?: 'dalle' | 'alter' | 'vision'
prompt?: string
photo?: PhotoSize[] | undefined
photoUrl?: string[]
}
export interface ChatGptSessionData {
model: string
Expand Down

0 comments on commit 695f979

Please sign in to comment.