-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
245 lines (200 loc) · 7.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import asyncio
from collections import defaultdict
from pathlib import Path
from time import time
import httpx
import translators as ts
from pyrogram import Client, errors, filters, raw
from pyrogram.raw.types.messages import StickerSet
from pyrogram.types import Sticker as Stk
from rapidocr_onnxruntime import RapidOCR
from sqlalchemy import or_, select
from config.config import admin
from database import AutoIndexSticker, DBSession, RecentlyUsed, Sticker
def get_sticker_pack_name(client: Client, set_name):
try:
info: StickerSet = client.invoke(
raw.functions.messages.GetStickerSet(
stickerset=raw.types.InputStickerSetShortName(short_name=set_name),
hash=0,
)
)
except errors.StickersetInvalid:
return []
return info.set.title
# 获取贴纸包名称
async def async_get_sticker_pack_name(client: Client, set_name):
try:
info: StickerSet = await client.invoke(
raw.functions.messages.GetStickerSet(
stickerset=raw.types.InputStickerSetShortName(short_name=set_name),
hash=0,
)
)
except errors.StickersetInvalid:
return []
return info.set.title
# 获取贴纸中所有贴纸
def parse_stickers(client: Client, set_name):
try:
info: StickerSet = client.invoke(
raw.functions.messages.GetStickerSet(
stickerset=raw.types.InputStickerSetShortName(short_name=set_name),
hash=0,
)
)
except errors.StickersetInvalid:
return []
documents = info.documents
final = []
title = info.set.title
count = info.set.count
short_name = info.set.short_name
for stk in documents:
__sticker = asyncio.run(
Stk._parse(client, stk, {type(i): i for i in stk.attributes})
)
final.append(__sticker)
return {"title": title, "count": count, "short_name": short_name, "final": final}
# 获取自动索引的贴纸包
def get_auto_indexed_packages(set_name, uid):
with DBSession.begin() as session:
stmt = select(AutoIndexSticker).filter(
AutoIndexSticker.uid == uid, AutoIndexSticker.set_name == set_name
)
return session.execute(stmt).scalars().one_or_none()
# 过滤指定字符开头的内联查询结果
def filter_inline_query_results(command: str):
"""
过滤指定字符开头的内联查询结果
:param command:
:return:
"""
async def func(_, __, update):
return update.query.startswith(command)
return filters.create(func, name="InlineQueryResultFilter", commands=command)
def stick_find(query, uid) -> list[Sticker]:
if query:
stmt = select(Sticker).filter(
or_(
Sticker.tag.ilike(f"%{query}%"),
Sticker.emoji.ilike(f"%{query}%"),
Sticker.title.ilike(f"%{query}%"),
Sticker.set_name == query,
Sticker.sticker_unique_id == query,
),
Sticker.uid == uid,
)
else:
stmt = select(Sticker).filter(Sticker.uid == uid)
# 按贴纸包名和emoji升序排序
stmt_asc = stmt.order_by(Sticker.set_name.asc(), Sticker.time.asc())
session = DBSession()
return session.execute(stmt_asc).scalars().all()
def recently_used_find(uid) -> list[Sticker]:
stmt = (
select(RecentlyUsed)
.filter(RecentlyUsed.uid == uid)
.order_by(RecentlyUsed.time.asc())
)
session = DBSession()
return session.execute(stmt).scalars().all()
# _ocr = PaddleOCR(use_angle_cls=True, enable_mkldnn=True, ocr_version='PP-OCRv4')
# def ocr(path: str) -> list:
# result = _ocr.ocr(path, cls=False)
# return [i[1][0] for i in result[0]]
# 最优组合为:
# ch_PP-OCRv3_det + ch_ppocr_mobile_v2.0_cls + ch_PP-OCRv3_rec
# 和v4速度相差不大,文字检测不如v4
# rapid_ocr = RapidOCR(
# det_model_path="resources/models/ch_PP-OCRv3_det_infer.onnx", # 指定检测模型文件路径
# cls_model_path="resources/models/ch_ppocr_mobile_v2.0_cls_infer.onnx", # 指定方向分类模型文件路径
# rec_model_path="resources/models/ch_PP-OCRv3_rec_infer.onnx", # 指定识别模型文件路径
# )
rapid_ocr = RapidOCR()
def ocr_rapid(path) -> list[None | str]:
result, _ = rapid_ocr(path, text_score=0.4, use_angle_cls=False)
return [i[1] for i in result] if result else []
def get_sticker_id(sid: str) -> str:
i = sid.split("_")
return "_".join(i[1:]) if i[1:] else i[0]
requests = defaultdict(int)
last_request_time = defaultdict(int)
# 速率限制
def rate_limit(request_limit=3, time_limit=60):
def decorator(func):
def wrapper(client, message):
user_id = message.from_user.id
current_time = time()
if current_time - last_request_time[user_id] > time_limit:
requests[user_id] = 1
last_request_time[user_id] = current_time
else:
if requests[user_id] >= request_limit:
return message.reply(
f"速率限制:{request_limit}张/{time_limit}秒,请稍后再试"
) # 超过限制次数,不处理请求
requests[user_id] += 1
func(client, message) # 调用原函数
return wrapper
return decorator
def is_admin():
async def func(_, __, update):
return not admin or update.from_user.id == admin
return filters.create(func)
def azure_img_tag(path: str | Path) -> str:
params = {
"features": "tags",
"language": "zh",
}
with open(path, "rb") as f:
data = {"file": f}
response = httpx.post(
"https://portal.vision.cognitive.azure.com/api/demo/analyze",
params=params,
files=data,
)
if response.status_code == 200:
response = response.json()
else:
return []
return [i["name"] for i in response["tagsResult"]["values"]]
# 微软试用接口 图片标题
def azure_img_caption(path: str | Path) -> str:
params = {
"features": "caption",
"language": "en",
}
with open(path, "rb") as f:
data = {"file": f}
response = httpx.post(
"https://portal.vision.cognitive.azure.com/api/demo/analyze",
params=params,
files=data,
)
if response.status_code == 200:
response = response.json()
else:
raise Exception("识别失败")
text = response["captionResult"]["text"]
try:
text = ts.translate_text(text, "google", to_language="zh")
finally:
return text
# 微软试用接口 OCR
def azure_ocr(path: str | Path) -> list[str]:
params = {
"features": "read",
}
with open(path, "rb") as f:
data = {"file": f}
response = httpx.post(
"https://portal.vision.cognitive.azure.com/api/demo/analyze",
params=params,
files=data,
)
if response.status_code == 200:
response = response.json()
else:
raise Exception("识别失败")
return [line["content"] for line in response["readResult"]["pages"][0]["lines"]]