From 2ba76142ba36779d998396f1e1f3bb76b30d28c4 Mon Sep 17 00:00:00 2001 From: Archento Date: Tue, 29 Oct 2024 13:46:22 +0100 Subject: [PATCH] add: quota protocol fix --- .../uagents/experimental/quota/__init__.py | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/python/src/uagents/experimental/quota/__init__.py b/python/src/uagents/experimental/quota/__init__.py index 7b48bad0..7cfd79d4 100644 --- a/python/src/uagents/experimental/quota/__init__.py +++ b/python/src/uagents/experimental/quota/__init__.py @@ -57,7 +57,7 @@ async def handle(ctx: Context, sender: str, msg: ExampleMessage3): runtime. This can be useful for dynamic access control rules based on the state of the agent or the network. ```python -acl = AccessControlList(default=True, allowed={""}, blocked={""}) +acl = AccessControlList(default=True) @proto.on_message(model=Message, access_control_list=acl) async def message_handler(ctx: Context, sender: str, msg: Message): @@ -97,8 +97,9 @@ class RateLimit(BaseModel): class AccessControlList(BaseModel): default: bool - allowed: set[str] - blocked: set[str] + allowed: set[str] = set() + blocked: set[str] = set() + bypass_rate_limit: set[str] = set() class QuotaProtocol(Protocol): @@ -108,6 +109,7 @@ def __init__( name: Optional[str] = None, version: Optional[str] = None, default_rate_limit: Optional[RateLimit] = None, + default_acl: Optional[AccessControlList] = None, ): """ Initialize a QuotaProtocol instance. @@ -116,11 +118,15 @@ def __init__( storage_reference (StorageAPI): The storage reference to use for rate limiting. name (Optional[str], optional): The name of the protocol. Defaults to None. version (Optional[str], optional): The version of the protocol. Defaults to None. - acl (Optional[AccessControlList], optional): The access control list. Defaults to None. + default_rate_limit (Optional[RateLimit], optional): The default rate limit. + Defaults to None. + default_acl (Optional[AccessControlList], optional): The access control list. + Defaults to None. """ super().__init__(name=name, version=version) self.storage_ref = storage_reference self.default_rate_limit = default_rate_limit + self.default_acl = default_acl def on_message( self, @@ -172,8 +178,9 @@ def wrap( Returns: Callable: The decorated """ + acl = acl or self.default_acl if acl is None: - acl = AccessControlList(default=True, allowed=set(), blocked=set()) + acl = AccessControlList(default=True) rate_limit = rate_limit or self.default_rate_limit @@ -184,28 +191,26 @@ async def decorator(ctx: Context, sender: str, msg: Type[Model]): ): return await ctx.send( sender, - ErrorMessage( - error=("You are not allowed to access this endpoint.") - ), + ErrorMessage(error=("You are not allowed to access this handler.")), + ) + if ( + sender in acl.bypass_rate_limit + or not rate_limit + or self.add_request( + sender, + func.__name__, + rate_limit.window_size_minutes, + rate_limit.max_requests, ) - if not rate_limit or self.add_request( - sender, - func.__name__, - rate_limit.window_size_minutes, - rate_limit.max_requests, ): result = await func(ctx, sender, msg) else: - result = await ctx.send( - sender, - ErrorMessage( - error=( - f"Rate limit exceeded for {msg.schema()["title"]}. " - f"This endpoint allows for {rate_limit.max_requests} calls per " - f"{rate_limit.window_size_minutes} minutes. Try again later." - ) - ), + err = ( + f"Rate limit exceeded for {msg.schema()['title']}. " + f"This handler allows for {rate_limit.max_requests} calls per " + f"{rate_limit.window_size_minutes} minutes. Try again later." ) + result = await ctx.send(sender, ErrorMessage(error=err)) return result return decorator # type: ignore