Skip to content

Commit

Permalink
Fix Tier constructor, and update code to handle optional stripe_produ…
Browse files Browse the repository at this point in the history
…ct_id and stripe_price_id
  • Loading branch information
micahflee committed Sep 27, 2024
1 parent e36387c commit d331bbd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
12 changes: 9 additions & 3 deletions hushline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ class Tier(Model):
stripe_price_id: Mapped[Optional[str]] = mapped_column(db.String(255), unique=True)

def __init__(self, name: str, monthly_amount: int) -> None:
super().__init__(name=name, monthly_amount=monthly_amount)
super().__init__()
self.name = name
self.monthly_amount = monthly_amount


class StripeEvent(Model):
Expand Down Expand Up @@ -410,7 +412,9 @@ def __init__(self, invoice: Invoice):
self.created_at = datetime.fromtimestamp(invoice.created, tz=timezone.utc)

# Look up the user by their customer ID
user = db.session.scalars(db.select(User).filter_by(stripe_customer_id=invoice.customer)).one_or_none()
user = db.session.scalars(
db.select(User).filter_by(stripe_customer_id=invoice.customer)
).one_or_none()
if user:
self.user_id = user.id
else:
Expand All @@ -420,7 +424,9 @@ def __init__(self, invoice: Invoice):
if invoice.lines.data[0].plan:
product_id = invoice.lines.data[0].plan.product

tier = db.session.scalars(db.select(Tier).filter_by(stripe_product_id=product_id)).one_or_none()
tier = db.session.scalars(
db.select(Tier).filter_by(stripe_product_id=product_id)
).one_or_none()
if tier:
self.tier_id = tier.id
else:
Expand Down
16 changes: 14 additions & 2 deletions hushline/premium.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def create_products_and_prices() -> None:
def update_price(tier: Tier) -> None:
current_app.logger.info(f"Updating price for tier {tier.name} to {tier.monthly_amount}")

if not tier.stripe_product_id:
current_app.logger.error(f"Tier {tier.name} does not have a product ID")
return

# See if we already have an appropriate price for this product
prices = stripe.Price.search(query=f'product:"{tier.stripe_product_id}"')
found_price_id = None
Expand Down Expand Up @@ -202,7 +206,9 @@ def get_business_price_string() -> str:
def handle_subscription_created(subscription: stripe.Subscription) -> None:
# customer.subscription.created

user = db.session.scalars(db.select(User).filter_by(stripe_customer_id=subscription.customer)).one_or_none()
user = db.session.scalars(
db.select(User).filter_by(stripe_customer_id=subscription.customer)
).one_or_none()
if user:
user.stripe_subscription_id = subscription.id
user.stripe_subscription_status = StripeSubscriptionStatusEnum(subscription.status)
Expand All @@ -222,7 +228,9 @@ def handle_subscription_updated(subscription: stripe.Subscription) -> None:
# customer.subscription.updated

# If subscription changes to cancel or unpaid, downgrade user
user = db.session.scalars(db.select(User).filter_by(stripe_subscription_id=subscription.id)).one_or_none()
user = db.session.scalars(
db.select(User).filter_by(stripe_subscription_id=subscription.id)
).one_or_none()
if user:
user.stripe_subscription_status = StripeSubscriptionStatusEnum(subscription.status)
user.stripe_subscription_cancel_at_period_end = subscription.cancel_at_period_end
Expand Down Expand Up @@ -430,6 +438,10 @@ def upgrade() -> Response | str:
current_app.logger.error("Could not find business tier")
flash("⚠️ Something went wrong!")
return redirect(url_for("premium.index"))
if not business_tier.stripe_price_id:
current_app.logger.error("Business tier does not have a price ID")
flash("⚠️ Something went wrong!")
return redirect(url_for("premium.index"))

# Make sure the user has a Stripe customer
try:
Expand Down

0 comments on commit d331bbd

Please sign in to comment.