Skip to content

Commit

Permalink
Use a proper cookiejar + Nicer embeds
Browse files Browse the repository at this point in the history
  • Loading branch information
beer-psi committed Mar 17, 2024
1 parent 01bfeb2 commit 1c6e38d
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 230 deletions.
300 changes: 133 additions & 167 deletions chunithm_net/__init__.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions chunithm_net/_bs4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import importlib.util

BS4_FEATURE = "lxml" if importlib.util.find_spec("lxml") else "html.parser"
27 changes: 27 additions & 0 deletions chunithm_net/_httpx_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from http.client import SERVICE_UNAVAILABLE
import httpx
from bs4 import BeautifulSoup

from ._bs4 import BS4_FEATURE
from .exceptions import MaintenanceException, ChuniNetError


async def raise_on_chunithm_net_error(response: httpx.Response):
if response.url.path != "/mobile/error/":
return

html = ""
async for chunk in response.aiter_text():
html += chunk

dom = BeautifulSoup(html, BS4_FEATURE)
error_blocks = dom.select(".block.text_l .font_small")
code = int(error_blocks[0].text.split(": ", 1)[1])
description = error_blocks[1].text if len(error_blocks) > 1 else ""

raise ChuniNetError(code, description)


async def raise_on_scheduled_maintenance(response: httpx.Response):
if response.status_code == SERVICE_UNAVAILABLE:
raise MaintenanceException
38 changes: 25 additions & 13 deletions cogs/botutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
from http.cookiejar import LWPCookieJar
import io
from typing import TYPE_CHECKING, Optional, overload

from discord.ext import commands
from discord.ext.commands import Context
from sqlalchemy import select, text
from jarowinkler import jarowinkler_similarity
from sqlalchemy import select, text, update

from chunithm_net import ChuniNet
from chunithm_net.entities.enums import Rank
Expand Down Expand Up @@ -44,36 +47,47 @@ async def guild_prefix(self, ctx: Context) -> str:

return self.bot.prefixes.get(ctx.guild.id, default_prefix)

async def login_check(self, ctx_or_id: Context | int) -> str:
async def login_check(self, ctx_or_id: Context | int) -> LWPCookieJar:
id = ctx_or_id if isinstance(ctx_or_id, int) else ctx_or_id.author.id
clal = await self.fetch_cookie(id)
if clal is None:
msg = "You are not logged in. Please send `c>login` in my DMs to log in."
raise commands.BadArgument(msg)
return clal

async def fetch_cookie(self, id: int) -> str | None:
async def fetch_cookie(self, id: int) -> LWPCookieJar | None:
async with self.bot.begin_db_session() as session:
stmt = select(Cookie).where(Cookie.discord_id == id)
cookie = (await session.execute(stmt)).scalar_one_or_none()

if cookie is None:
return None

return cookie.cookie

jar = LWPCookieJar()
jar._really_load( # type: ignore[reportAttributeAccessIssue]
io.StringIO(cookie.cookie), "?", ignore_discard=False, ignore_expires=False
)

return jar

@contextlib.asynccontextmanager
async def chuninet(self, ctx_or_id: Context | int):
id = ctx_or_id if isinstance(ctx_or_id, int) else ctx_or_id.author.id
cookie = await self.login_check(ctx_or_id)
user_id, token = self.bot.sessions.get(id, (None, None))
jar = await self.login_check(ctx_or_id)

session = ChuniNet(cookie, user_id=user_id, token=token)
session = ChuniNet(jar)
try:
yield session
finally:
async with self.bot.begin_db_session() as db_session:
await db_session.execute(
update(Cookie)
.where(Cookie.discord_id == id)
.values(cookie=f"#LWP-Cookies-2.0\n{jar.as_lwp_str()}")
)
await db_session.commit()

await session.close()
self.bot.sessions[id] = (session.user_id, session.token)

