Skip to content
This repository has been archived by the owner on May 1, 2022. It is now read-only.

Commit

Permalink
Clearer errors for Unions
Browse files Browse the repository at this point in the history
  • Loading branch information
AstreaTSS committed Dec 19, 2021
1 parent 831fe9f commit 38cb083
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
28 changes: 23 additions & 5 deletions molter/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,24 @@
class CommandParameter:
name: str = attr.ib(default=None)
default: typing.Optional[typing.Any] = attr.ib(default=None)
type: type = attr.ib(default=None)
converters: list[
typing.Callable[[dis_snek.MessageContext, str], typing.Any]
] = attr.ib(factory=list)
union: bool = attr.ib(default=False)
variable: bool = attr.ib(default=False)
consume_rest: bool = attr.ib(default=False)


def _get_type_name(x: type):
try:
return x.__name__
except AttributeError:
if hasattr(x, "__origin__"):
return repr(x)
return x.__class__.__name__


def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower()
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
Expand Down Expand Up @@ -71,9 +82,10 @@ def _get_params(func: typing.Callable):
else dis_snek.const.MISSING
)

anno = param.annotation
cmd_param.type = anno = param.annotation

if typing.get_origin(anno) in {typing.Union, UnionType}:
cmd_param.union = True
for arg in typing.get_args(anno):
if arg != NoneType:
converter = _get_converter(arg)
Expand Down Expand Up @@ -147,7 +159,10 @@ async def call_callback(
converted = await maybe_coroutine(converter, ctx, arg)
break
except Exception as e:
if param.default == dis_snek.const.MISSING:
if (
not param.union
and param.default == dis_snek.const.MISSING
):
raise errors.BadArgument(str(e))

if converted == dis_snek.const.MISSING:
Expand All @@ -156,10 +171,13 @@ async def call_callback(
new_args.append(converted)
param_index += 1
else:
# vague, ik
union_types = typing.get_args(param.type)
union_names = tuple(_get_type_name(t) for t in union_types)
union_types_str = (
", ".join(union_names[:-1]) + f", or {union_names[-1]}"
)
raise errors.BadArgument(
f"Could not convert {arg} into a type specified for"
f" {param.name}."
f"Could not convert {arg} into {union_types_str}."
)
else:
new_args.append(converted)
Expand Down
9 changes: 8 additions & 1 deletion molter/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import typing

import dis_snek


class BadArgument(dis_snek.CommandException):
pass
def __init__(self, message: typing.Optional[str] = None, *args: typing.Any) -> None:
if message is not None:
message = dis_snek.utils.escape_mentions(message)
super().__init__(message, *args)
else:
super().__init__(*args)

0 comments on commit 38cb083

Please sign in to comment.