Skip to content
This repository has been archived by the owner on May 1, 2022. It is now read-only.

Commit

Permalink
Various fixes to consume rest behavior
Browse files Browse the repository at this point in the history
Also added proper error handling for functions with more than 2 args
  • Loading branch information
AstreaTSS committed Dec 20, 2021
1 parent 38cb083 commit ab23615
Showing 1 changed file with 60 additions and 21 deletions.
81 changes: 60 additions & 21 deletions molter/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,37 @@ class CommandParameter:
consume_rest: bool = attr.ib(default=False)


def _get_type_name(x: type):
@attr.s(slots=True)
class ArgsIterator:
args: typing.Sequence[str] = attr.ib(converter=tuple)
index: int = attr.ib(init=False, default=0)
length: int = attr.ib(init=False, default=0)

def __iter__(self):
self.length = len(self.args)
self.index = 0
return self

def __next__(self):
if self.index >= self.length:
raise StopIteration

result = self.args[self.index]
self.index += 1
return result

def consume_rest(self):
result = self.args[self.index - 1 :]
self.index = self.length
return result

def forward(self, count: int = 1):
result = self.args[self.index - 1 : self.length + (count - 1)]
self.index += count
return result


def _get_name(x: typing.Any):
try:
return x.__name__
except AttributeError:
Expand All @@ -44,7 +74,7 @@ def _convert_to_bool(argument: str) -> bool:


def _get_converter(
anno: type,
anno: type, name: str
) -> typing.Callable[[dis_snek.MessageContext, str], typing.Any]: # type: ignore
if converter := converters.SNEK_OBJECT_TO_CONVERTER.get(anno, None):
return converter().convert # type: ignore
Expand All @@ -57,8 +87,13 @@ def _get_converter(
return lambda ctx, arg: anno(ctx, arg)
case 1:
return lambda ctx, arg: anno(arg)
case 0:
return lambda ctx, arg: anno()
case _:
errors.BadArgument(anno)
errors.BadArgument(
f"{_get_name(anno)} for {name} has more than 2 arguments, which is"
" unsupported."
)
elif anno == bool:
return lambda ctx, arg: _convert_to_bool(arg)
elif anno == inspect._empty:
Expand Down Expand Up @@ -88,19 +123,23 @@ def _get_params(func: typing.Callable):
cmd_param.union = True
for arg in typing.get_args(anno):
if arg != NoneType:
converter = _get_converter(arg)
converter = _get_converter(arg, name)
cmd_param.converters.append(converter)
elif cmd_param.default == dis_snek.const.MISSING: # d.py-like behavior
cmd_param.default = None
else:
converter = _get_converter(anno)
converter = _get_converter(anno, name)
cmd_param.converters.append(converter)

match param.kind:
case param.KEYWORD_ONLY:
cmd_param.consume_rest = True
cmd_params.append(cmd_param)
break
case param.VAR_POSITIONAL:
cmd_param.variable = True
cmd_params.append(cmd_param)
break

cmd_params.append(cmd_param)

Expand Down Expand Up @@ -135,22 +174,19 @@ async def call_callback(
return await callback(ctx)
else:
new_args: list[typing.Any] = []
args: list[str] = ctx.args
kwargs: dict[str, typing.Any] = {}
args = ArgsIterator(ctx.args)
param_index = 0

break_for_loop = False

for index, arg in enumerate(args):
for arg in args:
while param_index < len(self.params):
param = self.params[param_index]

if param.consume_rest:
arg = " ".join(args[index:])
break_for_loop = True
arg = " ".join(args.consume_rest())
if param.variable:
# temp behavior until i decide what to do with this
new_args.append(args[index:])
break_for_loop = True
new_args.append(args.consume_rest())
break

converted = dis_snek.const.MISSING
Expand All @@ -168,31 +204,34 @@ async def call_callback(
if converted == dis_snek.const.MISSING:
if param.default != dis_snek.const.MISSING:
converted = param.default
new_args.append(converted)
if not param.consume_rest:
new_args.append(converted)
else:
kwargs[param.name] = converted
param_index += 1
else:
union_types = typing.get_args(param.type)
union_names = tuple(_get_type_name(t) for t in union_types)
union_names = tuple(_get_name(t) for t in union_types)
union_types_str = (
", ".join(union_names[:-1]) + f", or {union_names[-1]}"
)
raise errors.BadArgument(
f"Could not convert {arg} into {union_types_str}."
)
else:
new_args.append(converted)
if not param.consume_rest:
new_args.append(converted)
else:
kwargs[param.name] = converted
param_index += 1
break

if break_for_loop:
break

if len(new_args) < len(self.params):
if param_index < len(self.params):
raise errors.BadArgument(
f"Missing argument for {self.params[len(new_args)].name}"
)

return await callback(ctx, *new_args)
return await callback(ctx, *new_args, **kwargs)


def message_command(
Expand Down

0 comments on commit ab23615

Please sign in to comment.