Skip to content

Commit

Permalink
v 0.1.2
Browse files Browse the repository at this point in the history
Update version
Add `ErrorEmbed` model
Add client response error
Update `clean content` for boltalka api
  • Loading branch information
LEv145 committed Apr 25, 2022
1 parent 7c56709 commit 8083e86
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "discoboltalka"
version = "0.1.1"
version = "0.1.2"
description = "Discord bot for 'Demo Болталка'"
authors = ["lev"]
license = "MIT"
Expand Down
21 changes: 19 additions & 2 deletions src/discoboltalka/boltalka_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,26 @@ async def predict(self, contexts: t.List[t.List[str]]) -> t.List[str]:
responses = re.findall("'([^']+)'", json_response["responses"])

return [
context.replace("%bot_name", self._client_name) # TODO?: Template
context.replace("%bot_name", self._client_name)
for context in responses
]

async def _request(self, url: str, json: t.Any) -> t.Any:
"""
Exceptions:
ValidationError
ClientResponseError
"""
async with self._client_session.post(
url=url,
json=json,
) as response:
if response.status != 200:
raise ClientResponseError(
status_code=response.status,
reason=response.reason,
)

json_response = self._json_loader(await response.text())

detail = json_response.get("detail")
Expand All @@ -60,7 +67,7 @@ async def _request(self, url: str, json: t.Any) -> t.Any:


class APIError(Exception):
...
pass


class ValidationError(APIError):
Expand All @@ -73,3 +80,13 @@ def __init__(
self.message = message
self.location = location
self.type_ = type_


class ClientResponseError(APIError):
def __init__(
self,
status_code: int,
reason: str
) -> None:
self.status_code = status_code
self.reason = reason
74 changes: 64 additions & 10 deletions src/discoboltalka/boltalka_gateway_bot.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import re
import logging
import typing as t
from datetime import datetime

import hikari

from discoboltalka.boltalka_api import (
BoltalkaAPI,
ValidationError,
ClientResponseError,
)
from discoboltalka.utils.discord import clean_content
from discoboltalka.models import ErrorEmbed


_logger = logging.getLogger("discoboltalka.boltalka_gateway_bot")
Expand All @@ -35,24 +38,75 @@ async def message_listen(self, event: hikari.GuildMessageCreateEvent) -> None:
return

# Prepare content
content = clean_content(content).lstrip()
if not content:
clean_content = await self._get_clean_content_from_guild_message_create_event(
event=event,
content=content,
)
if not clean_content:
return

try:
boltalka_phrases = await self._boltalka_api.predict([[content]])
boltalka_phrases = await self._boltalka_api.predict([[clean_content]])
except ValidationError:
await event.message.respond(
embed=hikari.Embed(
title="Ошибка",
description="Я не смогла понять ваш текст",
colour=0xbd0505,
)
) # TODO?: Error handler or error embed
embed=ErrorEmbed("Я не смогла понять ваш текст"),
)
return
except ClientResponseError:
await event.message.respond(
embed=ErrorEmbed("Мои сервера дали сбой, спросите позже"),
)
return

boltalka_phrase = boltalka_phrases[0]

_logger.info(f"Boltalka response: {content!r} -> {boltalka_phrase!r}")

await event.message.respond(boltalka_phrases[0])

async def _get_clean_content_from_guild_message_create_event(
self,
event: hikari.GuildMessageCreateEvent,
content: str,
) -> str:
guild = event.get_guild()
clean_content = content

def member_repl(match: re.Match) -> str:
user_id = match[0]

member = guild.get_member(user=user_id)
if member is not None:
return f"@{member.display_name}"
return ""

def role_repl(match: re.Match) -> str:
role_id = match[0]

role = guild.get_role(role=role_id)
if role is not None:
return f"@{role.name}"
return ""

def channel_repl(match: re.Match) -> str:
channel_id = match[0]

channel = guild.get_channel(channel=channel_id)
if channel is not None:
return f"#{channel.name}"
return ""

def timestamp_repl(match: re.Match) -> str:
timestamp_ = match[0]
return str(datetime.fromtimestamp(timestamp_))

def discord_emoji_repl(_match: re.Match) -> str:
return ""

clean_content = re.sub(r'<@&(\d[1-9]+)>', role_repl, clean_content)
clean_content = re.sub(r'<#(\d[1-9]+)>', channel_repl, clean_content)
clean_content = re.sub(r'<@!?(\d[1-9]+)>', member_repl, clean_content)
clean_content = re.sub(r'<t:(\d)+(?::[a-zA-Z])?>', timestamp_repl, clean_content)
clean_content = re.sub(r'<a?:[^:]+:\d[1-9]+>', discord_emoji_repl, clean_content)

return clean_content.lstrip()
6 changes: 6 additions & 0 deletions src/discoboltalka/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .embeds import ErrorEmbed


__all__ = (
"ErrorEmbed",
)
10 changes: 10 additions & 0 deletions src/discoboltalka/models/embeds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import hikari


class ErrorEmbed(hikari.Embed):
def __init__(self, message: str) -> None:
super().__init__(
title="Ошибка",
description=message,
colour=0xbd0505,
)
17 changes: 0 additions & 17 deletions src/discoboltalka/utils/discord.py

This file was deleted.

0 comments on commit 8083e86

Please sign in to comment.