From ab23615ad657cc46ba175d1538721ecf575d6a1a Mon Sep 17 00:00:00 2001 From: Sonic4999 Date: Sun, 19 Dec 2021 22:51:43 -0500 Subject: [PATCH] Various fixes to consume rest behavior Also added proper error handling for functions with more than 2 args --- molter/command.py | 81 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/molter/command.py b/molter/command.py index 360b04a..9b51985 100644 --- a/molter/command.py +++ b/molter/command.py @@ -24,7 +24,37 @@ class CommandParameter: consume_rest: bool = attr.ib(default=False) -def _get_type_name(x: type): +@attr.s(slots=True) +class ArgsIterator: + args: typing.Sequence[str] = attr.ib(converter=tuple) + index: int = attr.ib(init=False, default=0) + length: int = attr.ib(init=False, default=0) + + def __iter__(self): + self.length = len(self.args) + self.index = 0 + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + + result = self.args[self.index] + self.index += 1 + return result + + def consume_rest(self): + result = self.args[self.index - 1 :] + self.index = self.length + return result + + def forward(self, count: int = 1): + result = self.args[self.index - 1 : self.length + (count - 1)] + self.index += count + return result + + +def _get_name(x: typing.Any): try: return x.__name__ except AttributeError: @@ -44,7 +74,7 @@ def _convert_to_bool(argument: str) -> bool: def _get_converter( - anno: type, + anno: type, name: str ) -> typing.Callable[[dis_snek.MessageContext, str], typing.Any]: # type: ignore if converter := converters.SNEK_OBJECT_TO_CONVERTER.get(anno, None): return converter().convert # type: ignore @@ -57,8 +87,13 @@ def _get_converter( return lambda ctx, arg: anno(ctx, arg) case 1: return lambda ctx, arg: anno(arg) + case 0: + return lambda ctx, arg: anno() case _: - errors.BadArgument(anno) + errors.BadArgument( + f"{_get_name(anno)} for {name} has more than 2 arguments, which is" + " unsupported." + ) elif anno == bool: return lambda ctx, arg: _convert_to_bool(arg) elif anno == inspect._empty: @@ -88,19 +123,23 @@ def _get_params(func: typing.Callable): cmd_param.union = True for arg in typing.get_args(anno): if arg != NoneType: - converter = _get_converter(arg) + converter = _get_converter(arg, name) cmd_param.converters.append(converter) elif cmd_param.default == dis_snek.const.MISSING: # d.py-like behavior cmd_param.default = None else: - converter = _get_converter(anno) + converter = _get_converter(anno, name) cmd_param.converters.append(converter) match param.kind: case param.KEYWORD_ONLY: cmd_param.consume_rest = True + cmd_params.append(cmd_param) + break case param.VAR_POSITIONAL: cmd_param.variable = True + cmd_params.append(cmd_param) + break cmd_params.append(cmd_param) @@ -135,22 +174,19 @@ async def call_callback( return await callback(ctx) else: new_args: list[typing.Any] = [] - args: list[str] = ctx.args + kwargs: dict[str, typing.Any] = {} + args = ArgsIterator(ctx.args) param_index = 0 - break_for_loop = False - - for index, arg in enumerate(args): + for arg in args: while param_index < len(self.params): param = self.params[param_index] if param.consume_rest: - arg = " ".join(args[index:]) - break_for_loop = True + arg = " ".join(args.consume_rest()) if param.variable: # temp behavior until i decide what to do with this - new_args.append(args[index:]) - break_for_loop = True + new_args.append(args.consume_rest()) break converted = dis_snek.const.MISSING @@ -168,11 +204,14 @@ async def call_callback( if converted == dis_snek.const.MISSING: if param.default != dis_snek.const.MISSING: converted = param.default - new_args.append(converted) + if not param.consume_rest: + new_args.append(converted) + else: + kwargs[param.name] = converted param_index += 1 else: union_types = typing.get_args(param.type) - union_names = tuple(_get_type_name(t) for t in union_types) + union_names = tuple(_get_name(t) for t in union_types) union_types_str = ( ", ".join(union_names[:-1]) + f", or {union_names[-1]}" ) @@ -180,19 +219,19 @@ async def call_callback( f"Could not convert {arg} into {union_types_str}." ) else: - new_args.append(converted) + if not param.consume_rest: + new_args.append(converted) + else: + kwargs[param.name] = converted param_index += 1 break - if break_for_loop: - break - - if len(new_args) < len(self.params): + if param_index < len(self.params): raise errors.BadArgument( f"Missing argument for {self.params[len(new_args)].name}" ) - return await callback(ctx, *new_args) + return await callback(ctx, *new_args, **kwargs) def message_command(