Skip to content

Commit

Permalink
Merge pull request #49 from he0119/feat/pyd2
Browse files Browse the repository at this point in the history
feat: upgrade to pydantic v2
  • Loading branch information
Tian-que authored Feb 16, 2024
2 parents 8a13a61 + f10fb29 commit 7627a49
Show file tree
Hide file tree
Showing 8 changed files with 338 additions and 548 deletions.
17 changes: 9 additions & 8 deletions nonebot/adapters/kaiheila/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Any, Dict, List, Tuple, Type, Union, Mapping, Callable, Optional

from pygtrie import StringTrie
from pydantic import parse_obj_as
from nonebot.utils import escape_tag
from nonebot.internal.driver import Response
from nonebot.compat import model_dump, type_validate_python
from nonebot.drivers import (
URL,
Driver,
Expand All @@ -20,6 +20,7 @@
WebSocketClientMixin,
)

from nonebot import get_plugin_config
from nonebot.adapters import Adapter as BaseAdapter

from . import event
Expand Down Expand Up @@ -64,7 +65,7 @@ class Adapter(BaseAdapter):
@override
def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
self.kaiheila_config: KaiheilaConfig = KaiheilaConfig(**self.config.dict())
self.kaiheila_config: KaiheilaConfig = get_plugin_config(KaiheilaConfig)
self.api_root = "https://www.kaiheila.cn/api/v3/"
self.connections: Dict[str, WebSocket] = {}
self.tasks: List[asyncio.Task] = []
Expand Down Expand Up @@ -182,7 +183,7 @@ async def _do_call_api(
try:
resp = await self.request(request)
result = _handle_api_result(resp)
return parse_obj_as(result_type, result) if result_type else None
return type_validate_python(result_type, result) if result_type else None
except Exception as e:
raise e

Expand Down Expand Up @@ -378,7 +379,7 @@ def json_to_event(
data["post_type"] = "meta_event"
data["sub_type"] = "connect"
data["meta_event_type"] = "lifecycle"
return LifecycleMetaEvent.parse_obj(data)
return type_validate_python(LifecycleMetaEvent, data)
elif json_data["d"]["code"] == 40103:
raise ReconnectError
elif json_data["d"]["code"] == 40101:
Expand All @@ -391,7 +392,7 @@ def json_to_event(
"TRACE",
f"<y>Bot {escape_tag(str(self_id))}</y> HeartBeat",
)
return HeartbeatMetaEvent.parse_obj(data)
return type_validate_python(HeartbeatMetaEvent, data)
elif signal == SignalTypes.EVENT:
ResultStore.set_sn(self_id, json_data["sn"])
elif signal == SignalTypes.RECONNECT:
Expand Down Expand Up @@ -452,13 +453,13 @@ def json_to_event(
models = cls.get_event_model(event_name)
for model in models:
try:
event = model.parse_obj(data)
event = type_validate_python(model, data)
break
except Exception as e:
log("DEBUG", "Event Parser Error", e)
else:
event = Event.parse_obj(json_data)
log("DEBUG", str(event.dict()))
event = type_validate_python(Event, json_data)
log("DEBUG", str(model_dump(event)))
return event
except Exception as e:
log(
Expand Down
2 changes: 1 addition & 1 deletion nonebot/adapters/kaiheila/api/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class ApiMethod(NamedTuple):
method: str
restype: Optional[type]
restype: Optional[type] = None


api_method_map = {
Expand Down
22 changes: 22 additions & 0 deletions nonebot/adapters/kaiheila/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Literal, overload

from nonebot.compat import PYDANTIC_V2

__all__ = ("model_validator",)


if PYDANTIC_V2:
from pydantic import model_validator as model_validator
else:
from pydantic import root_validator

@overload
def model_validator(*, mode: Literal["before"]):
...

@overload
def model_validator(*, mode: Literal["after"]):
...

def model_validator(*, mode: Literal["before", "after"]):
return root_validator(pre=mode == "before", allow_reuse=True)
27 changes: 21 additions & 6 deletions nonebot/adapters/kaiheila/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple

from pydantic import Field, BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict


class BotConfig(BaseModel):
Expand All @@ -12,9 +13,16 @@ class BotConfig(BaseModel):

token: str

class Config:
extra = "ignore"
allow_population_by_field_name = True
if PYDANTIC_V2:
model_config = ConfigDict(
extra="ignore",
populate_by_name=True,
)
else:

class Config(ConfigDict):
extra = "ignore"
allow_population_by_field_name = True


class Config(BaseModel):
Expand All @@ -37,6 +45,13 @@ class Config(BaseModel):
compress: Optional[bool] = Field(default=False)
kaiheila_ignore_events: Tuple[str, ...] = Field(default_factory=tuple)

class Config:
extra = "allow"
allow_population_by_field_name = True
if PYDANTIC_V2:
model_config = ConfigDict(
extra="allow",
populate_by_name=True,
)
else:

class Config(ConfigDict):
extra = "allow"
allow_population_by_field_name = True
26 changes: 14 additions & 12 deletions nonebot/adapters/kaiheila/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from pygtrie import StringTrie
from nonebot.utils import escape_tag
from pydantic import Field, HttpUrl, BaseModel, validator, root_validator
from nonebot.compat import model_dump
from pydantic import Field, HttpUrl, BaseModel, validator

from nonebot.adapters import Event as BaseEvent

from .utils import AttrDict
from .compat import model_validator
from .exception import NoLogException
from .message import Message, MessageDeserializer
from .api import Role, User, Emoji, Guild, Channel
Expand Down Expand Up @@ -109,7 +111,7 @@ def get_event_name(self) -> str:

@override
def get_event_description(self) -> str:
return escape_tag(str(self.dict()))
return escape_tag(str(model_dump(self)))

@override
def get_message(self) -> Message:
Expand Down Expand Up @@ -140,23 +142,23 @@ class Kmarkdown(BaseModel):

class EventMessage(BaseModel):
type: Union[int, str]
guild_id: Optional[str]
channel_name: Optional[str]
mention: Optional[List]
mention_all: Optional[bool]
mention_roles: Optional[List]
mention_here: Optional[bool]
nav_channels: Optional[List]
guild_id: Optional[str] = None
channel_name: Optional[str] = None
mention: Optional[List] = None
mention_all: Optional[bool] = None
mention_roles: Optional[List] = None
mention_here: Optional[bool] = None
nav_channels: Optional[List] = None
author: User

kmarkdown: Optional[Kmarkdown]
kmarkdown: Optional[Kmarkdown] = None

code: Optional[str] = None
attachments: Optional[Attachment] = None

content: Message

@root_validator(pre=True)
@model_validator(mode="before")
def parse_message(cls, values: dict):
values["content"] = MessageDeserializer(
values["type"],
Expand Down Expand Up @@ -196,7 +198,7 @@ class Event(OriginEvent):

@override
def get_event_description(self) -> str:
return escape_tag(str(self.dict()))
return escape_tag(str(model_dump(self)))

@override
def get_plaintext(self) -> str:
Expand Down
22 changes: 11 additions & 11 deletions nonebot/adapters/kaiheila/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ class Video(Media):
if TYPE_CHECKING:

class _VideoData(Media._MediaData):
title: Optional[str]
title: Optional[str] = None

data: _VideoData

Expand Down Expand Up @@ -383,8 +383,8 @@ class Audio(Media):
if TYPE_CHECKING:

class _AudioData(Media._MediaData):
title: Optional[str]
cover_file_key: Optional[str]
title: Optional[str] = None
cover_file_key: Optional[str] = None

data: _AudioData

Expand Down Expand Up @@ -425,7 +425,7 @@ class File(Media):
if TYPE_CHECKING:

class _FileData(Media._MediaData):
title: Optional[str]
title: Optional[str] = None

data: _FileData

Expand Down Expand Up @@ -458,9 +458,9 @@ class LocalMedia(VirtualMessageSegment):
if TYPE_CHECKING:

class _LocalMediaData(TypedDict):
content: Optional[bytes]
title: Optional[str]
file: Optional[Path]
content: Optional[bytes] = None
title: Optional[str] = None
file: Optional[Path] = None

data: _LocalMediaData

Expand Down Expand Up @@ -564,8 +564,8 @@ class LocalAudio(LocalMedia):
if TYPE_CHECKING:

class _LocalAudioData(LocalMedia._LocalMediaData):
cover_content: Optional[bytes]
cover_file: Optional[Path]
cover_content: Optional[bytes] = None
cover_file: Optional[Path] = None

data: _LocalAudioData

Expand Down Expand Up @@ -635,7 +635,7 @@ class Mention(VirtualMessageSegment):

class _MentionData(TypedDict):
user_id: str
username: Optional[str]
username: Optional[str] = None

data: _MentionData

Expand All @@ -662,7 +662,7 @@ class MentionRole(VirtualMessageSegment):

class _MentionData(TypedDict):
role_id: str
name: Optional[str]
name: Optional[str] = None

data: _MentionData

Expand Down
Loading

0 comments on commit 7627a49

Please sign in to comment.