Skip to content

Commit

Permalink
接入ruff、black与isort并配置pre-commit (#46)
Browse files Browse the repository at this point in the history
* 接入ruff、black与isort并配置pre-commit

* 修复导入错误
  • Loading branch information
ssttkkl authored Dec 11, 2023
1 parent 1a2e458 commit 53f2142
Show file tree
Hide file tree
Showing 16 changed files with 1,692 additions and 999 deletions.
27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
default_install_hook_types: [pre-commit, prepare-commit-msg]
ci:
autofix_commit_msg: ":rotating_light: auto fix by pre-commit hooks"
autofix_prs: true
autoupdate_branch: master
autoupdate_schedule: monthly
autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
hooks:
- id: ruff
args: [--fix]
stages: [commit]

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
stages: [commit]

- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
stages: [commit]

1 change: 0 additions & 1 deletion nonebot/adapters/kaiheila/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
from .adapter import Adapter as Adapter
from .message import Message as Message
from .message import MessageSegment as MessageSegment

144 changes: 95 additions & 49 deletions nonebot/adapters/kaiheila/adapter.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,53 @@
import asyncio
import inspect
import json
import re
import json
import zlib
from typing import Any, Dict, List, Type, Union, Callable, Optional, Mapping
import asyncio
import inspect
from typing_extensions import override
from typing import Any, Dict, List, Type, Union, Mapping, Callable, Optional

from nonebot.adapters import Adapter as BaseAdapter
from pygtrie import StringTrie
from pydantic import parse_obj_as
from nonebot.utils import escape_tag
from nonebot.internal.driver import Response
from nonebot.drivers import (
URL,
Driver,
Request,
WebSocket,
ForwardDriver,
HTTPClientMixin,
WebSocketClientMixin,
)
from nonebot.internal.driver import Response
from nonebot.utils import escape_tag
from pydantic import parse_obj_as
from pygtrie import StringTrie
from typing_extensions import override

from nonebot.adapters import Adapter as BaseAdapter

from . import event
from .api.handle import get_api_method, get_api_restype
from .bot import Bot
from .config import Config as KaiheilaConfig, BotConfig
from .event import *
from .event import OriginEvent
from .exception import ApiNotAvailable, ReconnectError, TokenError, UnauthorizedException, RateLimitException, \
ActionFailed, KaiheilaAdapterException, NetworkError
from .api.model import User
from .config import BotConfig
from .config import Config as KaiheilaConfig
from .message import Message, MessageSegment
from .api.handle import get_api_method, get_api_restype
from .utils import ResultStore, log, _handle_api_result
from .event import (
Event,
EventTypes,
OriginEvent,
SignalTypes,
HeartbeatMetaEvent,
LifecycleMetaEvent,
)
from .exception import (
TokenError,
ActionFailed,
NetworkError,
ReconnectError,
ApiNotAvailable,
RateLimitException,
UnauthorizedException,
KaiheilaAdapterException,
)

RECONNECT_INTERVAL = 3.0

Expand Down Expand Up @@ -109,18 +127,21 @@ async def _call_api(self, bot: Bot, api: str, **data) -> Any:
api = api.replace("_", "/")

if api.startswith("/api/v3/"):
api = api[len("/api/v3/"):]
api = api[len("/api/v3/") :]
elif api.startswith("api/v3"):
api = api[len("api/v3"):]
api = api[len("api/v3") :]
api = api.strip("/")
return await self._do_call_api(api, data, bot.token)

else:
raise ApiNotAvailable

async def _do_call_api(self, api: str,
data: Optional[Mapping[str, Any]] = None,
token: Optional[str] = None) -> Any:
async def _do_call_api(
self,
api: str,
data: Optional[Mapping[str, Any]] = None,
token: Optional[str] = None,
) -> Any:
log("DEBUG", f"Calling API <y>{api}</y>")
data = dict(data) if data is not None else {}

Expand Down Expand Up @@ -169,9 +190,11 @@ async def _get_bot_info(self, token: str) -> User:
return await self._do_call_api("user/me", token=token)

async def _get_gateway(self, token: str) -> URL:
result = await self._do_call_api("gateway/index",
data={"compress": 1 if self.kaiheila_config.compress else 0},
token=token)
result = await self._do_call_api(
"gateway/index",
data={"compress": 1 if self.kaiheila_config.compress else 0},
token=token,
)
return result.url

async def start_forward(self) -> None:
Expand All @@ -185,7 +208,7 @@ async def stop_forward(self) -> None:

await asyncio.gather(
*(asyncio.wait_for(task, timeout=10) for task in self.tasks),
return_exceptions=True
return_exceptions=True,
)

async def _forward_ws(self, bot_config: BotConfig) -> None:
Expand All @@ -212,6 +235,7 @@ async def _forward_ws(self, bot_config: BotConfig) -> None:
"Trying to reconnect...</bg #f8bbd0></r>",
e,
)
continue

