From 061a4e78836a1998ff509411cb7efbcf1ef24683 Mon Sep 17 00:00:00 2001 From: Ollie <69084614+olijeffers0n@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:37:42 +0100 Subject: [PATCH] rework addresses --- rustplus/annotations/chat_event.py | 4 +- rustplus/annotations/command.py | 4 +- rustplus/annotations/entity_event.py | 4 +- rustplus/annotations/protobuf_event.py | 4 +- rustplus/annotations/team_event.py | 4 +- rustplus/commands/chat_command.py | 4 +- rustplus/identification/__init__.py | 2 +- .../{server_id.py => server_details.py} | 21 ++++-- rustplus/remote/handler_list.py | 65 +++++++++++-------- rustplus/remote/ratelimiter/__init__.py | 30 ++++----- rustplus/remote/rustplus_proto/rustplus.py | 2 +- rustplus/remote/websocket/ws.py | 35 ++++++---- rustplus/rust_api.py | 20 +++--- 13 files changed, 119 insertions(+), 80 deletions(-) rename rustplus/identification/{server_id.py => server_details.py} (69%) diff --git a/rustplus/annotations/chat_event.py b/rustplus/annotations/chat_event.py index 3053150..595398d 100644 --- a/rustplus/annotations/chat_event.py +++ b/rustplus/annotations/chat_event.py @@ -5,7 +5,7 @@ from ..events import ChatEventPayload as ChatEventManager -def ChatEvent(server_id: ServerDetails) -> Callable: +def ChatEvent(server_details: ServerDetails) -> Callable: def wrapper(func) -> RegisteredListener: @@ -14,7 +14,7 @@ def wrapper(func) -> RegisteredListener: listener = RegisteredListener(func.__name__, func) - ChatEventManager.HANDLER_LIST.register(listener, server_id) + ChatEventManager.HANDLER_LIST.register(listener, server_details) return listener diff --git a/rustplus/annotations/command.py b/rustplus/annotations/command.py index 0835fb4..24ecb4c 100644 --- a/rustplus/annotations/command.py +++ b/rustplus/annotations/command.py @@ -5,7 +5,7 @@ def Command( - server_id: ServerDetails, aliases: list = None, alias_func: Callable = None + server_details: ServerDetails, aliases: list = None, alias_func: Callable = None ) -> Callable: def wrapper(func): @@ -16,7 +16,7 @@ def wrapper(func): command_data = ChatCommandData( coroutine=func, aliases=aliases, callable_func=alias_func ) - ChatCommand.REGISTERED_COMMANDS[server_id][func.__name__] = command_data + ChatCommand.REGISTERED_COMMANDS[server_details][func.__name__] = command_data return RegisteredListener(func.__name__, func) diff --git a/rustplus/annotations/entity_event.py b/rustplus/annotations/entity_event.py index adc5b74..84ba11d 100644 --- a/rustplus/annotations/entity_event.py +++ b/rustplus/annotations/entity_event.py @@ -5,7 +5,7 @@ from ..events import EntityEventPayload as EntityEventManager -def EntityEvent(server_id: ServerDetails, eid: int) -> Callable: +def EntityEvent(server_details: ServerDetails, eid: int) -> Callable: def wrapper(func) -> RegisteredListener: if isinstance(func, RegisteredListener): func = func.get_coro() @@ -14,7 +14,7 @@ def wrapper(func) -> RegisteredListener: str(eid), func, 1 ) # TODO, how are we going to handle the entity type? - EntityEventManager.HANDLER_LIST.register(listener, server_id) + EntityEventManager.HANDLER_LIST.register(listener, server_details) return listener diff --git a/rustplus/annotations/protobuf_event.py b/rustplus/annotations/protobuf_event.py index 2ccabbe..889aaee 100644 --- a/rustplus/annotations/protobuf_event.py +++ b/rustplus/annotations/protobuf_event.py @@ -5,14 +5,14 @@ from ..events import ProtobufEventPayload as ProtobufEventManager -def ProtobufEvent(server_id: ServerDetails) -> Callable: +def ProtobufEvent(server_details: ServerDetails) -> Callable: def wrapper(func) -> RegisteredListener: if isinstance(func, RegisteredListener): func = func.get_coro() listener = RegisteredListener(func.__name__, func) - ProtobufEventManager.HANDLER_LIST.register(listener, server_id) + ProtobufEventManager.HANDLER_LIST.register(listener, server_details) return listener diff --git a/rustplus/annotations/team_event.py b/rustplus/annotations/team_event.py index 4f9d9f8..c25801b 100644 --- a/rustplus/annotations/team_event.py +++ b/rustplus/annotations/team_event.py @@ -5,14 +5,14 @@ from ..events import TeamEventPayload as TeamEventManager -def TeamEvent(server_id: ServerDetails) -> Callable: +def TeamEvent(server_details: ServerDetails) -> Callable: def wrapper(func) -> RegisteredListener: if isinstance(func, RegisteredListener): func = func.get_coro() listener = RegisteredListener(func.__name__, func) - TeamEventManager.HANDLER_LIST.register(listener, server_id) + TeamEventManager.HANDLER_LIST.register(listener, server_details) return listener diff --git a/rustplus/commands/chat_command.py b/rustplus/commands/chat_command.py index cbee811..c7cbce7 100644 --- a/rustplus/commands/chat_command.py +++ b/rustplus/commands/chat_command.py @@ -14,7 +14,9 @@ class ChatCommandTime: class ChatCommand: - REGISTERED_COMMANDS: Dict[ServerDetails, Dict[str, ChatCommandData]] = defaultdict(dict) + REGISTERED_COMMANDS: Dict[ServerDetails, Dict[str, ChatCommandData]] = defaultdict( + dict + ) def __init__( self, diff --git a/rustplus/identification/__init__.py b/rustplus/identification/__init__.py index 3eac49e..e2b323b 100644 --- a/rustplus/identification/__init__.py +++ b/rustplus/identification/__init__.py @@ -1,2 +1,2 @@ from .registered_listener import RegisteredListener, RegisteredEntityListener -from .server_id import ServerDetails +from .server_details import ServerDetails diff --git a/rustplus/identification/server_id.py b/rustplus/identification/server_details.py similarity index 69% rename from rustplus/identification/server_id.py rename to rustplus/identification/server_details.py index 2bb9640..24e2679 100644 --- a/rustplus/identification/server_id.py +++ b/rustplus/identification/server_details.py @@ -1,16 +1,29 @@ +from typing import Union + + class ServerDetails: - def __init__(self, ip: str, port: str, player_id: int, player_token: int) -> None: + def __init__( + self, + ip: str, + port: Union[str, int, None], + player_id: int, + player_token: int, + secure: bool = False, + ) -> None: self.ip = ip self.port = port self.player_id = player_id self.player_token = player_token - - def __str__(self) -> str: - return f"{self.ip}:{self.port} {self.player_id} {self.player_token}" + self.secure = secure def get_server_string(self) -> str: + if self.port is None: + return f"{self.ip}" return f"{self.ip}:{self.port}" + def __str__(self) -> str: + return f"{self.ip}:{self.port} {self.player_id} {self.player_token}" + def __hash__(self): return hash(self.__str__()) diff --git a/rustplus/remote/handler_list.py b/rustplus/remote/handler_list.py index 95b1a47..703a232 100644 --- a/rustplus/remote/handler_list.py +++ b/rustplus/remote/handler_list.py @@ -7,49 +7,62 @@ class HandlerList: def __init__(self) -> None: self._handlers: Dict[ServerDetails, Set[RegisteredListener]] = defaultdict(set) - def unregister(self, listener: RegisteredListener, server_id: ServerDetails) -> None: - self._handlers[server_id].remove(listener) + def unregister( + self, listener: RegisteredListener, server_details: ServerDetails + ) -> None: + self._handlers[server_details].remove(listener) - def register(self, listener: RegisteredListener, server_id: ServerDetails) -> None: - self._handlers[server_id].add(listener) + def register( + self, listener: RegisteredListener, server_details: ServerDetails + ) -> None: + self._handlers[server_details].add(listener) - def has(self, listener: RegisteredListener, server_id: ServerDetails) -> bool: - return listener in self._handlers[server_id] + def has(self, listener: RegisteredListener, server_details: ServerDetails) -> bool: + return listener in self._handlers[server_details] def unregister_all(self) -> None: self._handlers.clear() - def get_handlers(self, server_id: ServerDetails) -> Set[RegisteredListener]: - return self._handlers.get(server_id, set()) + def get_handlers(self, server_details: ServerDetails) -> Set[RegisteredListener]: + return self._handlers.get(server_details, set()) class EntityHandlerList(HandlerList): def __init__(self) -> None: super().__init__() - self._handlers: Dict[ServerDetails, Dict[str, Set[RegisteredEntityListener]]] = ( - defaultdict(dict) - ) + self._handlers: Dict[ + ServerDetails, Dict[str, Set[RegisteredEntityListener]] + ] = defaultdict(dict) def unregister( - self, listener: RegisteredEntityListener, server_id: ServerDetails + self, listener: RegisteredEntityListener, server_details: ServerDetails ) -> None: - if listener.listener_id in self._handlers.get(server_id): - self._handlers.get(server_id).get(listener.listener_id).remove(listener) + if listener.listener_id in self._handlers.get(server_details): + self._handlers.get(server_details).get(listener.listener_id).remove( + listener + ) - def register(self, listener: RegisteredEntityListener, server_id: ServerDetails) -> None: - if server_id not in self._handlers: - self._handlers[server_id] = defaultdict(set) + def register( + self, listener: RegisteredEntityListener, server_details: ServerDetails + ) -> None: + if server_details not in self._handlers: + self._handlers[server_details] = defaultdict(set) - if listener.listener_id not in self._handlers.get(server_id): - self._handlers.get(server_id)[listener.listener_id] = set() + if listener.listener_id not in self._handlers.get(server_details): + self._handlers.get(server_details)[listener.listener_id] = set() - self._handlers.get(server_id).get(listener.listener_id).add(listener) + self._handlers.get(server_details).get(listener.listener_id).add(listener) - def has(self, listener: RegisteredEntityListener, server_id: ServerDetails) -> bool: - if server_id in self._handlers and listener.listener_id in self._handlers.get( - server_id + def has( + self, listener: RegisteredEntityListener, server_details: ServerDetails + ) -> bool: + if ( + server_details in self._handlers + and listener.listener_id in self._handlers.get(server_details) ): - return listener in self._handlers.get(server_id).get(listener.listener_id) + return listener in self._handlers.get(server_details).get( + listener.listener_id + ) return False @@ -57,6 +70,6 @@ def unregister_all(self) -> None: self._handlers.clear() def get_handlers( - self, server_id: ServerDetails + self, server_details: ServerDetails ) -> Dict[str, Set[RegisteredEntityListener]]: - return self._handlers.get(server_id, dict()) + return self._handlers.get(server_details, dict()) diff --git a/rustplus/remote/ratelimiter/__init__.py b/rustplus/remote/ratelimiter/__init__.py index d950106..f39ff86 100644 --- a/rustplus/remote/ratelimiter/__init__.py +++ b/rustplus/remote/ratelimiter/__init__.py @@ -54,21 +54,21 @@ def __init__(self) -> None: def add_socket( self, - server_id: ServerDetails, + server_details: ServerDetails, current: float, maximum: float, refresh_rate: float, refresh_amount: float, ) -> None: - self.socket_buckets[server_id] = TokenBucket( + self.socket_buckets[server_details] = TokenBucket( current, maximum, refresh_rate, refresh_amount ) - if server_id.get_server_string() not in self.server_buckets: - self.server_buckets[server_id.get_server_string()] = TokenBucket( + if server_details.get_server_string() not in self.server_buckets: + self.server_buckets[server_details.get_server_string()] = TokenBucket( self.SERVER_LIMIT, self.SERVER_LIMIT, 1, self.SERVER_REFRESH_AMOUNT ) - async def can_consume(self, server_id: ServerDetails, amount: int = 1) -> bool: + async def can_consume(self, server_details: ServerDetails, amount: int = 1) -> bool: """ Returns whether the user can consume the amount of tokens provided """ @@ -76,8 +76,8 @@ async def can_consume(self, server_id: ServerDetails, amount: int = 1) -> bool: can_consume = True for bucket in [ - self.socket_buckets.get(server_id), - self.server_buckets.get(server_id.get_server_string()), + self.socket_buckets.get(server_details), + self.server_buckets.get(server_details.get_server_string()), ]: bucket.refresh() if not bucket.can_consume(amount): @@ -85,14 +85,14 @@ async def can_consume(self, server_id: ServerDetails, amount: int = 1) -> bool: return can_consume - async def consume(self, server_id: ServerDetails, amount: int = 1) -> None: + async def consume(self, server_details: ServerDetails, amount: int = 1) -> None: """ Consumes an amount of tokens from the bucket. You should first check to see whether it is possible with can_consume """ async with self.lock: for bucket in [ - self.socket_buckets.get(server_id), - self.server_buckets.get(server_id.get_server_string()), + self.socket_buckets.get(server_details), + self.server_buckets.get(server_details.get_server_string()), ]: bucket.refresh() if not bucket.can_consume(amount): @@ -101,7 +101,7 @@ async def consume(self, server_id: ServerDetails, amount: int = 1) -> None: bucket.consume(amount) async def get_estimated_delay_time( - self, server_id: ServerDetails, target_cost: int + self, server_details: ServerDetails, target_cost: int ) -> float: """ Returns how long until the amount of tokens needed will be available @@ -109,8 +109,8 @@ async def get_estimated_delay_time( async with self.lock: delay = 0 for bucket in [ - self.socket_buckets.get(server_id), - self.server_buckets.get(server_id.get_server_string()), + self.socket_buckets.get(server_details), + self.server_buckets.get(server_details.get_server_string()), ]: val = ( math.ceil( @@ -126,9 +126,9 @@ async def get_estimated_delay_time( delay = val return delay - async def remove(self, server_id: ServerDetails) -> None: + async def remove(self, server_details: ServerDetails) -> None: """ Removes the limiter """ async with self.lock: - del self.socket_buckets[server_id] + del self.socket_buckets[server_details] diff --git a/rustplus/remote/rustplus_proto/rustplus.py b/rustplus/remote/rustplus_proto/rustplus.py index 83d882b..dc53467 100644 --- a/rustplus/remote/rustplus_proto/rustplus.py +++ b/rustplus/remote/rustplus_proto/rustplus.py @@ -426,7 +426,7 @@ class AppClanChat(betterproto.Message): @dataclass class AppNexusAuth(betterproto.Message): - server_id: str = betterproto.string_field(1) + server_details: str = betterproto.string_field(1) player_token: int = betterproto.int32_field(2) diff --git a/rustplus/remote/websocket/ws.py b/rustplus/remote/websocket/ws.py index 6fc5d21..d325d6d 100644 --- a/rustplus/remote/websocket/ws.py +++ b/rustplus/remote/websocket/ws.py @@ -28,9 +28,11 @@ class RustWebsocket: RESPONSE_TIMEOUT = 5 def __init__( - self, server_id: ServerDetails, command_options: Union[CommandOptions, None] + self, + server_details: ServerDetails, + command_options: Union[CommandOptions, None], ) -> None: - self.server_id: ServerDetails = server_id + self.server_details: ServerDetails = server_details self.command_options: Union[CommandOptions, None] = command_options self.connection: Union[WebSocketClientProtocol, None] = None self.logger: logging.Logger = logging.getLogger("rustplus.py") @@ -42,7 +44,10 @@ def __init__( async def connect(self) -> bool: - address = "ws://" + self.server_id.get_server_string() + address = ( + f"{'wss' if self.server_details.secure else 'ws'}://" + + self.server_details.get_server_string() + ) try: self.connection = await connect( @@ -69,7 +74,7 @@ async def run(self) -> None: data = await self.connection.recv() await self.run_coroutine_non_blocking( - self.run_proto_event(data, self.server_id) + self.run_proto_event(data, self.server_details) ) app_message = AppMessage() @@ -141,7 +146,9 @@ async def handle_message(self, app_message: AppMessage) -> None: parts = shlex.split(message.message) command = parts[0][len(prefix) :] - data = ChatCommand.REGISTERED_COMMANDS[self.server_id].get(command, None) + data = ChatCommand.REGISTERED_COMMANDS[self.server_details].get( + command, None + ) dao = ChatCommand( message.name, @@ -158,7 +165,7 @@ async def handle_message(self, app_message: AppMessage) -> None: await data.coroutine(dao) else: for command_name, data in ChatCommand.REGISTERED_COMMANDS[ - self.server_id + self.server_details ].items(): if command in data.aliases or data.callable_func(command): await data.coroutine(dao) @@ -170,9 +177,9 @@ async def handle_message(self, app_message: AppMessage) -> None: if self.debug: self.logger.info(f"Running Entity Event: {app_message}") - handlers = EntityEventPayload.HANDLER_LIST.get_handlers(self.server_id).get( - str(app_message.broadcast.entity_changed.entity_id), [] - ) + handlers = EntityEventPayload.HANDLER_LIST.get_handlers( + self.server_details + ).get(str(app_message.broadcast.entity_changed.entity_id), []) for handler in handlers: handler.get_coro()( EntityEventPayload( @@ -196,7 +203,7 @@ async def handle_message(self, app_message: AppMessage) -> None: self.logger.info(f"Running Team Event: {app_message}") # This means that the team of the current player has changed - handlers = TeamEventPayload.HANDLER_LIST.get_handlers(self.server_id) + handlers = TeamEventPayload.HANDLER_LIST.get_handlers(self.server_details) team_event = TeamEventPayload( app_message.broadcast.team_changed.player_id, RustTeamInfo(app_message.broadcast.team_changed.team_info), @@ -210,7 +217,7 @@ async def handle_message(self, app_message: AppMessage) -> None: if self.debug: self.logger.info(f"Running Chat Event: {app_message}") - handlers = ChatEventPayload.HANDLER_LIST.get_handlers(self.server_id) + handlers = ChatEventPayload.HANDLER_LIST.get_handlers(self.server_details) chat_event = ChatEventPayload( RustChatMessage(app_message.broadcast.team_message.message) ) @@ -237,9 +244,11 @@ def get_prefix(self, message: str) -> Optional[str]: return None @staticmethod - async def run_proto_event(data: Union[str, bytes], server_id: ServerDetails) -> None: + async def run_proto_event( + data: Union[str, bytes], server_details: ServerDetails + ) -> None: handlers: Set[RegisteredListener] = ( - ProtobufEventPayload.HANDLER_LIST.get_handlers(server_id) + ProtobufEventPayload.HANDLER_LIST.get_handlers(server_details) ) for handler in handlers: await handler.get_coro()(data) diff --git a/rustplus/rust_api.py b/rustplus/rust_api.py index f74c5ba..e91334f 100644 --- a/rustplus/rust_api.py +++ b/rustplus/rust_api.py @@ -36,11 +36,11 @@ class RustSocket: def __init__( self, - server_id: ServerDetails, + server_details: ServerDetails, ratelimiter: Union[None, RateLimiter] = None, command_options: Union[None, CommandOptions] = None, ) -> None: - self.server_id = server_id + self.server_details = server_details self.command_options = command_options self.logger = logging.getLogger("rustplus.py") @@ -52,7 +52,7 @@ def __init__( self.logger.addHandler(console_handler) self.logger.setLevel(logging.DEBUG) - self.ws = RustWebsocket(self.server_id, self.command_options) + self.ws = RustWebsocket(self.server_details, self.command_options) self.seq = 1 if ratelimiter: @@ -61,7 +61,7 @@ def __init__( self.ratelimiter = RateLimiter() self.ratelimiter.add_socket( - self.server_id, + self.server_details, RateLimiter.SERVER_LIMIT, RateLimiter.SERVER_LIMIT, 1, @@ -70,12 +70,14 @@ def __init__( async def _handle_ratelimit(self, tokens) -> None: while True: - if await self.ratelimiter.can_consume(self.server_id, tokens): - await self.ratelimiter.consume(self.server_id, tokens) + if await self.ratelimiter.can_consume(self.server_details, tokens): + await self.ratelimiter.consume(self.server_details, tokens) break await asyncio.sleep( - await self.ratelimiter.get_estimated_delay_time(self.server_id, tokens) + await self.ratelimiter.get_estimated_delay_time( + self.server_details, tokens + ) ) async def _generate_request(self, tokens=1) -> AppRequest: @@ -84,8 +86,8 @@ async def _generate_request(self, tokens=1) -> AppRequest: app_request = AppRequest() app_request.seq = self.seq self.seq += 1 - app_request.player_id = self.server_id.player_id - app_request.player_token = self.server_id.player_token + app_request.player_id = self.server_details.player_id + app_request.player_token = self.server_details.player_token return app_request