diff --git a/molter/command.py b/molter/command.py index 6cffe67..360b04a 100644 --- a/molter/command.py +++ b/molter/command.py @@ -15,13 +15,24 @@ class CommandParameter: name: str = attr.ib(default=None) default: typing.Optional[typing.Any] = attr.ib(default=None) + type: type = attr.ib(default=None) converters: list[ typing.Callable[[dis_snek.MessageContext, str], typing.Any] ] = attr.ib(factory=list) + union: bool = attr.ib(default=False) variable: bool = attr.ib(default=False) consume_rest: bool = attr.ib(default=False) +def _get_type_name(x: type): + try: + return x.__name__ + except AttributeError: + if hasattr(x, "__origin__"): + return repr(x) + return x.__class__.__name__ + + def _convert_to_bool(argument: str) -> bool: lowered = argument.lower() if lowered in ("yes", "y", "true", "t", "1", "enable", "on"): @@ -71,9 +82,10 @@ def _get_params(func: typing.Callable): else dis_snek.const.MISSING ) - anno = param.annotation + cmd_param.type = anno = param.annotation if typing.get_origin(anno) in {typing.Union, UnionType}: + cmd_param.union = True for arg in typing.get_args(anno): if arg != NoneType: converter = _get_converter(arg) @@ -147,7 +159,10 @@ async def call_callback( converted = await maybe_coroutine(converter, ctx, arg) break except Exception as e: - if param.default == dis_snek.const.MISSING: + if ( + not param.union + and param.default == dis_snek.const.MISSING + ): raise errors.BadArgument(str(e)) if converted == dis_snek.const.MISSING: @@ -156,10 +171,13 @@ async def call_callback( new_args.append(converted) param_index += 1 else: - # vague, ik + union_types = typing.get_args(param.type) + union_names = tuple(_get_type_name(t) for t in union_types) + union_types_str = ( + ", ".join(union_names[:-1]) + f", or {union_names[-1]}" + ) raise errors.BadArgument( - f"Could not convert {arg} into a type specified for" - f" {param.name}." + f"Could not convert {arg} into {union_types_str}." ) else: new_args.append(converted) diff --git a/molter/errors.py b/molter/errors.py index c18dc47..eef18e6 100644 --- a/molter/errors.py +++ b/molter/errors.py @@ -1,5 +1,12 @@ +import typing + import dis_snek class BadArgument(dis_snek.CommandException): - pass + def __init__(self, message: typing.Optional[str] = None, *args: typing.Any) -> None: + if message is not None: + message = dis_snek.utils.escape_mentions(message) + super().__init__(message, *args) + else: + super().__init__(*args)