headers = {}
if bot_config.token:
Expand All @@ -224,7 +248,11 @@ async def _forward_ws(self, bot_config: BotConfig) -> None:
f"WebSocket Connection to {escape_tag(str(url))} established",
)
try:
data_decompress_func = zlib.decompress if self.kaiheila_config.compress else lambda x: x
data_decompress_func = (
zlib.decompress
if self.kaiheila_config.compress
else lambda x: x
)
while True:
data = await ws.receive()
data = data_decompress_func(data)
Expand All @@ -234,18 +262,22 @@ async def _forward_ws(self, bot_config: BotConfig) -> None:
continue
if not bot:
if (
not isinstance(event, LifecycleMetaEvent)
or event.sub_type != "connect"
not isinstance(event, LifecycleMetaEvent)
or event.sub_type != "connect"
):
continue
bot_info = await self._get_bot_info(bot_config.token)
self_id = bot_info.id_
bot = Bot(self, self_id, bot_info.username, bot_config.token)
bot = Bot(
self, self_id, bot_info.username, bot_config.token
)
self.connections[self_id] = ws
self.bot_connect(bot)

# start heartbeat
heartbeat_task = asyncio.create_task(self.start_heartbeat(bot))
heartbeat_task = asyncio.create_task(
self.start_heartbeat(bot)
)

log(
"INFO",
Expand Down Expand Up @@ -275,7 +307,7 @@ async def _forward_ws(self, bot_config: BotConfig) -> None:

try:
await ws.close()
except Exception:
except: # noqa: E722
pass

if bot:
Expand Down Expand Up @@ -303,24 +335,32 @@ async def start_heartbeat(self, bot: Bot) -> None:
if self.connections.get(bot.self_id).closed:
break
try:
await self.connections.get(bot.self_id).send(json.dumps({
"s": 2,
"sn": ResultStore.get_sn(bot.self_id) # 客户端目前收到的最新的消息 sn
}))
await self.connections.get(bot.self_id).send(
json.dumps(
{
"s": 2,
"sn": ResultStore.get_sn(bot.self_id), # 客户端目前收到的最新的消息 sn
}
)
)
await asyncio.sleep(26)
except asyncio.CancelledError:
raise
except Exception as e:
log("ERROR",
log(
"ERROR",
"<r><bg #f8bbd0>Error while sending heartbeat for bot"
f"{escape_tag(bot.self_id)}. Will retry after 1s ...</bg #f8bbd0></r>",
e)
e,
)
await asyncio.sleep(1)

@classmethod
def json_to_event(
cls, json_data: Any, self_id: Optional[str] = None,
) -> Optional[Event]:
cls,
json_data: Any,
self_id: Optional[str] = None,
) -> Union[OriginEvent, Event, None]:
if not isinstance(json_data, dict):
return None

Expand All @@ -340,9 +380,7 @@ def json_to_event(
elif json_data["d"]["code"] == 40102:
raise TokenError("token 验证失败")
elif signal == SignalTypes.PONG:
data = {}
data["post_type"] = "meta_event"
data["meta_event_type"] = "heartbeat"
data = {"post_type": "meta_event", "meta_event_type": "heartbeat"}
log(
"TRACE",
f"<y>Bot {escape_tag(str(self_id))}</y> HeartBeat",
Expand All @@ -367,7 +405,9 @@ def json_to_event(
data["self_id"] = self_id
data["group_id"] = data.get("target_id")
data["time"] = data.get("msg_timestamp")
data["user_id"] = data.get("author_id") if data.get("author_id") != "1" else "SYSTEM"
data["user_id"] = (
data.get("author_id") if data.get("author_id") != "1" else "SYSTEM"
)

if data["type"] == EventTypes.sys:
data["post_type"] = "notice"
Expand All @@ -378,9 +418,15 @@ def json_to_event(
# data['notice_type'] = 'private' if data['notice_type'] == 'person' else data['notice_type']
else:
data["post_type"] = "message"
data["sub_type"] = [i.name.lower() for i in EventTypes if i.value == extra.get("type")][0]
data["sub_type"] = [
i.name.lower() for i in EventTypes if i.value == extra.get("type")
][0]
data["message_type"] = data.get("channel_type").lower()
data["message_type"] = "private" if data["message_type"] == "person" else data["message_type"]
data["message_type"] = (
"private"
if data["message_type"] == "person"
else data["message_type"]
)
data["extra"]["content"] = data.get("content")
data["event"] = data["extra"]

Expand Down Expand Up @@ -428,12 +474,12 @@ def get_event_model(cls, event_name: str) -> List[Type[Event]]:
- ``List[Type[Event]]``
"""
return [model.value for model in cls.event_models.prefixes("." + event_name)][
::-1
]
::-1
]

@classmethod
def custom_send(
cls,
send_func: Callable[[Bot, Event, Union[str, Message, MessageSegment]], None],
cls,
send_func: Callable[[Bot, Event, Union[str, Message, MessageSegment]], None],
):
setattr(Bot, "send_handler", send_func)
2 changes: 1 addition & 1 deletion nonebot/adapters/kaiheila/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .client import ApiClient as ApiClient
from .model import *
from .client import ApiClient as ApiClient
Loading

0 comments on commit 53f2142

Please sign in to comment.