diff --git a/src/bot.ts b/src/bot.ts index 35d711e9..f8f49389 100644 --- a/src/bot.ts +++ b/src/bot.ts @@ -1,4 +1,4 @@ -import {TranslateBot} from "./modules/translate/TranslateBot"; +import { TranslateBot } from "./modules/translate/TranslateBot"; require("events").EventEmitter.defaultMaxListeners = 30; import express from "express"; @@ -98,8 +98,8 @@ function createInitialSessionData(): BotSessionData { }, translate: { languages: [], - enable: false - } + enable: false, + }, }; } @@ -270,13 +270,15 @@ const onMessage = async (ctx: OnMessageContext) => { const price = translateBot.getEstimatedPrice(ctx); const isPaid = await payments.pay(ctx, price); - if(isPaid) { - const response = await translateBot.onEvent(ctx, (reason?: string) => { - payments.refundPayment(reason, ctx, price); - }).catch((e) => { - payments.refundPayment(e.message || "Unknown error", ctx, price); - return {next: false}; - }); + if (isPaid) { + const response = await translateBot + .onEvent(ctx, (reason?: string) => { + payments.refundPayment(reason, ctx, price); + }) + .catch((e) => { + payments.refundPayment(e.message || "Unknown error", ctx, price); + return { next: false }; + }); if (!response.next) { return; @@ -284,7 +286,7 @@ const onMessage = async (ctx: OnMessageContext) => { } } - if (openAiBot.isSupportedEvent(ctx)) { + if (await openAiBot.isSupportedEvent(ctx)) { if (ctx.session.openAi.imageGen.isEnabled) { const price = openAiBot.getEstimatedPrice(ctx); const isPaid = await payments.pay(ctx, price!); @@ -438,15 +440,15 @@ bot.command("love", (ctx) => { }); }); -bot.command('stop', (ctx) => { +bot.command("stop", (ctx) => { logger.info("/stop command"); ctx.session.openAi.chatGpt.chatConversation = []; ctx.session.openAi.chatGpt.usage = 0; - ctx.session.openAi.chatGpt.price = 0; + ctx.session.openAi.chatGpt.price = 0; ctx.session.translate.enable = false; - ctx.session.translate.languages = [] - ctx.session.oneCountry.lastDomain = "" -}) + ctx.session.translate.languages = []; + ctx.session.oneCountry.lastDomain = ""; +}); // bot.command("memo", (ctx) => { // ctx.reply(MEMO.text, { // parse_mode: "Markdown", diff --git a/src/config.ts b/src/config.ts index 4a290f6f..c979e1b0 100644 --- a/src/config.ts +++ b/src/config.ts @@ -26,10 +26,6 @@ export default { ? parseInt(process.env.SESSION_TIMEOUT) : 48, // in hours openAi: { - maxTokens: - (process.env.OPENAI_MAX_TOKENS && - parseInt(process.env.OPENAI_MAX_TOKENS)) || - 800, // telegram messages has a char limit dalle: { isEnabled: Boolean(parseInt(process.env.IMAGE_GEN_ENABLED || "1")), telegramFileUrl: "https://api.telegram.org/file/bot", @@ -48,6 +44,11 @@ export default { }, }, chatGpt: { + maxTokens: + (process.env.OPENAI_MAX_TOKENS && + parseInt(process.env.OPENAI_MAX_TOKENS)) || + 800, // telegram messages has a char limit + wordLimit: 50, wordCountBetween: process.env.WORD_COUNT_BETWEEN ? parseInt(process.env.WORD_COUNT_BETWEEN) : 100, @@ -64,13 +65,13 @@ export default { prefixes: { chatPrefix: process.env.ASK_PREFIX ? process.env.ASK_PREFIX.split(",") - : ["a.","?",">","."], + : ["a.", "?", ">", "."], dallePrefix: process.env.DALLE_PREFIX ? process.env.DALLE_PREFIX.split(",") : ["d."], newPrefix: process.env.NEW_PREFIX ? process.env.NEW_PREFIX.split(",") - : ["n."], + : ["n.", ".."], }, minimumBalance: process.env.MIN_BALANCE ? parseInt(process.env.MIN_BALANCE) diff --git a/src/modules/open-ai/api/openAi.ts b/src/modules/open-ai/api/openAi.ts index 3251f94a..a910b43a 100644 --- a/src/modules/open-ai/api/openAi.ts +++ b/src/modules/open-ai/api/openAi.ts @@ -17,6 +17,7 @@ import { DalleGPTModel, DalleGPTModels, } from "../types"; +import { getMessageExtras } from "../helpers"; const openai = new OpenAI({ apiKey: config.openAiKey, @@ -98,7 +99,7 @@ export async function chatCompletion( try { const payload = { model: model, - max_tokens: limitTokens ? config.openAi.maxTokens : undefined, + max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined, temperature: config.openAi.dalle.completions.temperature, messages: conversation, }; @@ -134,12 +135,15 @@ export const streamChatCompletion = async ( const wordCountMinimum = config.openAi.chatGpt.wordCountBetween; return new Promise(async (resolve, reject) => { try { + // const extras = getMessageExtras({ + // topicId: ctx.message?.message_thread_id + // }) const stream = await openai.chat.completions.create({ model: model, messages: conversation as OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[], stream: true, - max_tokens: limitTokens ? config.openAi.maxTokens : undefined, + max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined, temperature: config.openAi.dalle.completions.temperature, }); let wordCount = 0; @@ -157,6 +161,9 @@ export const streamChatCompletion = async ( completion = completion.replaceAll("..", ""); completion += ".."; wordCount = 0; + // const extras = getMessageExtras({ + // topicId: ctx.message?.message_thread_id + // }) await ctx.api .editMessageText(ctx.chat?.id!, msgId, completion) .catch(async (e: any) => { diff --git a/src/modules/open-ai/controller/index.ts b/src/modules/open-ai/controller/index.ts deleted file mode 100644 index aa521f87..00000000 --- a/src/modules/open-ai/controller/index.ts +++ /dev/null @@ -1,168 +0,0 @@ -import { pino } from "pino"; -import { - ChatConversation, - OnCallBackQueryData, - OnMessageContext, -} from "../../types"; -import { - improvePrompt, - postGenerateImg, - alterGeneratedImg, - streamChatCompletion, - getTokenNumber, - getChatModel, - getChatModelPrice, -} from "../api/openAi"; -import config from "../../../config"; - -interface ImageGenPayload { - chatId: number; - prompt: string; - numImages?: number; - imgSize?: string; - filePath?: string; - model?: string; -} - -interface ChatGptPayload { - conversation: ChatConversation[]; - model: string; - ctx: OnMessageContext | OnCallBackQueryData; -} - -const logger = pino({ - name: "openAI-controller", - transport: { - target: "pino-pretty", - options: { - colorize: true, - }, - }, -}); - -export const imgGen = async ( - data: ImageGenPayload, - ctx: OnMessageContext | OnCallBackQueryData -) => { - const { chatId, prompt, numImages, imgSize } = data; - try { - const imgs = await postGenerateImg(prompt, numImages, imgSize); - imgs.map(async (img: any) => { - await ctx - .replyWithPhoto(img.url, { - caption: `/dalle ${prompt}`, - }) - .catch((e) => { - throw e; - }); - }); - return true; - } catch (e: any) { - throw e; - } -}; - -export const imgGenEnhanced = async ( - data: ImageGenPayload, - ctx: OnMessageContext | OnCallBackQueryData -) => { - const { chatId, prompt, numImages, imgSize, model } = data; - try { - const upgratedPrompt = await improvePrompt(prompt, model!); - if (upgratedPrompt) { - await ctx - .reply( - `The following description was added to your prompt: ${upgratedPrompt}` - ) - .catch((e) => { - throw e; - }); - } - // bot.api.sendMessage(chatId, "generating the output..."); - const imgs = await postGenerateImg( - upgratedPrompt || prompt, - numImages, - imgSize - ); - imgs.map(async (img: any) => { - await ctx - .replyWithPhoto(img.url, { - caption: `/DALLE ${upgratedPrompt || prompt}`, - }) - .catch((e) => { - throw e; - }); - }); - return true; - } catch (e) { - throw e; - } -}; - -export const alterImg = async ( - data: ImageGenPayload, - ctx: OnMessageContext | OnCallBackQueryData -) => { - const { chatId, prompt, numImages, imgSize, filePath } = data; - try { - ctx.chatAction = "upload_photo"; - const imgs = await alterGeneratedImg(prompt!, filePath!, ctx, imgSize!); - if (imgs) { - imgs!.map(async (img: any) => { - await ctx.replyWithPhoto(img.url).catch((e) => { - throw e; - }); - }); - } - ctx.chatAction = null; - } catch (e) { - throw e; - } -}; - -export const promptGen = async ( - data: ChatGptPayload, - chat: ChatConversation[] -) => { - const { conversation, ctx, model } = data; - try { - let msgId = (await ctx.reply("...")).message_id; - const isTypingEnabled = config.openAi.chatGpt.isTypingEnabled; - if (isTypingEnabled) { - ctx.chatAction = "typing"; - } - const completion = await streamChatCompletion( - conversation!, - ctx, - model, - msgId, - true // telegram messages has a character limit - ); - if (isTypingEnabled) { - ctx.chatAction = null; - } - if (completion) { - const prompt = conversation[conversation.length - 1].content; - const promptTokens = getTokenNumber(prompt); - const completionTokens = getTokenNumber(completion); - const modelPrice = getChatModel(model); - const price = - getChatModelPrice(modelPrice, true, promptTokens, completionTokens) * - config.openAi.chatGpt.priceAdjustment; - logger.info( - `streamChatCompletion result = tokens: ${ - promptTokens + completionTokens - } | ${modelPrice.name} | price: ${price}¢` - ); - conversation.push({ content: completion, role: "system" }); - ctx.session.openAi.chatGpt.usage += promptTokens + completionTokens; - ctx.session.openAi.chatGpt.price += price; - chat = [...conversation!]; - return price; - } - return 0; - } catch (e: any) { - ctx.chatAction = null; - throw e; - } -}; diff --git a/src/modules/open-ai/helpers.ts b/src/modules/open-ai/helpers.ts new file mode 100644 index 00000000..e9e46866 --- /dev/null +++ b/src/modules/open-ai/helpers.ts @@ -0,0 +1,243 @@ +import config from "../../config"; +import { isValidUrl } from "./utils/web-crawler"; +import { OnMessageContext, OnCallBackQueryData, MessageExtras } from "../types"; +import { parse } from "path"; +import { ParseMode } from "grammy/types"; +import { getChatModel, getChatModelPrice, getTokenNumber } from "./api/openAi"; +import { ChatGptPayload } from "./types"; + +export const SupportedCommands = { + chat: { + name: "chat", + }, + ask: { + name: "ask", + }, + sum: { + name: "sum", + }, + ask35: { + name: "ask35", + }, + new: { + name: "new", + }, + gpt4: { + name: "gpt4", + }, + gpt: { + name: "gpt", + }, + last: { + name: "last", + }, + dalle: { + name: "DALLE", + }, + dalleLC: { + name: "dalle", + }, + genImgEn: { + name: "genImgEn", + }, +}; + +export const MAX_TRIES = 3; + +export const isMentioned = ( + ctx: OnMessageContext | OnCallBackQueryData +): boolean => { + if (ctx.entities()[0]) { + const { offset, text } = ctx.entities()[0]; + const { username } = ctx.me; + if (username === text.slice(1) && offset === 0) { + const prompt = ctx.message?.text!.slice(text.length); + if (prompt && prompt.split(" ").length > 0) { + return true; + } + } + } + return false; +}; + +export const hasChatPrefix = (prompt: string): string => { + const prefixList = config.openAi.chatGpt.prefixes.chatPrefix; + for (let i = 0; i < prefixList.length; i++) { + if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { + return prefixList[i]; + } + } + return ""; +}; + +export const hasDallePrefix = (prompt: string): string => { + const prefixList = config.openAi.chatGpt.prefixes.dallePrefix; + for (let i = 0; i < prefixList.length; i++) { + if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { + return prefixList[i]; + } + } + return ""; +}; + +export const hasNewPrefix = (prompt: string): string => { + const prefixList = config.openAi.chatGpt.prefixes.newPrefix; + for (let i = 0; i < prefixList.length; i++) { + if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { + return prefixList[i]; + } + } + return ""; +}; + +export const hasUrl = (prompt: string) => { + const promptArray = prompt.split(" "); + let url = ""; + for (let i = 0; i < promptArray.length; i++) { + if (isValidUrl(promptArray[i])) { + url = promptArray[i]; + promptArray.splice(i, 1); + break; + } + } + return { + url, + newPrompt: promptArray.join(" "), + }; +}; + +export const hasUsernamePassword = (prompt: string) => { + let user = ""; + let password = ""; + const parts = prompt.split(" "); + + for (let i = 0; i < parts.length; i++) { + const part = parts[i].toLowerCase(); + if (part.includes("=")) { + const [keyword, value] = parts[i].split("="); + if (keyword === "user" || keyword === "username") { + user = value; + } else if (keyword === "password" || keyword === "pwd") { + password = value; + } + if (user !== "" && password !== "") { + break; + } + } else if (part === "user") { + user = parts[i + 1]; + } else if (part === "password") { + password = parts[i + 1]; + } + } + return { user, password }; +}; + +// doesn't get all the special characters like ! +export const hasUserPasswordRegex = (prompt: string) => { + const pattern = + /\b(user=|password=|user|password)\s*([^\s]+)\b.*\b(user=|password=|user|password)\s*([^\s]+)\b/i; + const matches = pattern.exec(prompt); + + let user = ""; + let password = ""; + + if (matches) { + const [_, keyword, word, __, word2] = matches; + if (keyword.toLowerCase() === "user" || keyword.toLowerCase() === "user=") { + user = word; + password = word2; + } else if ( + keyword.toLowerCase() === "password" || + keyword.toLowerCase() === "password=" + ) { + password = word; + user = word2; + } + } + return { user, password }; +}; + +export const preparePrompt = async ( + ctx: OnMessageContext | OnCallBackQueryData, + prompt: string +) => { + const msg = await ctx.message?.reply_to_message?.text; + if (msg) { + return `${prompt} ${msg}`; + } + return prompt; +}; + +export const messageTopic = async ( + ctx: OnMessageContext | OnCallBackQueryData +) => { + return await ctx.message?.message_thread_id; +}; +interface GetMessagesExtras { + parseMode?: ParseMode | undefined; + topicId?: number | undefined; + caption?: string | undefined; +} + +export const getMessageExtras = (params: GetMessagesExtras) => { + const { parseMode, topicId, caption } = params; + let extras: MessageExtras = {}; + if (parseMode) { + extras["parse_mode"] = parseMode; + } + if (topicId) { + extras["message_thread_id"] = parseInt( + String(topicId) + ) as unknown as number; + } + if (caption) { + extras["caption"] = caption; + } + return extras; +}; + +export const sendMessage = async ( + ctx: OnMessageContext | OnCallBackQueryData, + msg: string, + msgExtras: GetMessagesExtras +) => { + const extras = getMessageExtras(msgExtras); + return await ctx.reply(msg, extras); +}; + +export const hasPrefix = (prompt: string): string => { + return ( + hasChatPrefix(prompt) || hasDallePrefix(prompt) || hasNewPrefix(prompt) + ); +}; + +export const getPromptPrice = (completion: string, data: ChatGptPayload) => { + const { conversation, ctx, model } = data; + + const prompt = conversation[conversation.length - 1].content; + const promptTokens = getTokenNumber(prompt); + const completionTokens = getTokenNumber(completion); + const modelPrice = getChatModel(model); + const price = + getChatModelPrice(modelPrice, true, promptTokens, completionTokens) * + config.openAi.chatGpt.priceAdjustment; + conversation.push({ content: completion, role: "system" }); + ctx.session.openAi.chatGpt.usage += promptTokens + completionTokens; + ctx.session.openAi.chatGpt.price += price; + return { + price, + promptTokens, + completionTokens, + }; +}; + +export const limitPrompt = (prompt: string) => { + const wordCountPattern = /(\d+)\s*word(s)?/g; + const match = wordCountPattern.exec(prompt); + + if (match) { + return `${prompt}`; + } + + return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words`; +}; diff --git a/src/modules/open-ai/index.ts b/src/modules/open-ai/index.ts index 91808184..08f8d27f 100644 --- a/src/modules/open-ai/index.ts +++ b/src/modules/open-ai/index.ts @@ -9,58 +9,37 @@ import { OnCallBackQueryData, ChatConversation, } from "../types"; -import { getChatModel, getDalleModel, getDalleModelPrice } from "./api/openAi"; -import { alterImg, imgGen, imgGenEnhanced, promptGen } from "./controller"; +import { + alterGeneratedImg, + getChatModel, + getDalleModel, + getDalleModelPrice, + postGenerateImg, + streamChatCompletion, +} from "./api/openAi"; import { appText } from "./utils/text"; import { chatService } from "../../database/services"; -import { ChatGPTModelsEnum } from "./types"; +import { ChatGPTModelsEnum, ChatGptPayload } from "./types"; import config from "../../config"; import { sleep } from "../sd-images/utils"; import { - isValidUrl, - getWebContent, - getCrawlerPrice, -} from "./utils/web-crawler"; - -export const SupportedCommands = { - chat: { - name: "chat", - }, - ask: { - name: "ask", - }, - sum: { - name: "sum", - }, - ask35: { - name: "ask35", - }, - new: { - name: "new", - }, - gpt4: { - name: "gpt4", - }, - gpt: { - name: "gpt", - }, - last: { - name: "last", - }, - dalle: { - name: "DALLE", - }, - dalleLC: { - name: "dalle", - }, - genImgEn: { - name: "genImgEn", - } -}; + getMessageExtras, + getPromptPrice, + hasChatPrefix, + hasNewPrefix, + hasPrefix, + hasUrl, + hasUsernamePassword, + isMentioned, + limitPrompt, + MAX_TRIES, + messageTopic, + preparePrompt, + sendMessage, + SupportedCommands, +} from "./helpers"; +import { getWebContent, getCrawlerPrice } from "./utils/web-crawler"; -const MAX_TRIES = 3; - -// const payments = new BotPayments(); export class OpenAIBot { private logger: Logger; private payments: BotPayments; @@ -83,87 +62,23 @@ export class OpenAIBot { } } - private isMentioned(ctx: OnMessageContext | OnCallBackQueryData): boolean { - if (ctx.entities()[0]) { - const { offset, text } = ctx.entities()[0]; - const { username } = ctx.me; - if (username === text.slice(1) && offset === 0) { - const prompt = ctx.message?.text!.slice(text.length); - if (prompt && prompt.split(" ").length > 0) { - return true; - } - } - } - return false; - } - - public isSupportedEvent( + public async isSupportedEvent( ctx: OnMessageContext | OnCallBackQueryData - ): boolean { + ): Promise { const hasCommand = ctx.hasCommand( Object.values(SupportedCommands).map((command) => command.name) ); - if (this.isMentioned(ctx)) { + if (isMentioned(ctx)) { return true; } const hasReply = this.isSupportedImageReply(ctx); - const chatPrefix = this.hasPrefix(ctx.message?.text || ""); + const chatPrefix = hasPrefix(ctx.message?.text || ""); if (chatPrefix !== "") { return true; } return hasCommand || hasReply; } - private hasPrefix(prompt: string): string { - return this.hasChatPrefix(prompt) || this.hasDallePrefix(prompt) || this.hasNewPrefix(prompt) - } - - private hasChatPrefix(prompt: string): string { - const prefixList = config.openAi.chatGpt.prefixes.chatPrefix; - for (let i = 0; i < prefixList.length; i++) { - if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { - return prefixList[i]; - } - } - return ""; - } - - private hasDallePrefix(prompt: string): string { - const prefixList = config.openAi.chatGpt.prefixes.dallePrefix; - for (let i = 0; i < prefixList.length; i++) { - if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { - return prefixList[i]; - } - } - return ""; - } - - private hasNewPrefix(prompt: string): string { - const prefixList = config.openAi.chatGpt.prefixes.newPrefix; - for (let i = 0; i < prefixList.length; i++) { - if (prompt.toLocaleLowerCase().startsWith(prefixList[i])) { - return prefixList[i]; - } - } - return ""; - } - - private hasUrl(prompt: string) { - const promptArray = prompt.split(" "); - let url = ""; - for (let i = 0; i < promptArray.length; i++) { - if (isValidUrl(promptArray[i])) { - url = promptArray[i]; - promptArray.splice(i, 1); - break; - } - } - return { - url, - newPrompt: promptArray.join(" "), - }; - } - public getEstimatedPrice(ctx: any): number { try { const priceAdjustment = config.openAi.chatGpt.priceAdjustment; @@ -229,7 +144,7 @@ export class OpenAIBot { if ( ctx.hasCommand(SupportedCommands.chat.name) || - (ctx.message?.text?.startsWith("chat ") && ctx.chat?.type === 'private') + (ctx.message?.text?.startsWith("chat ") && ctx.chat?.type === "private") ) { ctx.session.openAi.chatGpt.model = ChatGPTModelsEnum.GPT_4; await this.onChat(ctx); @@ -238,17 +153,17 @@ export class OpenAIBot { if ( ctx.hasCommand(SupportedCommands.new.name) || - (ctx.message?.text?.startsWith("new ") && ctx.chat?.type === 'private') + (ctx.message?.text?.startsWith("new ") && ctx.chat?.type === "private") ) { ctx.session.openAi.chatGpt.model = ChatGPTModelsEnum.GPT_4; - await this.onEnd(ctx) + await this.onEnd(ctx); this.onChat(ctx); return; } if ( ctx.hasCommand(SupportedCommands.ask.name) || - (ctx.message?.text?.startsWith("ask ") && ctx.chat?.type === 'private') + (ctx.message?.text?.startsWith("ask ") && ctx.chat?.type === "private") ) { ctx.session.openAi.chatGpt.model = ChatGPTModelsEnum.GPT_4; this.onChat(ctx); @@ -275,16 +190,16 @@ export class OpenAIBot { if ( ctx.hasCommand(SupportedCommands.dalle.name) || ctx.hasCommand(SupportedCommands.dalleLC.name) || - (ctx.message?.text?.startsWith("dalle ") && ctx.chat?.type === 'private') + (ctx.message?.text?.startsWith("dalle ") && ctx.chat?.type === "private") ) { this.onGenImgCmd(ctx); return; } - if (ctx.hasCommand(SupportedCommands.genImgEn.name)) { - this.onGenImgEnCmd(ctx); - return; - } + // if (ctx.hasCommand(SupportedCommands.genImgEn.name)) { + // this.onGenImgEnCmd(ctx); + // return; + // } if (this.isSupportedImageReply(ctx)) { this.onAlterImage(ctx); @@ -293,7 +208,7 @@ export class OpenAIBot { if ( ctx.hasCommand(SupportedCommands.sum.name) || - (ctx.message?.text?.startsWith("sum ") && ctx.chat?.type === 'private') + (ctx.message?.text?.startsWith("sum ") && ctx.chat?.type === "private") ) { this.onSum(ctx); return; @@ -303,18 +218,18 @@ export class OpenAIBot { return; } - if (this.hasChatPrefix(ctx.message?.text || "") !== "") { + if (hasNewPrefix(ctx.message?.text || "") !== "") { + await this.onEnd(ctx); this.onPrefix(ctx); return; } - if (this.hasNewPrefix(ctx.message?.text || "") !== "") { - await this.onEnd(ctx) + if (hasChatPrefix(ctx.message?.text || "") !== "") { this.onPrefix(ctx); return; } - if (this.isMentioned(ctx)) { + if (isMentioned(ctx)) { this.onMention(ctx); return; } @@ -325,9 +240,9 @@ export class OpenAIBot { } this.logger.warn(`### unsupported command`); - await ctx - .reply("### unsupported command") - .catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled")); + sendMessage(ctx, "### unsupported command", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e, MAX_TRIES, "### unsupported command")); } private async hasBalance(ctx: OnMessageContext | OnCallBackQueryData) { @@ -350,52 +265,30 @@ export class OpenAIBot { prompt = config.openAi.dalle.defaultPrompt; } ctx.chatAction = "upload_photo"; - const payload = { - chatId: ctx.chat?.id!, - prompt: prompt as string, - numImages: await ctx.session.openAi.imageGen.numImages, // lazy load - imgSize: await ctx.session.openAi.imageGen.imgSize, // lazy load - }; - await imgGen(payload, ctx); - } else { - await ctx - .reply("Bot disabled") - .catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled")); - } - } catch (e) { - this.onError(ctx, e, 3, "There was an error while generating the image"); - } - }; - - onGenImgEnCmd = async (ctx: OnMessageContext | OnCallBackQueryData) => { - try { - if (ctx.session.openAi.imageGen.isEnabled) { - const prompt = await ctx.match; - if (!prompt) { - await ctx - .reply("Error: Missing prompt") - .catch((e) => - this.onError(ctx, e, MAX_TRIES, "Error: Missing prompt") - ); - return; - } - const payload = { - chatId: await ctx.chat?.id!, - prompt: prompt as string, - numImages: await ctx.session.openAi.imageGen.numImages, - imgSize: await ctx.session.openAi.imageGen.imgSize, - }; - await ctx - .reply("generating improved prompt...") - .catch((e) => this.onError(ctx, e)); - await imgGenEnhanced(payload, ctx); + const numImages = await ctx.session.openAi.imageGen.numImages; + const imgSize = await ctx.session.openAi.imageGen.imgSize; + const imgs = await postGenerateImg(prompt, numImages, imgSize); + const msgExtras = getMessageExtras({ + caption: `/dalle ${prompt}`, + topicId: await messageTopic(ctx), + }); + imgs.map(async (img: any) => { + await ctx.replyWithPhoto(img.url, msgExtras).catch((e) => { + this.onError(ctx, e, MAX_TRIES); + }); + }); } else { - await ctx - .reply("Bot disabled") - .catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled")); + sendMessage(ctx, "Bot disabled", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled")); } } catch (e) { - this.onError(ctx, e); + this.onError( + ctx, + e, + MAX_TRIES, + "There was an error while generating the image" + ); } }; @@ -408,14 +301,17 @@ export class OpenAIBot { const file_id = photo?.pop()?.file_id; // with pop() get full image quality const file = await ctx.api.getFile(file_id!); const filePath = `${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}`; - const payload = { - chatId: ctx.chat?.id!, - prompt: prompt as string, - numImages: await ctx.session.openAi.imageGen.numImages, - imgSize: await ctx.session.openAi.imageGen.imgSize, - filePath: filePath, - }; - await alterImg(payload, ctx); + const imgSize = await ctx.session.openAi.imageGen.imgSize; + ctx.chatAction = "upload_photo"; + const imgs = await alterGeneratedImg(prompt!, filePath!, ctx, imgSize!); + if (imgs) { + imgs!.map(async (img: any) => { + await ctx.replyWithPhoto(img.url).catch((e) => { + throw e; + }); + }); + } + ctx.chatAction = null; } } catch (e: any) { this.onError( @@ -427,93 +323,68 @@ export class OpenAIBot { } }; - // doesn't get all the special characters like ! - private hasUserPasswordRegex(prompt: string) { - const pattern = - /\b(user=|password=|user|password)\s*([^\s]+)\b.*\b(user=|password=|user|password)\s*([^\s]+)\b/i; - const matches = pattern.exec(prompt); - - let user = ""; - let password = ""; - - if (matches) { - const [_, keyword, word, __, word2] = matches; - if ( - keyword.toLowerCase() === "user" || - keyword.toLowerCase() === "user=" - ) { - user = word; - password = word2; - } else if ( - keyword.toLowerCase() === "password" || - keyword.toLowerCase() === "password=" - ) { - password = word; - user = word2; + private async promptGen(data: ChatGptPayload, chat: ChatConversation[]) { + const { conversation, ctx, model } = data; + try { + const extras = getMessageExtras({ + topicId: ctx.message?.message_thread_id, + }); + let msgId = (await ctx.reply("...", extras)).message_id; + const isTypingEnabled = config.openAi.chatGpt.isTypingEnabled; + if (isTypingEnabled) { + ctx.chatAction = "typing"; } - } - - return { user, password }; - } - - private hasUsernamePassword(prompt: string) { - let user = ""; - let password = ""; - const parts = prompt.split(" "); - - for (let i = 0; i < parts.length; i++) { - const part = parts[i].toLowerCase(); - if (part.includes("=")) { - const [keyword, value] = parts[i].split("="); - if (keyword === "user" || keyword === "username") { - user = value; - } else if (keyword === "password" || keyword === "pwd") { - password = value; - } - if (user !== "" && password !== "") { - break; - } - } else if (part === "user") { - user = parts[i + 1]; - } else if (part === "password") { - password = parts[i + 1]; + const completion = await streamChatCompletion( + conversation!, + ctx, + model, + msgId, + true // telegram messages has a character limit + ); + if (isTypingEnabled) { + ctx.chatAction = null; } + if (completion) { + const price = getPromptPrice(completion, data); + this.logger.info( + `streamChatCompletion result = tokens: ${ + price.promptTokens + price.completionTokens + } | ${model} | price: ${price}¢` + ); + conversation.push({ content: completion, role: "system" }); + chat = [...conversation!]; + return price.price; + } + return 0; + } catch (e: any) { + ctx.chatAction = null; + throw e; } - return { user, password }; - } - - private async preparePrompt( - ctx: OnMessageContext | OnCallBackQueryData, - prompt: string - ) { - const msg = await ctx.message?.reply_to_message?.text; - if (msg) { - return `${prompt} ${msg}`; - } - return prompt; } async onSum(ctx: OnMessageContext | OnCallBackQueryData) { if (this.botSuspended) { - await ctx - .reply("The bot is suspended") - .catch((e) => this.onError(ctx, e)); + sendMessage(ctx, "The bot is suspended", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e)); return; } try { const { prompt } = getCommandNamePrompt(ctx, SupportedCommands); - const { url, newPrompt } = this.hasUrl(prompt); + const { url, newPrompt } = hasUrl(prompt); if (url) { let chat: ChatConversation[] = []; this.onWebCrawler( ctx, - await this.preparePrompt(ctx, newPrompt), + await preparePrompt(ctx, newPrompt), chat, url, "sum" ); } else { - ctx.reply(`Error: Missing url`); + await sendMessage(ctx, `Error: Missing url`, { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e)); } } catch (e) { this.onError(ctx, e); @@ -533,8 +404,8 @@ export class OpenAIBot { // const { model } = ctx.session.openAi.chatGpt; const chatModel = getChatModel(model); const webCrawlerMaxTokens = - chatModel.maxContextTokens - config.openAi.maxTokens * 2; - const { user, password } = this.hasUsernamePassword(prompt); + chatModel.maxContextTokens - config.openAi.chatGpt.maxTokens * 2; + const { user, password } = hasUsernamePassword(prompt); if (user && password) { // && ctx.chat?.type !== 'private' const maskedPrompt = @@ -542,7 +413,9 @@ export class OpenAIBot { ?.text!.replaceAll(user, "****") .replaceAll(password, "*****") || ""; ctx.api.deleteMessage(ctx.chat?.id!, ctx.message?.message_id!); - ctx.reply(maskedPrompt); + sendMessage(ctx, maskedPrompt, { + topicId: ctx.message?.message_thread_id, + }); } const webContent = await getWebContent( url, @@ -551,51 +424,55 @@ export class OpenAIBot { password ); if (webContent.urlText !== "") { - // ctx.reply(`URL downloaded`, - // // `${(webContent.networkTraffic / 1048576).toFixed( + // await sendMessage(ctx,`URL downloaded`, + // // `${(webContent.networkTraffic / 1048576).toFixed( // // 2 // // )} MB in ${(webContent.elapsedTime / 1000).toFixed(2)} seconds`, - // { - // parse_mode: "Markdown", - // } - // ); + // { + // topicId: ctx.message?.message_thread_id, + // parseMode: "Markdown", + // }).catch((e) => this.onError(ctx, e)); if ( !(await this.payments.pay(ctx as OnMessageContext, webContent.fees)) ) { this.onNotBalanceMessage(ctx); } else { + let newPrompt = ""; if (prompt !== "") { - chat.push({ - content: `${ - command === "sum" && "Summarize" - } ${prompt} this text: ${webContent.urlText}`, - role: "user", - }); + newPrompt = `${command === "sum" && "Summarize"} ${limitPrompt( + prompt + )} this text: ${webContent.urlText}`; } else { - chat.push({ - content: `${ - command === "sum" && "Summarize this text in 50 words:" - } "${webContent.urlText}"`, - role: "user", - }); + newPrompt = `${ + command === "sum" && + `Summarize this text in ${config.openAi.chatGpt.wordLimit} words:` + } "${webContent.urlText}"`; } - + chat.push({ + content: newPrompt, + role: "user", + }); if (prompt || command === "sum") { const payload = { conversation: chat, model: model || config.openAi.chatGpt.model, ctx, }; - const price = await promptGen(payload, chat); + const price = await this.promptGen(payload, chat); if (!(await this.payments.pay(ctx as OnMessageContext, price))) { this.onNotBalanceMessage(ctx); } } } } else { - ctx.reply( - "Url not supported, incorrect web site address or missing user credentials" - ); + await sendMessage( + ctx, + "Url not supported, incorrect web site address or missing user credentials", + { + topicId: ctx.message?.message_thread_id, + parseMode: "Markdown", + } + ).catch((e) => this.onError(ctx, e)); return; } return { @@ -613,15 +490,15 @@ export class OpenAIBot { async onMention(ctx: OnMessageContext | OnCallBackQueryData) { try { if (this.botSuspended) { - await ctx - .reply("The bot is suspended") - .catch((e) => this.onError(ctx, e)); + sendMessage(ctx, "The bot is suspended", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e)); return; } const { username } = ctx.me; const prompt = ctx.message?.text?.slice(username.length + 1) || ""; //@ ctx.session.openAi.chatGpt.requestQueue.push( - await this.preparePrompt(ctx, prompt) + await preparePrompt(ctx, prompt) ); if (!ctx.session.openAi.chatGpt.isProcessingQueue) { ctx.session.openAi.chatGpt.isProcessingQueue = true; @@ -637,18 +514,18 @@ export class OpenAIBot { async onPrefix(ctx: OnMessageContext | OnCallBackQueryData) { try { if (this.botSuspended) { - await ctx - .reply("The bot is suspended") - .catch((e) => this.onError(ctx, e)); + sendMessage(ctx, "The bot is suspended", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e)); return; } const { prompt, commandName } = getCommandNamePrompt( ctx, SupportedCommands ); - const prefix = this.hasPrefix(prompt); + const prefix = hasPrefix(prompt); ctx.session.openAi.chatGpt.requestQueue.push( - await this.preparePrompt(ctx, prompt.slice(prefix.length)) + await preparePrompt(ctx, prompt.slice(prefix.length)) ); if (!ctx.session.openAi.chatGpt.isProcessingQueue) { ctx.session.openAi.chatGpt.isProcessingQueue = true; @@ -664,13 +541,13 @@ export class OpenAIBot { async onPrivateChat(ctx: OnMessageContext | OnCallBackQueryData) { try { if (this.botSuspended) { - await ctx - .reply("The bot is suspended") - .catch((e) => this.onError(ctx, e)); + sendMessage(ctx, "The bot is suspended", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e)); return; } ctx.session.openAi.chatGpt.requestQueue.push( - await this.preparePrompt(ctx, ctx.message?.text!) + await preparePrompt(ctx, ctx.message?.text!) ); if (!ctx.session.openAi.chatGpt.isProcessingQueue) { ctx.session.openAi.chatGpt.isProcessingQueue = true; @@ -686,14 +563,14 @@ export class OpenAIBot { async onChat(ctx: OnMessageContext | OnCallBackQueryData) { try { if (this.botSuspended) { - await ctx - .reply("The bot is suspended") - .catch((e) => this.onError(ctx, e)); + sendMessage(ctx, "The bot is suspended", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e)); return; } const prompt = ctx.match ? ctx.match : ctx.message?.text; ctx.session.openAi.chatGpt.requestQueue.push( - await this.preparePrompt(ctx, prompt as string) + await preparePrompt(ctx, prompt as string) ); if (!ctx.session.openAi.chatGpt.isProcessingQueue) { ctx.session.openAi.chatGpt.isProcessingQueue = true; @@ -719,12 +596,13 @@ export class OpenAIBot { chatConversation[chatConversation.length - 1].content }_` : appText.introText; - await ctx - .reply(msg, { parse_mode: "Markdown" }) - .catch((e) => this.onError(ctx, e)); + await sendMessage(ctx, msg, { + topicId: ctx.message?.message_thread_id, + parseMode: "Markdown", + }).catch((e) => this.onError(ctx, e)); return; } - const { url, newPrompt } = this.hasUrl(prompt); + const { url, newPrompt } = hasUrl(prompt); if (url) { await this.onWebCrawler( ctx, @@ -734,16 +612,16 @@ export class OpenAIBot { "ask" ); } else { - chatConversation.push({ + const newPrompt = chatConversation.push({ role: "user", - content: prompt, + content: limitPrompt(prompt), }); const payload = { conversation: chatConversation!, model: model || config.openAi.chatGpt.model, ctx, }; - const price = await promptGen(payload, chatConversation); + const price = await this.promptGen(payload, chatConversation); if (!(await this.payments.pay(ctx as OnMessageContext, price))) { this.onNotBalanceMessage(ctx); } @@ -761,17 +639,19 @@ export class OpenAIBot { async onLast(ctx: OnMessageContext | OnCallBackQueryData) { if (ctx.session.openAi.chatGpt.chatConversation.length > 0) { const chat = ctx.session.openAi.chatGpt.chatConversation; - await ctx - .reply(`${appText.gptLast}\n_${chat[chat.length - 1].content}_`, { - parse_mode: "Markdown", - }) - .catch((e) => this.onError(ctx, e)); + await sendMessage( + ctx, + `${appText.gptLast}\n_${chat[chat.length - 1].content}_`, + { + topicId: ctx.message?.message_thread_id, + parseMode: "Markdown", + } + ).catch((e) => this.onError(ctx, e)); } else { - await ctx - .reply(`To start a conversation please write */ask*`, { - parse_mode: "Markdown", - }) - .catch((e) => this.onError(ctx, e)); + await sendMessage(ctx, `To start a conversation please write */ask*`, { + topicId: ctx.message?.message_thread_id, + parseMode: "Markdown", + }).catch((e) => this.onError(ctx, e)); } } @@ -780,7 +660,7 @@ export class OpenAIBot { ctx.session.openAi.chatGpt.usage = 0; ctx.session.openAi.chatGpt.price = 0; } - + async onNotBalanceMessage(ctx: OnMessageContext | OnCallBackQueryData) { const accountId = this.payments.getAccountId(ctx as OnMessageContext); const account = await this.payments.getUserAccount(accountId); @@ -791,9 +671,10 @@ export class OpenAIBot { const balanceMessage = appText.notEnoughBalance .replaceAll("$CREDITS", balanceOne) .replaceAll("$WALLET_ADDRESS", account?.address || ""); - await ctx - .reply(balanceMessage, { parse_mode: "Markdown" }) - .catch((e) => this.onError(ctx, e)); + await sendMessage(ctx, balanceMessage, { + topicId: ctx.message?.message_thread_id, + parseMode: "Markdown", + }).catch((e) => this.onError(ctx, e)); } async onError( @@ -808,7 +689,15 @@ export class OpenAIBot { return; } if (e instanceof GrammyError) { - if (e.error_code === 429) { + if (e.error_code === 400 && e.description.includes("not enough rights")) { + await sendMessage( + ctx, + "Error: The bot does not have permission to send photos in chat", + { + topicId: ctx.message?.message_thread_id, + } + ); + } else if (e.error_code === 429) { this.botSuspended = true; const retryAfter = e.parameters.retry_after ? e.parameters.retry_after < 60 @@ -818,38 +707,119 @@ export class OpenAIBot { const method = e.method; const errorMessage = `On method "${method}" | ${e.error_code} - ${e.description}`; this.logger.error(errorMessage); - await ctx - .reply( - `${ - ctx.from.username ? ctx.from.username : "" - } Bot has reached limit, wait ${retryAfter} seconds` - ) - .catch((e) => this.onError(ctx, e, retryCount - 1)); + await sendMessage( + ctx, + `${ + ctx.from.username ? ctx.from.username : "" + } Bot has reached limit, wait ${retryAfter} seconds`, + { + topicId: ctx.message?.message_thread_id, + } + ).catch((e) => this.onError(ctx, e, retryCount - 1)); if (method === "editMessageText") { ctx.session.openAi.chatGpt.chatConversation.pop(); //deletes last prompt } await sleep(retryAfter * 1000); // wait retryAfter seconds to enable bot this.botSuspended = false; + } else { + this.logger.error( + `On method "${e.method}" | ${e.error_code} - ${e.description}` + ); } } else if (e instanceof OpenAI.APIError) { // 429 RateLimitError // e.status = 400 || e.code = BadRequestError this.logger.error(`OPENAI Error ${e.status}(${e.code}) - ${e.message}`); if (e.code === "context_length_exceeded") { - await ctx - .reply(`${e.message}`) - .catch((e) => this.onError(ctx, e, retryCount - 1)); + await sendMessage(ctx, e.message, { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e, retryCount - 1)); this.onEnd(ctx); } else { - await ctx - .reply(`Error accessing OpenAI (ChatGPT). Please try later`) - .catch((e) => this.onError(ctx, e, retryCount - 1)); + await sendMessage( + ctx, + `Error accessing OpenAI (ChatGPT). Please try later`, + { + topicId: ctx.message?.message_thread_id, + } + ).catch((e) => this.onError(ctx, e, retryCount - 1)); } } else { this.logger.error(`${e.toString()}`); - await ctx - .reply(msg ? msg : "Error handling your request") - .catch((e) => this.onError(ctx, e, retryCount - 1)); + await sendMessage(ctx, "Error handling your request", { + topicId: ctx.message?.message_thread_id, + }).catch((e) => this.onError(ctx, e, retryCount - 1)); } } } + +// onGenImgEnCmd = async (ctx: OnMessageContext | OnCallBackQueryData) => { +// try { +// if (ctx.session.openAi.imageGen.isEnabled) { +// const prompt = await ctx.match; +// if (!prompt) { +// sendMessage(ctx, "Error: Missing prompt", { +// topicId: ctx.message?.message_thread_id, +// }).catch((e) => +// this.onError(ctx, e, MAX_TRIES, "Error: Missing prompt") +// ); +// return; +// } +// const payload = { +// chatId: await ctx.chat?.id!, +// prompt: prompt as string, +// numImages: await ctx.session.openAi.imageGen.numImages, +// imgSize: await ctx.session.openAi.imageGen.imgSize, +// }; +// sendMessage(ctx, "generating improved prompt...", { +// topicId: ctx.message?.message_thread_id, +// }).catch((e) => +// this.onError(ctx, e, MAX_TRIES, "generating improved prompt...") +// ); +// await imgGenEnhanced(payload, ctx); +// } else { +// sendMessage(ctx, "Bot disabled", { +// topicId: ctx.message?.message_thread_id, +// }).catch((e) => this.onError(ctx, e, MAX_TRIES, "Bot disabled")); +// } +// } catch (e) { +// this.onError(ctx, e); +// } +// }; + +// private async imgGenEnhanced( +// data: ImageGenPayload, +// ctx: OnMessageContext | OnCallBackQueryData +// ) { +// const { chatId, prompt, numImages, imgSize, model } = data; +// try { +// const upgratedPrompt = await improvePrompt(prompt, model!); +// if (upgratedPrompt) { +// await ctx +// .reply( +// `The following description was added to your prompt: ${upgratedPrompt}` +// ) +// .catch((e) => { +// throw e; +// }); +// } +// // bot.api.sendMessage(chatId, "generating the output..."); +// const imgs = await postGenerateImg( +// upgratedPrompt || prompt, +// numImages, +// imgSize +// ); +// imgs.map(async (img: any) => { +// await ctx +// .replyWithPhoto(img.url, { +// caption: `/DALLE ${upgratedPrompt || prompt}`, +// }) +// .catch((e) => { +// throw e; +// }); +// }); +// return true; +// } catch (e) { +// throw e; +// } +// }; diff --git a/src/modules/open-ai/types.ts b/src/modules/open-ai/types.ts index 76934741..4fab0b27 100644 --- a/src/modules/open-ai/types.ts +++ b/src/modules/open-ai/types.ts @@ -1,3 +1,9 @@ +import { + ChatConversation, + OnCallBackQueryData, + OnMessageContext, +} from "../types"; + export interface ChatGPTModel { name: string; inputPrice: number; @@ -10,6 +16,12 @@ export interface DalleGPTModel { price: number; } +export interface ChatGptPayload { + conversation: ChatConversation[]; + model: string; + ctx: OnMessageContext | OnCallBackQueryData; +} + export enum ChatGPTModelsEnum { GPT_4 = "gpt-4", GPT_4_32K = "gpt-4-32k", @@ -22,39 +34,39 @@ export const ChatGPTModels: Record = { name: "gpt-4", inputPrice: 0.03, outputPrice: 0.06, - maxContextTokens: 8192 + maxContextTokens: 8192, }, "gpt-4-32k": { name: "gpt-4-32k", inputPrice: 0.06, outputPrice: 0.12, - maxContextTokens: 32000 + maxContextTokens: 32000, }, "gpt-3.5-turbo": { name: "gpt-3.5-turbo", inputPrice: 0.0015, outputPrice: 0.002, - maxContextTokens: 4000 + maxContextTokens: 4000, }, "gpt-3.5-turbo-16k": { name: "gpt-3.5-turbo-16k", inputPrice: 0.003, outputPrice: 0.004, - maxContextTokens: 16000 + maxContextTokens: 16000, }, }; export const DalleGPTModels: Record = { "1024x1024": { size: "1024x1024", - price: 0.020 + price: 0.02, }, "512x512": { size: "512x512", - price: 0.018 + price: 0.018, }, "256x256": { size: "256x256", - price: 0.016 + price: 0.016, }, }; diff --git a/src/modules/qrcode/QRCodeBot.ts b/src/modules/qrcode/QRCodeBot.ts index 31276870..be46b414 100644 --- a/src/modules/qrcode/QRCodeBot.ts +++ b/src/modules/qrcode/QRCodeBot.ts @@ -1,52 +1,67 @@ -import {Automatic1111Client} from "./Automatic1111Client"; -import {createQRCode, isQRCodeReadable, normalizeUrl, retryAsync} from "./utils"; +import { Automatic1111Client } from "./Automatic1111Client"; +import { + createQRCode, + isQRCodeReadable, + normalizeUrl, + retryAsync, +} from "./utils"; import config from "../../config"; -import {InlineKeyboard, InputFile} from "grammy"; -import {OnCallBackQueryData, OnMessageContext, RefundCallback} from "../types"; -import {Automatic1111Config} from "./Automatic1111Configs"; -import {automatic1111DefaultConfig} from "./Automatic1111DefaultConfig"; -import {ComfyClient} from "./comfy/ComfyClient"; +import { GrammyError, InlineKeyboard, InputFile } from "grammy"; +import { + MessageExtras, + OnCallBackQueryData, + OnMessageContext, + RefundCallback, +} from "../types"; +import { Automatic1111Config } from "./Automatic1111Configs"; +import { automatic1111DefaultConfig } from "./Automatic1111DefaultConfig"; +import { ComfyClient } from "./comfy/ComfyClient"; import crypto from "crypto"; import buildQRWorkflow from "./comfy/buildQRWorkflow"; -import pino, {Logger} from "pino"; +import pino, { Logger } from "pino"; enum SupportedCommands { - QR = 'qr', + QR = "qr", } enum Callbacks { - Regenerate = 'qr-regenerate', + Regenerate = "qr-regenerate", } export class QRCodeBot { - - private logger: Logger + private logger: Logger; constructor() { - this.logger = pino({ - name: 'QRBot', + name: "QRBot", transport: { - target: 'pino-pretty', + target: "pino-pretty", options: { - colorize: true - } - } - }) - + colorize: true, + }, + }, + }); } public getEstimatedPrice(ctx: any) { return 1; // 1.5; } - public isSupportedEvent(ctx: OnMessageContext | OnCallBackQueryData): boolean { - return ctx.hasCommand(Object.values(SupportedCommands)) || ctx.hasCallbackQuery(Object.values(Callbacks)); + public isSupportedEvent( + ctx: OnMessageContext | OnCallBackQueryData + ): boolean { + return ( + ctx.hasCommand(Object.values(SupportedCommands)) || + ctx.hasCallbackQuery(Object.values(Callbacks)) + ); } - public async onEvent(ctx: OnMessageContext | OnCallBackQueryData, refundCallback: RefundCallback) { + public async onEvent( + ctx: OnMessageContext | OnCallBackQueryData, + refundCallback: RefundCallback + ) { if (!this.isSupportedEvent(ctx)) { - await ctx.reply(`Unsupported command: ${ctx.message?.text}`) - return refundCallback('Unsupported command') + await ctx.reply(`Unsupported command: ${ctx.message?.text}`); + return refundCallback("Unsupported command"); } try { @@ -54,82 +69,89 @@ export class QRCodeBot { try { await ctx.answerCallbackQuery(); } catch (ex) { - console.log('### ex', ex); + console.log("### ex", ex); } - const msg = ctx.callbackQuery.message?.text || ctx.callbackQuery.message?.caption || ''; + const msg = + ctx.callbackQuery.message?.text || + ctx.callbackQuery.message?.caption || + ""; if (!msg) { - await ctx.reply('Error: message is too old'); - return refundCallback('Error: message is too old') + await ctx.reply("Error: message is too old"); + return refundCallback("Error: message is too old"); } const cmd = this.parseQrCommand(msg); if (cmd.error || !cmd.command || !cmd.url || !cmd.prompt) { - await ctx.reply('Message haven\'t contain command: ' + msg); - return refundCallback('Message haven\'t contain command: ') + await ctx.reply("Message haven't contain command: " + msg); + return refundCallback("Message haven't contain command: "); } if (cmd.command === SupportedCommands.QR) { - return this.onQr(ctx, msg, 'img2img'); + return this.onQr(ctx, msg, "img2img"); } } if (ctx.hasCommand(SupportedCommands.QR)) { - return this.onQr(ctx, ctx.message.text, 'img2img'); + return this.onQr(ctx, ctx.message.text, "img2img"); } } catch (ex) { if (ex instanceof Error) { - this.logger.info('Error ' + ex.message); + this.logger.info("Error " + ex.message); return refundCallback(ex.message); } - this.logger.info('Error ' + ex); - return refundCallback('Unknown error'); + this.logger.info("Error " + ex); + return refundCallback("Unknown error"); } - await ctx.reply('Unsupported command'); - this.logger.info('Unsupported command'); - return refundCallback('Unsupported command'); + await ctx.reply("Unsupported command"); + this.logger.info("Unsupported command"); + return refundCallback("Unsupported command"); } public parseQrCommand(message: string) { // command: /qr url prompt1, prompt2, prompt3 - if (!message.startsWith('/')) { + if (!message.startsWith("/")) { return { - command: '', - url: '', - prompt: '', + command: "", + url: "", + prompt: "", error: true, - } + }; } - const [command, url, ...rest] = message.split(' '); + const [command, url, ...rest] = message.split(" "); return { - command: command.replace('/', ''), + command: command.replace("/", ""), url, - prompt: rest.join(' '), - } + prompt: rest.join(" "), + }; } - private async onQr(ctx: OnMessageContext | OnCallBackQueryData, message: string, method: 'txt2img' | 'img2img') { - this.logger.info('generate qr'); + private async onQr( + ctx: OnMessageContext | OnCallBackQueryData, + message: string, + method: "txt2img" | "img2img" + ) { + this.logger.info("generate qr"); const command = this.parseQrCommand(message); if (command.error || !command.command || !command.url || !command.prompt) { - command.url = 'https://s.country/ai'; - command.prompt = 'astronaut, exuberant, anime girl, smile, sky, colorful' -// ctx.reply(` -// Please add -// -// /qr h.country/ai Dramatic bonfire on a remote beach, captured at the magic hour with flames dancing against the twilight sky; using a shallow depth of field, a fast lens, and controlled exposure to emphasize the intricate patterns and textures of the fire, complemented by embers in the wind and the gentle glow reflecting on the ocean's edge, moody, intense, and alive.`, { -// disable_web_page_preview: true, -// }); -// return + command.url = "https://s.country/ai"; + command.prompt = "astronaut, exuberant, anime girl, smile, sky, colorful"; + // ctx.reply(` + // Please add + // + // /qr h.country/ai Dramatic bonfire on a remote beach, captured at the magic hour with flames dancing against the twilight sky; using a shallow depth of field, a fast lens, and controlled exposure to emphasize the intricate patterns and textures of the fire, complemented by embers in the wind and the gentle glow reflecting on the ocean's edge, moody, intense, and alive.`, { + // disable_web_page_preview: true, + // }); + // return } // ctx.reply(`Generating...`); @@ -145,84 +167,160 @@ export class QRCodeBot { method, prompt: command.prompt, }; - const qrImgBuffer = await this.genQRCodeByComfyUI(props); - if (!qrImgBuffer) { - throw new Error('internal error'); + throw new Error("internal error"); } - - if(config.qrBot.checkReadable && isQRCodeReadable(qrImgBuffer)) { - console.log('### qr unreadable'); + if (config.qrBot.checkReadable && isQRCodeReadable(qrImgBuffer)) { + console.log("### qr unreadable"); return qrImgBuffer; } - return qrImgBuffer; - } + }; let qrImgBuffer; try { ctx.chatAction = "upload_photo"; qrImgBuffer = await retryAsync(operation, 5, 100); - } catch (ex) { ctx.chatAction = null; this.logger.error(`ex ${ex}`); - await ctx.reply("Internal error") - throw new Error('Internal error'); + await ctx.reply("Internal error"); + throw new Error("Internal error"); } - const regenButton = new InlineKeyboard() - .text("Regenerate", Callbacks.Regenerate) - + const regenButton = new InlineKeyboard().text( + "Regenerate", + Callbacks.Regenerate + ); - await ctx.replyWithPhoto(new InputFile(qrImgBuffer, `qr_code_${Date.now()}.png`), { - caption: `/qr ${command.url} ${command.prompt}`, - reply_markup: regenButton, - }) - this.logger.info('sent qr code'); - return true; + try { + await ctx.replyWithPhoto( + new InputFile(qrImgBuffer, `qr_code_${Date.now()}.png`), + { + caption: `/qr ${command.url} ${command.prompt}`, + reply_markup: regenButton, + } + ); + this.logger.info("sent qr code"); + return true; + } catch (e: any) { + const topicId = await ctx.message?.message_thread_id; + let msgExtras: MessageExtras = {}; + if (topicId) { + msgExtras["message_thread_id"] = topicId; + } + if (e instanceof GrammyError) { + if ( + e.error_code === 400 && + e.description.includes("not enough rights") + ) { + ctx.reply( + `Error: The bot does not have permission to send photos in chat...`, + msgExtras + ); + } else { + ctx.reply( + `Error: something went wrong...`, + msgExtras + ); + } + } else { + this.logger.error(e.toString()); + ctx.reply( + `Error: something went wrong...`, + msgExtras + ); + } + return false; + } } - private async genQRCode({qrUrl, qrMargin, prompt, method}: {qrUrl: string, qrMargin: number, prompt: string, method: 'img2img' | 'txt2img'}) { - const qrImgBuffer = await createQRCode({url: qrUrl, margin: qrMargin }); + private async genQRCode({ + qrUrl, + qrMargin, + prompt, + method, + }: { + qrUrl: string; + qrMargin: number; + prompt: string; + method: "img2img" | "txt2img"; + }) { + const qrImgBuffer = await createQRCode({ url: qrUrl, margin: qrMargin }); const sdClient = new Automatic1111Client(); - const extendedPrompt = prompt + ', ' + automatic1111DefaultConfig.additionalPrompt; + const extendedPrompt = + prompt + ", " + automatic1111DefaultConfig.additionalPrompt; const negativePrompt = automatic1111DefaultConfig.defaultNegativePrompt; const sdConfig: Automatic1111Config = { - imgBase64: qrImgBuffer.toString('base64'), + imgBase64: qrImgBuffer.toString("base64"), prompt: extendedPrompt, negativePrompt, }; - if (method === 'txt2img') { - return sdClient.text2img({...automatic1111DefaultConfig.text2img, ...sdConfig}); + if (method === "txt2img") { + return sdClient.text2img({ + ...automatic1111DefaultConfig.text2img, + ...sdConfig, + }); } - return sdClient.img2img({...automatic1111DefaultConfig.img2img, ...sdConfig}); + return sdClient.img2img({ + ...automatic1111DefaultConfig.img2img, + ...sdConfig, + }); } - private async genQRCodeByComfyUI({qrUrl, qrMargin, prompt, method}: {qrUrl: string, qrMargin: number, prompt: string, method: 'img2img' | 'txt2img'}) { - const qrImgBuffer = await createQRCode({url: normalizeUrl(qrUrl), width: 680, margin: qrMargin }); - const extendedPrompt = prompt + ', ' + automatic1111DefaultConfig.additionalPrompt; + private async genQRCodeByComfyUI({ + qrUrl, + qrMargin, + prompt, + method, + }: { + qrUrl: string; + qrMargin: number; + prompt: string; + method: "img2img" | "txt2img"; + }) { + const qrImgBuffer = await createQRCode({ + url: normalizeUrl(qrUrl), + width: 680, + margin: qrMargin, + }); + const extendedPrompt = + prompt + ", " + automatic1111DefaultConfig.additionalPrompt; const negativePrompt = automatic1111DefaultConfig.defaultNegativePrompt; - const comfyClient = new ComfyClient({host: config.comfyHost2, wsHost: config.comfyWsHost2}); - - const filenameHash = crypto.createHash('sha256').update(qrUrl, 'utf8'); - const filename = filenameHash.digest('hex') + '.png'; - - const uploadResult = await comfyClient.uploadImage({filename, fileBuffer: qrImgBuffer, override: true}); - - const workflow = buildQRWorkflow({qrFilename: uploadResult.name, clientId: comfyClient.clientId, negativePrompt, prompt: extendedPrompt}) + const comfyClient = new ComfyClient({ + host: config.comfyHost2, + wsHost: config.comfyWsHost2, + }); + + const filenameHash = crypto.createHash("sha256").update(qrUrl, "utf8"); + const filename = filenameHash.digest("hex") + ".png"; + const uploadResult = await comfyClient.uploadImage({ + filename, + fileBuffer: qrImgBuffer, + override: true, + }); + + const workflow = buildQRWorkflow({ + qrFilename: uploadResult.name, + clientId: comfyClient.clientId, + negativePrompt, + prompt: extendedPrompt, + }); const response = await comfyClient.queuePrompt(workflow); - const promptResult = await comfyClient.waitingPromptExecution(response.prompt_id); + const promptResult = await comfyClient.waitingPromptExecution( + response.prompt_id + ); comfyClient.abortWebsocket(); - - return comfyClient.downloadResult(promptResult.data.output.images[0].filename); + return comfyClient.downloadResult( + promptResult.data.output.images[0].filename + ); } } diff --git a/src/modules/sd-images/SDImagesBotBase.ts b/src/modules/sd-images/SDImagesBotBase.ts index e5f0e0c7..68b3a97a 100644 --- a/src/modules/sd-images/SDImagesBotBase.ts +++ b/src/modules/sd-images/SDImagesBotBase.ts @@ -1,8 +1,9 @@ import { SDNodeApi, IModel } from "./api"; -import { OnMessageContext, OnCallBackQueryData } from "../types"; +import { OnMessageContext, OnCallBackQueryData, MessageExtras } from "../types"; import { getTelegramFileUrl, loadFile, sleep, uuidv4 } from "./utils"; -import { InputFile } from "grammy"; +import { GrammyError, InputFile } from "grammy"; import { COMMAND } from './helpers'; +import { Logger, pino } from "pino"; import { ILora } from "./api/loras-config"; export interface ISession { @@ -19,12 +20,22 @@ export interface ISession { export class SDImagesBotBase { sdNodeApi: SDNodeApi; + private logger: Logger; private sessions: ISession[] = []; queue: string[] = []; constructor() { this.sdNodeApi = new SDNodeApi(); + this.logger = pino({ + name: "SDImagesBotBase", + transport: { + target: "pino-pretty", + options: { + colorize: true, + }, + }, + }); } createSession = async ( @@ -68,9 +79,13 @@ export class SDImagesBotBase { this.queue.push(uuid); let idx = this.queue.findIndex((v) => v === uuid); - + const topicId = await ctx.message?.message_thread_id + let msgExtras: MessageExtras = {} + if (topicId) { + msgExtras['message_thread_id'] = topicId + } const { message_id } = await ctx.reply( - `You are #${idx + 1}, wait about ${(idx + 1) * 15} seconds` + `You are #${idx + 1}, wait about ${(idx + 1) * 15} seconds`, msgExtras ); // waiting queue @@ -108,18 +123,36 @@ export class SDImagesBotBase { `${session.message} ${prompt}` : `/${model.aliases[0]} ${prompt}`; - - await ctx.replyWithPhoto(new InputFile(imageBuffer), { - caption: reqMessage, - }); + const topicId = await ctx.message?.message_thread_id + let msgExtras: MessageExtras = { + caption: reqMessage + } + if (topicId) { + msgExtras['message_thread_id'] = topicId + } + await ctx.replyWithPhoto(new InputFile(imageBuffer),msgExtras); if (ctx.chat?.id && queueMessageId) { await ctx.api.deleteMessage(ctx.chat?.id, queueMessageId); } - } catch (e) { - console.error(e); - ctx.reply(`Error: something went wrong... Refunding payments`); - refundCallback(); + } catch (e: any) { + ctx.chatAction = null + const topicId = await ctx.message?.message_thread_id + let msgExtras: MessageExtras = {} + if (topicId) { + msgExtras['message_thread_id'] = topicId + } + if (e instanceof GrammyError) { + if (e.error_code === 400 && e.description.includes('not enough rights')) { + ctx.reply(`Error: The bot does not have permission to send photos in chat... Refunding payments`, msgExtras); + } else { + ctx.reply(`Error: something went wrong... Refunding payments`, msgExtras) + } + } else { + this.logger.error(e.toString()); + ctx.reply(`Error: something went wrong... Refunding payments`, msgExtras); + } + refundCallback() } this.queue = this.queue.filter((v) => v !== uuid); @@ -178,7 +211,11 @@ export class SDImagesBotBase { `${session.message} ${prompt}` : `/${model.aliases[0]} ${prompt}`; - + const topicId = await ctx.message?.message_thread_id + let msgExtras: MessageExtras = {} + if (topicId) { + msgExtras['message_thread_id'] = topicId + } await ctx.replyWithMediaGroup([ { type: "photo", @@ -190,14 +227,27 @@ export class SDImagesBotBase { media: new InputFile(imageBuffer), // caption: reqMessage, } - ]); + ],msgExtras); if (ctx.chat?.id && queueMessageId) { await ctx.api.deleteMessage(ctx.chat?.id, queueMessageId); } - } catch (e) { - console.error(e); - ctx.reply(`Error: something went wrong... Refunding payments`); + } catch (e: any) { + const topicId = await ctx.message?.message_thread_id + let msgExtras: MessageExtras = {} + if (topicId) { + msgExtras['message_thread_id'] = topicId + } + if (e instanceof GrammyError) { + if (e.error_code === 400 && e.description.includes('not enough rights')) { + ctx.reply(`Error: The bot does not have permission to send photos in chat... Refunding payments`, msgExtras); + } else { + ctx.reply(`Error: something went wrong... Refunding payments`, msgExtras) + } + } else { + this.logger.error(e.toString()); + ctx.reply(`Error: something went wrong... Refunding payments`, msgExtras); + } refundCallback(); } diff --git a/src/modules/types.ts b/src/modules/types.ts index 335b5623..da11cfc8 100644 --- a/src/modules/types.ts +++ b/src/modules/types.ts @@ -6,12 +6,19 @@ import { type ConversationFlavor, } from "@grammyjs/conversations"; import { AutoChatActionFlavor } from "@grammyjs/auto-chat-action"; +import { ParseMode } from "grammy/types"; + export interface ImageGenSessionData { numImages: number; imgSize: string; isEnabled: boolean; } +export interface MessageExtras { + caption?: string; + message_thread_id?: number; + parse_mode?: ParseMode; +} export interface ChatCompletion { completion: string; usage: number; @@ -40,14 +47,14 @@ export interface OneCountryData { } export interface TranslateBotData { - languages: string[], - enable: boolean, + languages: string[]; + enable: boolean; } export interface BotSessionData { oneCountry: OneCountryData; openAi: OpenAiSessionData; - translate: TranslateBotData + translate: TranslateBotData; } export type BotContext = Context &