From d331bbde71174aab2a910c10f5617a24f3b7fb3a Mon Sep 17 00:00:00 2001 From: Micah Lee Date: Fri, 27 Sep 2024 11:10:31 -0700 Subject: [PATCH] Fix Tier constructor, and update code to handle optional stripe_product_id and stripe_price_id --- hushline/model.py | 12 +++++++++--- hushline/premium.py | 16 ++++++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/hushline/model.py b/hushline/model.py index 9d73849d..faf6a639 100644 --- a/hushline/model.py +++ b/hushline/model.py @@ -354,7 +354,9 @@ class Tier(Model): stripe_price_id: Mapped[Optional[str]] = mapped_column(db.String(255), unique=True) def __init__(self, name: str, monthly_amount: int) -> None: - super().__init__(name=name, monthly_amount=monthly_amount) + super().__init__() + self.name = name + self.monthly_amount = monthly_amount class StripeEvent(Model): @@ -410,7 +412,9 @@ def __init__(self, invoice: Invoice): self.created_at = datetime.fromtimestamp(invoice.created, tz=timezone.utc) # Look up the user by their customer ID - user = db.session.scalars(db.select(User).filter_by(stripe_customer_id=invoice.customer)).one_or_none() + user = db.session.scalars( + db.select(User).filter_by(stripe_customer_id=invoice.customer) + ).one_or_none() if user: self.user_id = user.id else: @@ -420,7 +424,9 @@ def __init__(self, invoice: Invoice): if invoice.lines.data[0].plan: product_id = invoice.lines.data[0].plan.product - tier = db.session.scalars(db.select(Tier).filter_by(stripe_product_id=product_id)).one_or_none() + tier = db.session.scalars( + db.select(Tier).filter_by(stripe_product_id=product_id) + ).one_or_none() if tier: self.tier_id = tier.id else: diff --git a/hushline/premium.py b/hushline/premium.py index bfe764b0..049ad126 100644 --- a/hushline/premium.py +++ b/hushline/premium.py @@ -130,6 +130,10 @@ def create_products_and_prices() -> None: def update_price(tier: Tier) -> None: current_app.logger.info(f"Updating price for tier {tier.name} to {tier.monthly_amount}") + if not tier.stripe_product_id: + current_app.logger.error(f"Tier {tier.name} does not have a product ID") + return + # See if we already have an appropriate price for this product prices = stripe.Price.search(query=f'product:"{tier.stripe_product_id}"') found_price_id = None @@ -202,7 +206,9 @@ def get_business_price_string() -> str: def handle_subscription_created(subscription: stripe.Subscription) -> None: # customer.subscription.created - user = db.session.scalars(db.select(User).filter_by(stripe_customer_id=subscription.customer)).one_or_none() + user = db.session.scalars( + db.select(User).filter_by(stripe_customer_id=subscription.customer) + ).one_or_none() if user: user.stripe_subscription_id = subscription.id user.stripe_subscription_status = StripeSubscriptionStatusEnum(subscription.status) @@ -222,7 +228,9 @@ def handle_subscription_updated(subscription: stripe.Subscription) -> None: # customer.subscription.updated # If subscription changes to cancel or unpaid, downgrade user - user = db.session.scalars(db.select(User).filter_by(stripe_subscription_id=subscription.id)).one_or_none() + user = db.session.scalars( + db.select(User).filter_by(stripe_subscription_id=subscription.id) + ).one_or_none() if user: user.stripe_subscription_status = StripeSubscriptionStatusEnum(subscription.status) user.stripe_subscription_cancel_at_period_end = subscription.cancel_at_period_end @@ -430,6 +438,10 @@ def upgrade() -> Response | str: current_app.logger.error("Could not find business tier") flash("⚠️ Something went wrong!") return redirect(url_for("premium.index")) + if not business_tier.stripe_price_id: + current_app.logger.error("Business tier does not have a price ID") + flash("⚠️ Something went wrong!") + return redirect(url_for("premium.index")) # Make sure the user has a Stripe customer try: