Skip to content

Commit

Permalink
Markdown & Stream regression
Browse files Browse the repository at this point in the history
  • Loading branch information
ruecat committed May 12, 2024
1 parent 43b6498 commit ebc78f2
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 167 deletions.
50 changes: 18 additions & 32 deletions bot/func/functions.py → bot/func/interactions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# >> interactions
import logging
import os
import aiohttp
Expand All @@ -6,31 +7,19 @@
from asyncio import Lock
from functools import wraps
from dotenv import load_dotenv
# --- Environment
load_dotenv()
# --- Environment Checker
token = os.getenv("TOKEN")
allowed_ids = list(map(int, os.getenv("USER_IDS", "").split(",")))
admin_ids = list(map(int, os.getenv("ADMIN_IDS", "").split(",")))
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
ollama_port = os.getenv("OLLAMA_PORT", "11434")
log_level_str = os.getenv("LOG_LEVEL", "INFO")

# --- Other
log_levels = list(logging._levelToName.values())
# ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']

# Set default level to be INFO
if log_level_str not in log_levels:
log_level = logging.DEBUG
else:
log_level = logging.getLevelName(log_level_str)

logging.basicConfig(level=log_level)


# Ollama API
# Model List
async def model_list():
async with aiohttp.ClientSession() as session:
url = f"http://{ollama_base_url}:{ollama_port}/api/tags"
Expand All @@ -41,20 +30,27 @@ async def model_list():
else:
return []
async def generate(payload: dict, modelname: str, prompt: str):
# try:
async with aiohttp.ClientSession() as session:
url = f"http://{ollama_base_url}:{ollama_port}/api/chat"

# Stream from API
async with session.post(url, json=payload) as response:
async for chunk in response.content:
if chunk:
decoded_chunk = chunk.decode()
if decoded_chunk.strip():
yield json.loads(decoded_chunk)
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
raise aiohttp.ClientResponseError(
status=response.status, message=response.reason
)
buffer = b""

async for chunk in response.content.iter_any():
buffer += chunk
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
line = line.strip()
if line:
yield json.loads(line)
except aiohttp.ClientError as e:
print(f"Error during request: {e}")


# Aiogram functions & wraps
def perms_allowed(func):
@wraps(func)
async def wrapper(message: types.Message = None, query: types.CallbackQuery = None):
Expand Down Expand Up @@ -103,16 +99,6 @@ async def wrapper(message: types.Message = None, query: types.CallbackQuery = No
)

return wrapper


def md_autofixer(text: str) -> str:
# In MarkdownV2, these characters must be escaped: _ * [ ] ( ) ~ ` > # + - = | { } . !
escape_chars = r"_[]()~>#+-=|{}.!"
# Use a backslash to escape special characters
return "".join("\\" + char if char in escape_chars else char for char in text)


# Context-Related
class contextLock:
lock = Lock()

Expand Down
Loading

0 comments on commit ebc78f2

Please sign in to comment.