Skip to content

Commit

Permalink
update vision logic to comply with openai api
Browse files Browse the repository at this point in the history
  • Loading branch information
fegloff committed Jan 15, 2024
1 parent 0844a77 commit 9395577
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 53 deletions.
4 changes: 2 additions & 2 deletions src/modules/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ export class LlmsBot implements PayableBot {
await ctx.api.editMessageText(
ctx.chat.id,
msgId,
response.completion.content
response.completion.content as string
)
conversation.push(response.completion)
// const price = getPromptPrice(completion, data);
Expand Down Expand Up @@ -648,7 +648,7 @@ export class LlmsBot implements PayableBot {
return
}
const chat: ChatConversation = {
content: limitPrompt(prompt),
content: limitPrompt(prompt as string),
model
}
if (model === LlmsModelsEnum.BISON) {
Expand Down
13 changes: 5 additions & 8 deletions src/modules/open-ai/api/openAi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
DalleGPTModels
} from '../types'
import type fs from 'fs'
import { type ChatCompletionCreateParamsNonStreaming } from 'openai/resources/chat/completions'
import { type ChatCompletionMessageParam, type ChatCompletionCreateParamsNonStreaming } from 'openai/resources/chat/completions'

const openai = new OpenAI({ apiKey: config.openAiKey })

Expand Down Expand Up @@ -112,15 +112,12 @@ export async function chatCompletion (
model = config.openAi.chatGpt.model,
limitTokens = true
): Promise<ChatCompletion> {
const payload = {
const response = await openai.chat.completions.create({
model,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined,
temperature: config.openAi.dalle.completions.temperature,
messages: conversation
}
const response = await openai.chat.completions.create(
payload as OpenAI.Chat.CompletionCreateParamsNonStreaming
)
messages: conversation as ChatCompletionMessageParam[]
})
const chatModel = getChatModel(model)
if (response.usage?.prompt_tokens === undefined) {
throw new Error('Unknown number of prompt tokens used')
Expand Down Expand Up @@ -149,7 +146,7 @@ export const streamChatCompletion = async (
let wordCountMinimum = 2
const stream = await openai.chat.completions.create({
model,
messages: conversation as OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[],
messages: conversation as ChatCompletionMessageParam[], // OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[],
stream: true,
max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined,
temperature: config.openAi.dalle.completions.temperature || 0.8
Expand Down
4 changes: 2 additions & 2 deletions src/modules/open-ai/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ export const hasPrefix = (prompt: string): string => {
export const getPromptPrice = (completion: string, data: ChatPayload): { price: number, promptTokens: number, completionTokens: number } => {
const { conversation, ctx, model } = data

const prompt = conversation[conversation.length - 1].content
const promptTokens = getTokenNumber(prompt)
const prompt = data.prompt ? data.prompt : conversation[conversation.length - 1].content
const promptTokens = getTokenNumber(prompt as string)
const completionTokens = getTokenNumber(completion)
const modelPrice = getChatModel(model)
const price =
Expand Down
86 changes: 47 additions & 39 deletions src/modules/open-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import {
} from '../types'
import {
alterGeneratedImg,
chatCompletion,
getChatModel,
getDalleModel,
getDalleModelPrice,
postGenerateImg,
streamChatCompletion,
streamChatVisionCompletion
streamChatCompletion
} from './api/openAi'
import { appText } from './utils/text'
import { chatService } from '../../database/services'
Expand Down Expand Up @@ -91,7 +91,7 @@ export class OpenAIBot implements PayableBot {
try {
const priceAdjustment = config.openAi.chatGpt.priceAdjustment
const prompts = ctx.match
if (this.isSupportedImageReply(ctx)) {
if (this.isSupportedImageReply(ctx) && !isNaN(+prompts)) {
const imageNumber = ctx.message?.caption || ctx.message?.text
const imageSize = ctx.session.openAi.imageGen.imgSize
const model = getDalleModel(imageSize)
Expand Down Expand Up @@ -609,18 +609,6 @@ export class OpenAIBot implements PayableBot {
}
}

// imgInquiryWithVision = async (
// img: string,
// prompt: string,
// ctx: OnMessageContext | OnCallBackQueryData
// ): Promise<string> => {
// console.log(img, prompt)
// console.log('HELLO')
// const response = await openai.chat.completions.create(payLoad as unknown as ChatCompletionCreateParamsNonStreaming)
// console.log(response.choices[0].message?.content)
// return 'hi'
// }

onInquiryImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise<void> => {
try {
if (ctx.session.openAi.imageGen.isEnabled) {
Expand All @@ -639,30 +627,50 @@ export class OpenAIBot implements PayableBot {
ctx.message?.reply_to_message?.message_thread_id
})
).message_id
const completion = await streamChatVisionCompletion([], ctx, 'gpt-4-vision-preview', prompt ?? '', filePath, msgId, true)
console.log(completion)
// const inquiry = await imgInquiryWithVision(filePath, prompt ?? '', ctx)
// console.log(inquiry)
// const imgSize = ctx.session.openAi.imageGen.imgSize
// ctx.chatAction = 'upload_photo'
// const imgs = await alterGeneratedImg(prompt ?? '', filePath, ctx, imgSize)
// if (imgs) {
// imgs.map(async (img: any) => {
// if (img?.url) {
// await ctx
// .replyWithPhoto(img.url, { message_thread_id: ctx.message?.message_thread_id })
// .catch(async (e) => {
// await this.onError(
// ctx,
// e,
// MAX_TRIES,
// 'There was an error while generating the image'
// )
// })
// }
// })
// }
// ctx.chatAction = null
const messages = [
{
role: 'user',
content: [
{ type: 'text', text: prompt },
{
type: 'image_url',
image_url: { url: filePath }
}
]
}
]
const model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW
const completion = await chatCompletion(messages as any, model, true)
if (completion) {
await ctx.api
.editMessageText(`${ctx.chat?.id}`, msgId, completion.completion)
.catch(async (e: any) => {
await this.onError(
ctx,
e,
MAX_TRIES,
'An error occurred while generating the AI edit'
)
})
ctx.transient.analytics.sessionState = RequestState.Success
ctx.transient.analytics.actualResponseTime = now()
const price = getPromptPrice(completion.completion, {
conversation: [],
prompt,
model,
ctx
})
this.logger.info(
`streamChatCompletion result = tokens: ${
price.promptTokens + price.completionTokens
} | ${model} | price: ${price.price}¢`
)
if (
!(await this.payments.pay(ctx as OnMessageContext, price.price))
) {
await this.onNotBalanceMessage(ctx)
}
}
}
} catch (e: any) {
await this.onError(
Expand Down
8 changes: 8 additions & 0 deletions src/modules/open-ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export enum ChatGPTModelsEnum {
GPT_4_32K = 'gpt-4-32k',
GPT_35_TURBO = 'gpt-3.5-turbo',
GPT_35_TURBO_16K = 'gpt-3.5-turbo-16k',
GPT_4_VISION_PREVIEW = 'gpt-4-vision-preview'
}

export const ChatGPTModels: Record<string, ChatModel> = {
Expand Down Expand Up @@ -46,6 +47,13 @@ export const ChatGPTModels: Record<string, ChatModel> = {
outputPrice: 0.004,
maxContextTokens: 16000,
chargeType: 'TOKEN'
},
'gpt-4-vision-preview': {
name: 'gpt-4-vision-preview',
inputPrice: 0.03,
outputPrice: 0.06,
maxContextTokens: 16000,
chargeType: 'TOKEN'
}
}

Expand Down
1 change: 0 additions & 1 deletion src/modules/payment/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ export class BotPayments {
public async pay (ctx: OnMessageContext, amountUSD: number): Promise<boolean> {
// eslint-disable-next-line @typescript-eslint/naming-convention
const { from, message_id, chat } = ctx.update.message

const accountId = this.getAccountId(ctx)
const userAccount = this.getUserAccount(accountId)
if (!userAccount) {
Expand Down
9 changes: 8 additions & 1 deletion src/modules/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,20 @@ export interface ChatCompletion {
}
export interface ChatPayload {
conversation: ChatConversation[]
prompt?: string
model: string
ctx: OnMessageContext | OnCallBackQueryData
}

export interface VisionContent {
type: string
text?: string
image_url?: { url: string }
}
export interface ChatConversation {
role?: string
author?: string
content: string
content: string | [VisionContent]
model?: string
}

Expand Down

0 comments on commit 9395577

Please sign in to comment.