diff --git a/commands.py b/commands.py index 01f34c5..604a76b 100644 --- a/commands.py +++ b/commands.py @@ -20,19 +20,21 @@ import re import discord +import redis import sql.guild from utils import cho_command -CMD_START = "start" -CMD_STOP = "stop" +CMD_HELP = "help" CMD_SCOREBOARD = "scoreboard" CMD_SET_CHANNEL = "set-channel" CMD_SET_PREFIX = "set-prefix" -CMD_HELP = "help" +CMD_SET_STATUS = "set-status" +CMD_START = "start" +CMD_STOP = "stop" -DISCORD_CHANNEL_REGEX = re.compile("^<#([0-9]*)>$") -ALLOWED_PREFIXES = set(["!", "&", "?", "|", "^", "%"]) +DISCORD_CHANNEL_REGEX = re.compile(r"^<#([0-9]*)>$") +ALLOWED_PREFIXES = {"!", "&", "?", "|", "^", "%"} LOGGER = logging.getLogger("cho") @@ -202,7 +204,7 @@ async def handle_set_channel(self, message, args, config): if len(args) < 3: await message.channel.send( - "Please specify a channel when using \"set-channel\"." + f"Please specify a channel when using \"{CMD_SET_CHANNEL}\"." ) return @@ -237,7 +239,7 @@ async def handle_set_prefix(self, message, args, config): if len(args) < 3: await message.channel.send( - "Please specify a prefix when using \"set-prefix\"." + f"Please specify a prefix when using \"{CMD_SET_PREFIX}\"." ) return @@ -254,6 +256,40 @@ async def handle_set_prefix(self, message, args, config): config["prefix"] = new_prefix sql.guild.update_guild_config(self.engine, guild_id, config) - await message.channel.send( - "My prefix is now in \"{}\".".format(new_prefix) - ) + await message.channel.send(f"My prefix is now \"{new_prefix}\".") + + @cho_command(CMD_SET_STATUS, owner_only=True) + async def handle_set_status(self, message, args, config): + """Updates the bot's status across all shards. + + :param m message: + :param list args: + :param dict config: + :type m: discord.message.Message + """ + + if len(args) < 3: + await message.channel.send( + f"Please specify a status when using \"{CMD_SET_STATUS}\"." + ) + return + elif len(args) > 3: + await message.channel.send( + f"Too many arguments for \"{CMD_SET_STATUS}\". Surround your " + f"status with double quotes to include spaces." + ) + return + + new_status = args[2] + + try: + self.redis.set("cho:status", new_status) + await self.set_status() + except redis.ConnectionError as exc: + LOGGER.warning(exc) + + await message.channel.send( + "Unable to set status due to a redis connection error." + ) + else: + await message.channel.send(f"My status is now \"{new_status}\".") diff --git a/lorewalker_cho.py b/lorewalker_cho.py index ed12878..16e0143 100644 --- a/lorewalker_cho.py +++ b/lorewalker_cho.py @@ -18,13 +18,16 @@ import asyncio import logging +import shlex import traceback from commands import CommandsMixin import discord +import redis from discord.message import Message +from redis import Redis from sqlalchemy.engine import Engine import utils @@ -38,18 +41,21 @@ class LorewalkerCho(CommandsMixin, GameMixin, discord.Client): """Discord client wrapper that uses functionality from cho.py.""" - def __init__(self, engine: Engine): + def __init__(self, engine: Engine, redis_client: Redis): """Initializes the ChoClient with a sqlalchemy connection pool. :param e engine: SQLAlchemy engine to make queries with. + :param r redis_client: A redis client for caching non-persistant data. :type e: sqlalchemy.engine.Engine - :rtype: ChoClient + :type r: redis.Redis + :rtype: LorewalkerCho :return: """ super().__init__() self.engine = engine + self.redis = redis_client self.guild_configs = {} self.active_games = {} @@ -58,9 +64,7 @@ async def on_ready(self): LOGGER.info("Client logged in as \"%s\"", self.user) - await self.change_presence( - status=discord.Status.online, - activity=discord.Game(name="!cho help")) + await self.set_status() asyncio.ensure_future(self.resume_incomplete_games()) @@ -132,10 +136,10 @@ async def handle_command(self, message): else: _, config = guild_query_results - # TODO: Come up with a better way to split up arguments. If we want to - # support flags in the future this might need to be done using a real - # argument parser. - args = message.content.split() + # Split arguments as if they're in a shell-like syntax using shlex. + # This allows for arguments to be quoted so strings with spaces can be + # included. + args = shlex.split(message.content) # Handle cho invocations with no command. if len(args) < 2: @@ -189,3 +193,21 @@ async def handle_message_response(self, message: Message): if utils.is_message_from_trivia_channel(message, config): await self.process_answer(message) + + async def set_status(self): + """Sets the bot status to the saved one, or the default if missing.""" + + status = "!cho help" + + try: + saved_status = self.redis.get("cho:status") + if saved_status: + status = saved_status.decode() + except redis.ConnectionError as exc: + LOGGER.warning(exc) + + LOGGER.debug("Setting status to \"%s\"", status) + + await self.change_presence( + status=discord.Status.online, + activity=discord.Game(name=status)) diff --git a/main.py b/main.py index 3954a44..2392d97 100755 --- a/main.py +++ b/main.py @@ -21,7 +21,10 @@ import argparse import logging import os + +import redis import sqlalchemy as sa + import config from lorewalker_cho import LorewalkerCho @@ -58,7 +61,10 @@ def main(): engine.connect() LOGGER.info("Started connection pool with size: %d", SQLALCHEMY_POOL_SIZE) - discord_client = LorewalkerCho(engine) + redis_url = os.environ.get("CHO_REDIS_URL") or "redis://localhost:6379" + redis_client = redis.Redis.from_url(redis_url) + + discord_client = LorewalkerCho(engine, redis_client) discord_client.run(DISCORD_TOKEN) LOGGER.info("Shutting down... good bye!") diff --git a/utils.py b/utils.py index e170184..abc8f00 100644 --- a/utils.py +++ b/utils.py @@ -17,6 +17,7 @@ """Core code that controls Cho's behavior.""" import logging +import os from collections import OrderedDict @@ -34,19 +35,24 @@ CHANNEL_COMMANDS = OrderedDict() -def cho_command(command, kind="global", admin_only=False): +def cho_command(command, kind="global", admin_only=False, owner_only=False): """Marks a function as a runnable command.""" def decorator(func): def wrapper(*args, **kwargs): - if admin_only: + if owner_only: + message = args[1] + if not is_owner(message.author): + return message.channel.send( + "Sorry, only the bot owner can run that command." + ) + return func(*args, **kwargs) + elif admin_only: message = args[1] - if not is_admin(message.author, message.channel): return message.channel.send( "Sorry, only administrators run that command." ) - return func(*args, **kwargs) return func(*args, **kwargs) @@ -107,6 +113,25 @@ def is_admin(member: Member, channel: TextChannel) -> bool: return channel.permissions_for(member).administrator +def is_owner(member: Member) -> bool: + """Checks if a passed in Member is the bot owner (me!). + + :param m member: + :type m: discord.member.Member + :rtype: bool + :return: + """ + + is_bot_owner = False + + try: + is_bot_owner = member.id == int(os.environ.get("CHO_OWNER", 0)) + except ValueError: + pass + + return is_bot_owner + + def is_message_from_trivia_channel(message: Message, config: dict) -> bool: """Checks if the message is from the trivia channel.