Skip to content

Commit

Permalink
chore: add some more type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Sep 21, 2024
1 parent 946bc3d commit 7d58330
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
15 changes: 7 additions & 8 deletions canary/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import aiosqlite
import contextlib
import logging
import sys
import traceback

from canary.config import Config
Expand Down Expand Up @@ -68,13 +67,13 @@


class _WebhookHandler(logging.Handler):
def __init__(self, webhook_id, webhook_token, username=None):
self.username = username or "Bot Logs"
def __init__(self, webhook_id: int, webhook_token: str, username: str | None = None):
self.username: str = username or "Bot Logs"
logging.Handler.__init__(self)
self.webhook = Webhook.partial(webhook_id, webhook_token, adapter=RequestsWebhookAdapter())
self.max_webhook_payload_size: int = 1800

def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
msg = self.format(record)
try:
self.webhook.send(
Expand Down Expand Up @@ -117,7 +116,7 @@ def __init__(self, *args, **kwargs):
self.mod_logger = mod_logger
self.config = config

async def start(self, *args, **kwargs): # TODO: discordpy 2.0: use setup_hook for database setup
async def start(self, *args, **kwargs) -> None: # TODO: discordpy 2.0: use setup_hook for database setup
await self._start_database()
await super().start(*args, **kwargs)
await self.health_check()
Expand All @@ -132,7 +131,7 @@ async def db(self) -> AsyncGenerator[aiosqlite.Connection, None]:
async def db_nocm(self) -> aiosqlite.Connection:
return await aiosqlite.connect(self.config.db_path)

async def _start_database(self):
async def _start_database(self) -> None:
if not self.config.db_path:
self.dev_logger.warning("No path to database configuration file")
return
Expand All @@ -147,12 +146,12 @@ async def _start_database(self):

self.dev_logger.debug("Database is ready")

async def health_check(self):
async def health_check(self) -> None:
guild = self.get_guild(self.config.server_id)
if not guild:
self.dev_logger.error(f"Could not get guild for bot (specified server ID {self.config.server_id})")

def log_traceback(self, exception):
def log_traceback(self, exception: Exception):
self.dev_logger.error("".join(traceback.format_exception(type(exception), exception, exception.__traceback__)))

async def on_command_error(self, ctx, error):
Expand Down
4 changes: 2 additions & 2 deletions canary/cogs/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ async def weather(self, ctx: commands.Context):
# Getting the wind specifically, because otherwise it starts being ugly very quickly
wind = soup.find("wind")

def retrieve_string(label, search=None, search_soup=soup):
def retrieve_string(label, search=None, search_soup=soup) -> str | None:
if elem := search_soup.find(label, string=search):
return elem.get_text().strip()
return None

def retrieve_attribute(label, key, search_soup=soup):
def retrieve_attribute(label, key, search_soup=soup) -> str | None:
if attr := search_soup.find(label)[key]:
return attr.strip()
return None
Expand Down
8 changes: 4 additions & 4 deletions canary/cogs/utils/auto_incorrect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__all__ = ["auto_incorrect"]


def _swap(s: str, i: int):
def _swap(s: str, i: int) -> str:
"""
Given string s and index i, it swaps s[i] and s[i+1] and returns
By @lazho
Expand All @@ -14,14 +14,14 @@ def _swap(s: str, i: int):
return s


def _repeat(s: str, i: int):
def _repeat(s: str, i: int) -> str:
"""
By @lazho
"""
return s[:i] + s[i] + s[i:]


def _omit(s: str, i: int):
def _omit(s: str, i: int) -> str:
"""
Given string s and index i, it omits the i-th character.
By @lazho
Expand Down Expand Up @@ -58,7 +58,7 @@ def _omit(s: str, i: int):
}


def auto_incorrect(input_str: str):
def auto_incorrect(input_str: str) -> str:
"""
By @lazho
"""
Expand Down
5 changes: 3 additions & 2 deletions canary/cogs/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

import discord
from discord.ext import commands
from typing import Literal

from canary.bot import config


def is_moderator():
"""Returns True if user has a moderator role, raises an exception otherwise"""

def predicate(ctx: commands.Context):
def predicate(ctx: commands.Context) -> Literal[True]:
if discord.utils.get(ctx.author.roles, name=config.moderator_role) is None:
raise commands.MissingPermissions([config.moderator_role])
return True
Expand All @@ -35,7 +36,7 @@ def predicate(ctx: commands.Context):
def is_developer():
"""Returns True if user is a bot developer, raises an exception otherwise"""

def predicate(ctx: commands.Context):
def predicate(ctx: commands.Context) -> Literal[True]:
if discord.utils.get(ctx.author.roles, name=config.developer_role) is None:
raise commands.MissingPermissions([config.developer_role])
return True
Expand Down
8 changes: 4 additions & 4 deletions canary/cogs/utils/role_restoration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

async def save_existing_roles(
bot: Canary, user: discord.Member, muted: bool = False, appeal_channel: discord.TextChannel | None = None
):
) -> None:
roles_id = [role.id for role in user.roles if role.name not in ("@everyone", bot.config.muted_role)]

if not roles_id and not muted:
Expand Down Expand Up @@ -78,12 +78,12 @@ async def fetch_saved_roles(bot: Canary, guild, user: discord.Member, muted: boo
)


def has_muted_role(bot: Canary, user: discord.Member):
def has_muted_role(bot: Canary, user: discord.Member) -> bool:
muted_role = utils.get(user.guild.roles, name=bot.config.muted_role)
return muted_role and next((r for r in user.roles if r == muted_role), None) is not None


async def is_in_muted_table(bot: Canary, user: discord.Member):
async def is_in_muted_table(bot: Canary, user: discord.Member) -> bool:
db: aiosqlite.Connection
c: aiosqlite.Cursor
async with bot.db() as db:
Expand All @@ -104,7 +104,7 @@ async def role_restoring_page(
user: discord.Member,
roles: list[discord.Role] | None,
muted: bool = False,
):
) -> None:
channel: discord.TextChannel | None = ctx.channel # Can be None from MockContext

if channel is None:
Expand Down

0 comments on commit 7d58330

Please sign in to comment.