diff --git a/hushline/premium.py b/hushline/premium.py index 049ad126..12f13a0b 100644 --- a/hushline/premium.py +++ b/hushline/premium.py @@ -44,7 +44,7 @@ def create_products_and_prices() -> None: current_app.logger.info("Creating products and prices") # Make sure the products and prices are created in Stripe - business_tier = db.session.query(Tier).get(BUSINESS_TIER) + business_tier = db.session.get(Tier, BUSINESS_TIER) if not business_tier: current_app.logger.error("Could not find business tier") return @@ -104,7 +104,8 @@ def create_products_and_prices() -> None: stripe_price = stripe.Price.retrieve(str(stripe_product.default_price)) current_app.logger.info(f"Found Stripe price for tier: {business_tier.name}") business_tier.stripe_price_id = stripe_price.id - business_tier.monthly_amount = stripe_price.unit_amount + if stripe_price.unit_amount: + business_tier.monthly_amount = stripe_price.unit_amount db.session.add(business_tier) db.session.commit() found = True @@ -189,7 +190,7 @@ def get_subscription(user: User) -> stripe.Subscription | None: def get_business_price_string() -> str: - business_tier = db.session.query(Tier).get(BUSINESS_TIER) + business_tier = db.session.get(Tier, BUSINESS_TIER) if not business_tier: current_app.logger.error("Could not find business tier") return "NA" @@ -255,7 +256,9 @@ def handle_subscription_updated(subscription: stripe.Subscription) -> None: def handle_subscription_deleted(subscription: stripe.Subscription) -> None: # customer.subscription.deleted - user = db.session.query(User).filter_by(stripe_subscription_id=subscription.id).first() + user = db.session.scalars( + db.select(User).filter_by(stripe_subscription_id=subscription.id) + ).one_or_none() if user: user.tier_id = FREE_TIER user.stripe_subscription_id = None @@ -282,7 +285,9 @@ def handle_invoice_created(invoice: stripe.Invoice) -> None: def handle_invoice_updated(invoice: stripe.Invoice) -> None: # invoice.updated - stripe_invoice = db.session.query(StripeInvoice).filter_by(invoice_id=invoice.id).first() + stripe_invoice = db.session.scalars( + db.select(StripeInvoice).filter_by(invoice_id=invoice.id) + ).one_or_none() if stripe_invoice: stripe_invoice.total = invoice.total stripe_invoice.status = StripeInvoiceStatusEnum(invoice.status)