From 37f3c778e8dfc5db3427d49b0b258fd97839c347 Mon Sep 17 00:00:00 2001 From: Pum <28806828+PumPum7@users.noreply.github.com> Date: Mon, 5 Aug 2024 18:44:06 +0200 Subject: [PATCH] update subscription type to also use the subscription type data class --- tatsu/data_structures.py | 30 ++++++++++++++++++++++++++++++ tatsu/wrapper.py | 7 ++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tatsu/data_structures.py b/tatsu/data_structures.py index 5dd8ec3..55a43fe 100644 --- a/tatsu/data_structures.py +++ b/tatsu/data_structures.py @@ -76,6 +76,36 @@ def to_dict(self): } +class SubscriptionType: + SUBSCRIPTION_MAP = { + 0: "None", + 1: "Supporter", + 2: "Supporter+", + 3: "Supporter++", + } + + def __init__(self, subscription_type: int): + self.subscription_type: int = subscription_type + + def __str__(self): + subscription_str = self.SUBSCRIPTION_MAP.get(self.subscription_type, "Unknown") + return f"SubscriptionType(subscription_type={subscription_str})" + + def __repr__(self): + subscription_str = self.SUBSCRIPTION_MAP.get(self.subscription_type, "Unknown") + return f"SubscriptionType(subscription_type={subscription_str!r})" + + def __eq__(self, other): + if isinstance(other, SubscriptionType): + return self.subscription_type == other.subscription_type + return False + + def to_dict(self): + return { + "subscription_type": self.subscription_type, + } + + class GuildRankings: def __init__(self, guild_id: int, rankings: list, original: dict): self.guild_id: int = int(guild_id) if guild_id else guild_id diff --git a/tatsu/wrapper.py b/tatsu/wrapper.py index bad7b6f..9ac4e0b 100644 --- a/tatsu/wrapper.py +++ b/tatsu/wrapper.py @@ -39,6 +39,11 @@ async def get_profile(self, user_id: int) -> ds.UserProfile: if subscription_renewal_str else None ) + + # Map the subscription type to the subscription type object + subscription_type = result.get("subscription_type", 0) + subscription_type = ds.SubscriptionType(subscription_type) + user_profile_data = { "avatar_hash": result.get("avatar_hash", None), "avatar_url": result.get("avatar_url", None), @@ -47,7 +52,7 @@ async def get_profile(self, user_id: int) -> ds.UserProfile: "user_id": result.get("id", None), "info_box": result.get("info_box", None), "reputation": result.get("reputation", None), - "subscription_type": result.get("subscription_type", None), + "subscription_type": subscription_type, "subscription_renewal": subscription_renewal, "title": result.get("title", None), "tokens": result.get("tokens", None),