@overload
async def annotate_song(
Expand All @@ -84,11 +98,9 @@ async def annotate_song(
@overload
async def annotate_song(self, song: RecentRecord) -> AnnotatedRecentRecord:
...

@overload
async def annotate_song(
self, song: Record | MusicRecord
) -> AnnotatedMusicRecord:
async def annotate_song(self, song: Record | MusicRecord) -> AnnotatedMusicRecord:
...

async def annotate_song(
Expand Down
72 changes: 34 additions & 38 deletions cogs/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from discord.ext.commands import Context

from chunithm_net.exceptions import (
ChuniNetError,
ChuniNetException,
InvalidTokenException,
MaintenanceException,
Expand Down Expand Up @@ -36,57 +37,49 @@ async def on_command_error(
while hasattr(exc, "original"):
exc = exc.original # type: ignore[reportGeneralTypeIssues]

embed = discord.Embed(
color=discord.Color.red(),
title="Error",
)
delete_after = None

if isinstance(exc, MaintenanceException):
return await ctx.reply(
"CHUNITHM-NET is currently undergoing maintenance. Please try again later.",
mention_author=False,
)
embed.description = "CHUNITHM-NET is currently undergoing maintenance. Please try again later."
if isinstance(exc, InvalidTokenException):
message = (
embed.description = (
"CHUNITHM-NET cookie is invalid. Please use `c>login` in DMs to log in."
)
if self.bot.dev:
message += f"\nDetailed error: {exc}"
return await ctx.reply(message, mention_author=False)
embed.description += f"\nDetailed error: {exc}"
if isinstance(exc, ChuniNetError):
embed.description = f"CHUNITHM-NET error {exc.code}: {exc.description}"
if isinstance(exc, ChuniNetException):
message = "An error occurred while communicating with CHUNITHM-NET. Please try again later (or re-login)."
embed.description = "An error occurred while communicating with CHUNITHM-NET. Please try again later (or re-login)."
if self.bot.dev:
message += f"\nDetailed error: {exc}"
return await ctx.reply(message, mention_author=False)

embed.description += f"\nDetailed error: {exc}"
if isinstance(exc, commands.errors.CommandOnCooldown):
return await ctx.reply(
f"You're too fast. Take a break for {exc.retry_after:.2f} seconds.",
mention_author=False,
delete_after=exc.retry_after,
embed.description = (
f"You're too fast. Take a break for {exc.retry_after:.2f} seconds."
)
delete_after = exc.retry_after
if isinstance(exc, commands.errors.ExpectedClosingQuoteError):
return await ctx.reply(
"You're missing a quote somewhere. Perhaps you're using the wrong kind of quote (`\"` vs `”`)?",
mention_author=False,
)
embed.description = "You're missing a quote somewhere. Perhaps you're using the wrong kind of quote (`\"` vs `”`)?"
if isinstance(exc, commands.errors.UnexpectedQuoteError):
return await ctx.reply(
(
f"Unexpected quote mark, {exc.quote!r}, in non-quoted string. If this was intentional, "
"escape the quote with a backslash (\\\\)."
),
mention_author=False,
embed.description = (
f"Unexpected quote mark, {exc.quote!r}, in non-quoted string. If this was intentional, "
"escape the quote with a backslash (\\\\)."
)
if isinstance(
exc, (commands.errors.NotOwner, commands.errors.MissingPermissions)
):
return await ctx.reply("Insufficient permissions.", mention_author=False)
embed.description = "Insufficient permissions."
if isinstance(exc, commands.BadLiteralArgument):
to_string = [repr(x) for x in exc.literals]
if len(to_string) > 2:
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = " or ".join(to_string)
return await ctx.reply(
f"`{exc.param.displayed_name or exc.param.name}` must be one of {fmt}, received {exc.argument!r}",
mention_author=False,
)
embed.description = f"`{exc.param.displayed_name or exc.param.name}` must be one of {fmt}, received {exc.argument!r}"
if isinstance(
exc,
(
Expand All @@ -100,17 +93,20 @@ async def on_command_error(
commands.PrivateMessageOnly,
),
):
return await ctx.reply(str(error), mention_author=False)
embed.description = str(error)

if embed.description is not None:
return await ctx.reply(
embed=embed, mention_author=False, delete_after=delete_after
)

logger.error("Exception in command %s", ctx.command, exc_info=exc)
await ctx.reply(
(
"Something really terrible happened. "
f"The owner <@{self.bot.owner_id}> has been notified.\n"
"Please try again in a couple of hours."
),
mention_author=False,
embed.description = (
"Something really terrible happened. "
f"The owner <@{self.bot.owner_id}> has been notified.\n"
"Please try again in a couple of hours."
)
await ctx.reply(embed=embed, mention_author=False)

if webhook_url := config.bot.error_reporting_webhook:
async with aiohttp.ClientSession() as session:
Expand Down
26 changes: 16 additions & 10 deletions cogs/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from discord.ext.commands import Context
from PIL import Image

from chunithm_net.exceptions import ChuniNetError
from utils.views.profile import ProfileView

if TYPE_CHECKING:
Expand Down Expand Up @@ -226,18 +227,23 @@ async def rename(self, ctx: Context, *, new_name: str):

async with ctx.typing(), self.utils.chuninet(ctx) as client:
await client.authenticate()

try:
if await client.change_player_name(new_name):
await ctx.reply(
"Your username has been changed.", mention_author=False
)
else:
await ctx.reply(
"There was an error changing your username.",
mention_author=False,
)
await client.change_player_name(new_name)
await ctx.reply("Your username has been changed.", mention_author=False)
except ValueError as e:
raise commands.BadArgument(str(e)) from None
msg = str(e)

if msg == "文字数が多すぎます。": # Too many characters
msg = "The new username is too long (only 8 characters allowed)."

raise commands.BadArgument(msg) from None
except ChuniNetError as e:
if e.code == 110106:
msg = "The new username contains a banned word."
raise commands.BadArgument(msg) from None

raise


async def setup(bot: "ChuniBot"):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Use a proper cookie jar
Revision ID: 50bb24f19a3a
Revises: 901af18ec932
Create Date: 2024-03-17 14:18:11.579871
"""
from http.cookiejar import LWPCookieJar, Cookie
from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "50bb24f19a3a"
down_revision: Union[str, None] = "901af18ec932"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

cookies = sa.Table(
"cookies",
sa.MetaData(),
sa.Column("discord_id", sa.BigInteger, primary_key=True),
sa.Column("cookie", sa.String, nullable=False),
sa.Column("kamaitachi_token", sa.String(40), nullable=True),
)


def upgrade() -> None:
with op.batch_alter_table("cookies") as bop:
bop.alter_column("cookie", type_=sa.String)

conn = op.get_bind()
rows = [x._asdict() for x in conn.execute(sa.select(cookies))]

for row in rows:
cookie = Cookie(
version=0,
name="clal",
value=row["cookie"],
port=None,
port_specified=False,
domain="lng-tgk-aime-gw.am-all.net",
domain_specified=True,
domain_initial_dot=False,
path="/common_auth",
path_specified=True,
secure=False,
expires=3856586927, # 2092-03-17 10:08:47Z
discard=False,
comment=None,
comment_url=None,
rest={},
)
jar = LWPCookieJar()
jar.set_cookie(cookie)

conn.execute(
sa.update(cookies)
.where(cookies.c.discord_id == row["discord_id"])
.values(
cookie=f"#LWP-Cookies-2.0\n{jar.as_lwp_str(ignore_discard=False, ignore_expires=False)}"
)
)

conn.commit()


def downgrade() -> None:
pass
Loading

0 comments on commit 1c6e38d

Please sign in to comment.