|
""" |
|
Stripe Billing API implementation for Suna on top of Basejump. ONLY HAS SUPPOT FOR USER ACCOUNTS – no team accounts. As we are using the user_id as account_id as is the case with personal accounts. In personal accounts, the account_id equals the user_id. In team accounts, the account_id is unique. |
|
|
|
stripe listen --forward-to localhost:8000/api/billing/webhook |
|
""" |
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Request |
|
from typing import Optional, Dict, Any, List, Tuple |
|
import stripe |
|
from datetime import datetime, timezone |
|
from utils.logger import logger |
|
from utils.config import config, EnvMode |
|
from services.supabase import DBConnection |
|
from utils.auth_utils import get_current_user_id_from_jwt |
|
from pydantic import BaseModel, Field |
|
|
|
|
|
stripe.api_key = config.STRIPE_SECRET_KEY |
|
|
|
|
|
router = APIRouter(prefix="/billing", tags=["billing"]) |
|
|
|
SUBSCRIPTION_TIERS = { |
|
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 60}, |
|
config.STRIPE_TIER_2_20_ID: {'name': 'tier_2_20', 'minutes': 120}, |
|
config.STRIPE_TIER_6_50_ID: {'name': 'tier_6_50', 'minutes': 360}, |
|
config.STRIPE_TIER_12_100_ID: {'name': 'tier_12_100', 'minutes': 720}, |
|
config.STRIPE_TIER_25_200_ID: {'name': 'tier_25_200', 'minutes': 1500}, |
|
config.STRIPE_TIER_50_400_ID: {'name': 'tier_50_400', 'minutes': 3000}, |
|
config.STRIPE_TIER_125_800_ID: {'name': 'tier_125_800', 'minutes': 7500}, |
|
config.STRIPE_TIER_200_1000_ID: {'name': 'tier_200_1000', 'minutes': 12000}, |
|
} |
|
|
|
|
|
class CreateCheckoutSessionRequest(BaseModel): |
|
price_id: str |
|
success_url: str |
|
cancel_url: str |
|
|
|
class CreatePortalSessionRequest(BaseModel): |
|
return_url: str |
|
|
|
class SubscriptionStatus(BaseModel): |
|
status: str |
|
plan_name: Optional[str] = None |
|
price_id: Optional[str] = None |
|
current_period_end: Optional[datetime] = None |
|
cancel_at_period_end: bool = False |
|
trial_end: Optional[datetime] = None |
|
minutes_limit: Optional[int] = None |
|
current_usage: Optional[float] = None |
|
|
|
has_schedule: bool = False |
|
scheduled_plan_name: Optional[str] = None |
|
scheduled_price_id: Optional[str] = None |
|
scheduled_change_date: Optional[datetime] = None |
|
|
|
|
|
async def get_stripe_customer_id(client, user_id: str) -> Optional[str]: |
|
"""Get the Stripe customer ID for a user.""" |
|
result = await client.schema('basejump').from_('billing_customers') \ |
|
.select('id') \ |
|
.eq('account_id', user_id) \ |
|
.execute() |
|
|
|
if result.data and len(result.data) > 0: |
|
return result.data[0]['id'] |
|
return None |
|
|
|
async def create_stripe_customer(client, user_id: str, email: str) -> str: |
|
"""Create a new Stripe customer for a user.""" |
|
|
|
customer = stripe.Customer.create( |
|
email=email, |
|
metadata={"user_id": user_id} |
|
) |
|
|
|
|
|
await client.schema('basejump').from_('billing_customers').insert({ |
|
'id': customer.id, |
|
'account_id': user_id, |
|
'email': email, |
|
'provider': 'stripe' |
|
}).execute() |
|
|
|
return customer.id |
|
|
|
async def get_user_subscription(user_id: str) -> Optional[Dict]: |
|
"""Get the current subscription for a user from Stripe.""" |
|
try: |
|
|
|
db = DBConnection() |
|
client = await db.client |
|
customer_id = await get_stripe_customer_id(client, user_id) |
|
|
|
if not customer_id: |
|
return None |
|
|
|
|
|
subscriptions = stripe.Subscription.list( |
|
customer=customer_id, |
|
status='active' |
|
) |
|
|
|
|
|
|
|
if not subscriptions or not subscriptions.get('data'): |
|
return None |
|
|
|
|
|
our_subscriptions = [] |
|
for sub in subscriptions['data']: |
|
|
|
if sub.get('items') and sub['items'].get('data') and len(sub['items']['data']) > 0: |
|
item = sub['items']['data'][0] |
|
if item.get('price') and item['price'].get('id') in [ |
|
config.STRIPE_FREE_TIER_ID, |
|
config.STRIPE_TIER_2_20_ID, |
|
config.STRIPE_TIER_6_50_ID, |
|
config.STRIPE_TIER_12_100_ID, |
|
config.STRIPE_TIER_25_200_ID, |
|
config.STRIPE_TIER_50_400_ID, |
|
config.STRIPE_TIER_125_800_ID, |
|
config.STRIPE_TIER_200_1000_ID |
|
]: |
|
our_subscriptions.append(sub) |
|
|
|
if not our_subscriptions: |
|
return None |
|
|
|
|
|
if len(our_subscriptions) > 1: |
|
logger.warning(f"User {user_id} has multiple active subscriptions: {[sub['id'] for sub in our_subscriptions]}") |
|
|
|
|
|
most_recent = max(our_subscriptions, key=lambda x: x['created']) |
|
|
|
|
|
for sub in our_subscriptions: |
|
if sub['id'] != most_recent['id']: |
|
try: |
|
stripe.Subscription.modify( |
|
sub['id'], |
|
cancel_at_period_end=True |
|
) |
|
logger.info(f"Cancelled subscription {sub['id']} for user {user_id}") |
|
except Exception as e: |
|
logger.error(f"Error cancelling subscription {sub['id']}: {str(e)}") |
|
|
|
return most_recent |
|
|
|
return our_subscriptions[0] |
|
|
|
except Exception as e: |
|
logger.error(f"Error getting subscription from Stripe: {str(e)}") |
|
return None |
|
|
|
async def calculate_monthly_usage(client, user_id: str) -> float: |
|
"""Calculate total agent run minutes for the current month for a user.""" |
|
|
|
now = datetime.now(timezone.utc) |
|
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc) |
|
|
|
|
|
threads_result = await client.table('threads') \ |
|
.select('thread_id') \ |
|
.eq('account_id', user_id) \ |
|
.execute() |
|
|
|
if not threads_result.data: |
|
return 0.0 |
|
|
|
thread_ids = [t['thread_id'] for t in threads_result.data] |
|
|
|
|
|
runs_result = await client.table('agent_runs') \ |
|
.select('started_at, completed_at') \ |
|
.in_('thread_id', thread_ids) \ |
|
.gte('started_at', start_of_month.isoformat()) \ |
|
.execute() |
|
|
|
if not runs_result.data: |
|
return 0.0 |
|
|
|
|
|
total_seconds = 0 |
|
now_ts = now.timestamp() |
|
|
|
for run in runs_result.data: |
|
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp() |
|
if run['completed_at']: |
|
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp() |
|
else: |
|
|
|
end_time = now_ts |
|
|
|
total_seconds += (end_time - start_time) |
|
|
|
return total_seconds / 60 |
|
|
|
async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optional[Dict]]: |
|
""" |
|
Check if a user can run agents based on their subscription and usage. |
|
|
|
Returns: |
|
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info) |
|
""" |
|
if config.ENV_MODE == EnvMode.LOCAL: |
|
logger.info("Running in local development mode - billing checks are disabled") |
|
return True, "Local development mode - billing disabled", { |
|
"price_id": "local_dev", |
|
"plan_name": "Local Development", |
|
"minutes_limit": "no limit" |
|
} |
|
|
|
|
|
subscription = await get_user_subscription(user_id) |
|
|
|
|
|
|
|
if not subscription: |
|
subscription = { |
|
'price_id': config.STRIPE_FREE_TIER_ID, |
|
'plan_name': 'free' |
|
} |
|
|
|
|
|
price_id = None |
|
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0: |
|
price_id = subscription['items']['data'][0]['price']['id'] |
|
else: |
|
price_id = subscription.get('price_id', config.STRIPE_FREE_TIER_ID) |
|
|
|
|
|
tier_info = SUBSCRIPTION_TIERS.get(price_id) |
|
if not tier_info: |
|
logger.warning(f"Unknown subscription tier: {price_id}, defaulting to free tier") |
|
tier_info = SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID] |
|
|
|
|
|
current_usage = await calculate_monthly_usage(client, user_id) |
|
|
|
|
|
if current_usage >= tier_info['minutes']: |
|
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription |
|
|
|
return True, "OK", subscription |
|
|
|
|
|
@router.post("/create-checkout-session") |
|
async def create_checkout_session( |
|
request: CreateCheckoutSessionRequest, |
|
current_user_id: str = Depends(get_current_user_id_from_jwt) |
|
): |
|
"""Create a Stripe Checkout session or modify an existing subscription.""" |
|
try: |
|
|
|
db = DBConnection() |
|
client = await db.client |
|
|
|
|
|
user_result = await client.auth.admin.get_user_by_id(current_user_id) |
|
if not user_result: raise HTTPException(status_code=404, detail="User not found") |
|
email = user_result.user.email |
|
|
|
|
|
customer_id = await get_stripe_customer_id(client, current_user_id) |
|
if not customer_id: customer_id = await create_stripe_customer(client, current_user_id, email) |
|
|
|
|
|
try: |
|
price = stripe.Price.retrieve(request.price_id, expand=['product']) |
|
product_id = price['product']['id'] |
|
except stripe.error.InvalidRequestError: |
|
raise HTTPException(status_code=400, detail=f"Invalid price ID: {request.price_id}") |
|
|
|
|
|
if product_id != config.STRIPE_PRODUCT_ID: |
|
raise HTTPException(status_code=400, detail="Price ID does not belong to the correct product.") |
|
|
|
|
|
existing_subscription = await get_user_subscription(current_user_id) |
|
|
|
|
|
if existing_subscription: |
|
|
|
try: |
|
subscription_id = existing_subscription['id'] |
|
subscription_item = existing_subscription['items']['data'][0] |
|
current_price_id = subscription_item['price']['id'] |
|
|
|
|
|
if current_price_id == request.price_id: |
|
return { |
|
"subscription_id": subscription_id, |
|
"status": "no_change", |
|
"message": "Already subscribed to this plan.", |
|
"details": { |
|
"is_upgrade": None, |
|
"effective_date": None, |
|
"current_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0, |
|
"new_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0, |
|
} |
|
} |
|
|
|
|
|
current_price = stripe.Price.retrieve(current_price_id) |
|
new_price = price |
|
is_upgrade = new_price['unit_amount'] > current_price['unit_amount'] |
|
|
|
if is_upgrade: |
|
|
|
updated_subscription = stripe.Subscription.modify( |
|
subscription_id, |
|
items=[{ |
|
'id': subscription_item['id'], |
|
'price': request.price_id, |
|
}], |
|
proration_behavior='always_invoice', |
|
billing_cycle_anchor='now' |
|
) |
|
|
|
|
|
await client.schema('basejump').from_('billing_customers').update( |
|
{'active': True} |
|
).eq('id', customer_id).execute() |
|
logger.info(f"Updated customer {customer_id} active status to TRUE after subscription upgrade") |
|
|
|
latest_invoice = None |
|
if updated_subscription.get('latest_invoice'): |
|
latest_invoice = stripe.Invoice.retrieve(updated_subscription['latest_invoice']) |
|
|
|
return { |
|
"subscription_id": updated_subscription['id'], |
|
"status": "updated", |
|
"message": "Subscription upgraded successfully", |
|
"details": { |
|
"is_upgrade": True, |
|
"effective_date": "immediate", |
|
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0, |
|
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0, |
|
"invoice": { |
|
"id": latest_invoice['id'] if latest_invoice else None, |
|
"status": latest_invoice['status'] if latest_invoice else None, |
|
"amount_due": round(latest_invoice['amount_due'] / 100, 2) if latest_invoice else 0, |
|
"amount_paid": round(latest_invoice['amount_paid'] / 100, 2) if latest_invoice else 0 |
|
} if latest_invoice else None |
|
} |
|
} |
|
else: |
|
|
|
try: |
|
current_period_end_ts = subscription_item['current_period_end'] |
|
|
|
|
|
|
|
sub_with_schedule = stripe.Subscription.retrieve(subscription_id) |
|
schedule_id = sub_with_schedule.get('schedule') |
|
|
|
|
|
if schedule_id: |
|
schedule = stripe.SubscriptionSchedule.retrieve(schedule_id) |
|
|
|
|
|
current_phase = None |
|
for phase in reversed(schedule['phases']): |
|
if phase['start_date'] <= datetime.now(timezone.utc).timestamp(): |
|
current_phase = phase |
|
break |
|
if not current_phase: |
|
current_phase = schedule['phases'][-1] |
|
else: |
|
|
|
current_phase = { |
|
'items': existing_subscription['items']['data'], |
|
'start_date': existing_subscription['current_period_start'], |
|
|
|
} |
|
|
|
|
|
|
|
current_phase_items_for_api = [] |
|
for item in current_phase.get('items', []): |
|
price_data = item.get('price') |
|
quantity = item.get('quantity') |
|
price_id = None |
|
|
|
|
|
if isinstance(price_data, dict): |
|
price_id = price_data.get('id') |
|
elif isinstance(price_data, str): |
|
price_id = price_data |
|
|
|
if price_id and quantity is not None: |
|
current_phase_items_for_api.append({'price': price_id, 'quantity': quantity}) |
|
else: |
|
logger.warning(f"Skipping item in current phase due to missing price ID or quantity: {item}") |
|
|
|
if not current_phase_items_for_api: |
|
raise ValueError("Could not determine valid items for the current phase.") |
|
|
|
current_phase_update_data = { |
|
'items': current_phase_items_for_api, |
|
'start_date': current_phase['start_date'], |
|
'end_date': current_period_end_ts, |
|
'proration_behavior': 'none' |
|
|
|
|
|
} |
|
|
|
|
|
new_downgrade_phase_data = { |
|
'items': [{'price': request.price_id, 'quantity': 1}], |
|
'start_date': current_period_end_ts, |
|
'proration_behavior': 'none' |
|
|
|
|
|
} |
|
|
|
|
|
if schedule_id: |
|
|
|
|
|
logger.info(f"Updating existing schedule {schedule_id} for subscription {subscription_id}") |
|
logger.debug(f"Current phase data: {current_phase_update_data}") |
|
logger.debug(f"New phase data: {new_downgrade_phase_data}") |
|
updated_schedule = stripe.SubscriptionSchedule.modify( |
|
schedule_id, |
|
phases=[current_phase_update_data, new_downgrade_phase_data], |
|
end_behavior='release' |
|
) |
|
logger.info(f"Successfully updated schedule {updated_schedule['id']}") |
|
else: |
|
|
|
print(f"Creating new schedule for subscription {subscription_id}") |
|
logger.info(f"Creating new schedule for subscription {subscription_id}") |
|
|
|
logger.debug(f"Subscription details: {subscription_id}, current_period_end_ts: {current_period_end_ts}") |
|
logger.debug(f"Current price: {current_price_id}, New price: {request.price_id}") |
|
|
|
try: |
|
updated_schedule = stripe.SubscriptionSchedule.create( |
|
from_subscription=subscription_id, |
|
phases=[ |
|
{ |
|
'start_date': current_phase['start_date'], |
|
'end_date': current_period_end_ts, |
|
'proration_behavior': 'none', |
|
'items': [ |
|
{ |
|
'price': current_price_id, |
|
'quantity': 1 |
|
} |
|
] |
|
}, |
|
{ |
|
'start_date': current_period_end_ts, |
|
'proration_behavior': 'none', |
|
'items': [ |
|
{ |
|
'price': request.price_id, |
|
'quantity': 1 |
|
} |
|
] |
|
} |
|
], |
|
end_behavior='release' |
|
) |
|
|
|
logger.info(f"Created new schedule {updated_schedule['id']} from subscription {subscription_id}") |
|
|
|
|
|
|
|
fetched_schedule = stripe.SubscriptionSchedule.retrieve(updated_schedule['id']) |
|
logger.info(f"Schedule verification - Status: {fetched_schedule.get('status')}, Phase Count: {len(fetched_schedule.get('phases', []))}") |
|
logger.debug(f"Schedule details: {fetched_schedule}") |
|
except Exception as schedule_error: |
|
logger.exception(f"Failed to create schedule: {str(schedule_error)}") |
|
raise schedule_error |
|
|
|
return { |
|
"subscription_id": subscription_id, |
|
"schedule_id": updated_schedule['id'], |
|
"status": "scheduled", |
|
"message": "Subscription downgrade scheduled", |
|
"details": { |
|
"is_upgrade": False, |
|
"effective_date": "end_of_period", |
|
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0, |
|
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0, |
|
"effective_at": datetime.fromtimestamp(current_period_end_ts, tz=timezone.utc).isoformat() |
|
} |
|
} |
|
except Exception as e: |
|
logger.exception(f"Error handling subscription schedule for sub {subscription_id}: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error handling subscription schedule: {str(e)}") |
|
except Exception as e: |
|
logger.exception(f"Error updating subscription {existing_subscription.get('id') if existing_subscription else 'N/A'}: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error updating subscription: {str(e)}") |
|
else: |
|
|
|
session = stripe.checkout.Session.create( |
|
customer=customer_id, |
|
payment_method_types=['card'], |
|
line_items=[{'price': request.price_id, 'quantity': 1}], |
|
mode='subscription', |
|
success_url=request.success_url, |
|
cancel_url=request.cancel_url, |
|
metadata={ |
|
'user_id': current_user_id, |
|
'product_id': product_id |
|
} |
|
) |
|
|
|
|
|
|
|
await client.schema('basejump').from_('billing_customers').update( |
|
{'active': True} |
|
).eq('id', customer_id).execute() |
|
logger.info(f"Updated customer {customer_id} active status to TRUE after creating checkout session") |
|
|
|
return {"session_id": session['id'], "url": session['url'], "status": "new"} |
|
|
|
except Exception as e: |
|
logger.exception(f"Error creating checkout session: {str(e)}") |
|
|
|
if hasattr(e, 'json_body') and e.json_body and 'error' in e.json_body: |
|
error_detail = e.json_body['error'].get('message', str(e)) |
|
else: |
|
error_detail = str(e) |
|
raise HTTPException(status_code=500, detail=f"Error creating checkout session: {error_detail}") |
|
|
|
@router.post("/create-portal-session") |
|
async def create_portal_session( |
|
request: CreatePortalSessionRequest, |
|
current_user_id: str = Depends(get_current_user_id_from_jwt) |
|
): |
|
"""Create a Stripe Customer Portal session for subscription management.""" |
|
try: |
|
|
|
db = DBConnection() |
|
client = await db.client |
|
|
|
|
|
customer_id = await get_stripe_customer_id(client, current_user_id) |
|
if not customer_id: |
|
raise HTTPException(status_code=404, detail="No billing customer found") |
|
|
|
|
|
try: |
|
|
|
configurations = stripe.billing_portal.Configuration.list(limit=100) |
|
active_config = None |
|
|
|
|
|
for config in configurations.get('data', []): |
|
features = config.get('features', {}) |
|
subscription_update = features.get('subscription_update', {}) |
|
if subscription_update.get('enabled', False): |
|
active_config = config |
|
logger.info(f"Found existing portal configuration with subscription_update enabled: {config['id']}") |
|
break |
|
|
|
|
|
if not active_config: |
|
|
|
if configurations.get('data', []): |
|
default_config = configurations['data'][0] |
|
logger.info(f"Updating default portal configuration: {default_config['id']} to enable subscription_update") |
|
|
|
active_config = stripe.billing_portal.Configuration.update( |
|
default_config['id'], |
|
features={ |
|
'subscription_update': { |
|
'enabled': True, |
|
'proration_behavior': 'create_prorations', |
|
'default_allowed_updates': ['price'] |
|
}, |
|
|
|
'customer_update': default_config.get('features', {}).get('customer_update', {'enabled': True, 'allowed_updates': ['email', 'address']}), |
|
'invoice_history': {'enabled': True}, |
|
'payment_method_update': {'enabled': True} |
|
} |
|
) |
|
else: |
|
|
|
logger.info("Creating new portal configuration with subscription_update enabled") |
|
active_config = stripe.billing_portal.Configuration.create( |
|
business_profile={ |
|
'headline': 'Subscription Management', |
|
'privacy_policy_url': config.FRONTEND_URL + '/privacy', |
|
'terms_of_service_url': config.FRONTEND_URL + '/terms' |
|
}, |
|
features={ |
|
'subscription_update': { |
|
'enabled': True, |
|
'proration_behavior': 'create_prorations', |
|
'default_allowed_updates': ['price'] |
|
}, |
|
'customer_update': { |
|
'enabled': True, |
|
'allowed_updates': ['email', 'address'] |
|
}, |
|
'invoice_history': {'enabled': True}, |
|
'payment_method_update': {'enabled': True} |
|
} |
|
) |
|
|
|
|
|
logger.info(f"Using portal configuration: {active_config['id']} with subscription_update: {active_config.get('features', {}).get('subscription_update', {}).get('enabled', False)}") |
|
|
|
except Exception as config_error: |
|
logger.warning(f"Error configuring portal: {config_error}. Continuing with default configuration.") |
|
|
|
|
|
portal_params = { |
|
"customer": customer_id, |
|
"return_url": request.return_url |
|
} |
|
|
|
|
|
if active_config: |
|
portal_params["configuration"] = active_config['id'] |
|
|
|
|
|
session = stripe.billing_portal.Session.create(**portal_params) |
|
|
|
return {"url": session.url} |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating portal session: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@router.get("/subscription") |
|
async def get_subscription( |
|
current_user_id: str = Depends(get_current_user_id_from_jwt) |
|
): |
|
"""Get the current subscription status for the current user, including scheduled changes.""" |
|
try: |
|
|
|
subscription = await get_user_subscription(current_user_id) |
|
|
|
|
|
if not subscription: |
|
|
|
free_tier_id = config.STRIPE_FREE_TIER_ID |
|
free_tier_info = SUBSCRIPTION_TIERS.get(free_tier_id) |
|
return SubscriptionStatus( |
|
status="no_subscription", |
|
plan_name=free_tier_info.get('name', 'free') if free_tier_info else 'free', |
|
price_id=free_tier_id, |
|
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0 |
|
) |
|
|
|
|
|
current_item = subscription['items']['data'][0] |
|
current_price_id = current_item['price']['id'] |
|
current_tier_info = SUBSCRIPTION_TIERS.get(current_price_id) |
|
if not current_tier_info: |
|
|
|
logger.warning(f"User {current_user_id} subscribed to unknown price {current_price_id}. Defaulting info.") |
|
current_tier_info = {'name': 'unknown', 'minutes': 0} |
|
|
|
|
|
db = DBConnection() |
|
client = await db.client |
|
current_usage = await calculate_monthly_usage(client, current_user_id) |
|
|
|
status_response = SubscriptionStatus( |
|
status=subscription['status'], |
|
plan_name=subscription['plan'].get('nickname') or current_tier_info['name'], |
|
price_id=current_price_id, |
|
current_period_end=datetime.fromtimestamp(current_item['current_period_end'], tz=timezone.utc), |
|
cancel_at_period_end=subscription['cancel_at_period_end'], |
|
trial_end=datetime.fromtimestamp(subscription['trial_end'], tz=timezone.utc) if subscription.get('trial_end') else None, |
|
minutes_limit=current_tier_info['minutes'], |
|
current_usage=round(current_usage, 2), |
|
has_schedule=False |
|
) |
|
|
|
|
|
schedule_id = subscription.get('schedule') |
|
if schedule_id: |
|
try: |
|
schedule = stripe.SubscriptionSchedule.retrieve(schedule_id) |
|
|
|
next_phase = None |
|
current_phase_end = current_item['current_period_end'] |
|
|
|
for phase in schedule.get('phases', []): |
|
|
|
if phase.get('start_date') == current_phase_end: |
|
next_phase = phase |
|
break |
|
|
|
if next_phase: |
|
scheduled_item = next_phase['items'][0] |
|
scheduled_price_id = scheduled_item['price'] |
|
scheduled_tier_info = SUBSCRIPTION_TIERS.get(scheduled_price_id) |
|
|
|
status_response.has_schedule = True |
|
status_response.status = 'scheduled_downgrade' |
|
status_response.scheduled_plan_name = scheduled_tier_info.get('name', 'unknown') if scheduled_tier_info else 'unknown' |
|
status_response.scheduled_price_id = scheduled_price_id |
|
status_response.scheduled_change_date = datetime.fromtimestamp(next_phase['start_date'], tz=timezone.utc) |
|
|
|
except Exception as schedule_error: |
|
logger.error(f"Error retrieving or parsing schedule {schedule_id} for sub {subscription['id']}: {schedule_error}") |
|
|
|
|
|
return status_response |
|
|
|
except Exception as e: |
|
logger.exception(f"Error getting subscription status for user {current_user_id}: {str(e)}") |
|
raise HTTPException(status_code=500, detail="Error retrieving subscription status.") |
|
|
|
@router.get("/check-status") |
|
async def check_status( |
|
current_user_id: str = Depends(get_current_user_id_from_jwt) |
|
): |
|
"""Check if the user can run agents based on their subscription and usage.""" |
|
try: |
|
|
|
db = DBConnection() |
|
client = await db.client |
|
|
|
can_run, message, subscription = await check_billing_status(client, current_user_id) |
|
|
|
return { |
|
"can_run": can_run, |
|
"message": message, |
|
"subscription": subscription |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error checking billing status: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@router.post("/webhook") |
|
async def stripe_webhook(request: Request): |
|
"""Handle Stripe webhook events.""" |
|
try: |
|
|
|
webhook_secret = config.STRIPE_WEBHOOK_SECRET |
|
|
|
|
|
payload = await request.body() |
|
sig_header = request.headers.get('stripe-signature') |
|
|
|
|
|
try: |
|
event = stripe.Webhook.construct_event( |
|
payload, sig_header, webhook_secret |
|
) |
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail="Invalid payload") |
|
except stripe.error.SignatureVerificationError as e: |
|
raise HTTPException(status_code=400, detail="Invalid signature") |
|
|
|
|
|
if event.type in ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted']: |
|
|
|
subscription = event.data.object |
|
customer_id = subscription.get('customer') |
|
|
|
if not customer_id: |
|
logger.warning(f"No customer ID found in subscription event: {event.type}") |
|
return {"status": "error", "message": "No customer ID found"} |
|
|
|
|
|
db = DBConnection() |
|
client = await db.client |
|
|
|
if event.type == 'customer.subscription.created' or event.type == 'customer.subscription.updated': |
|
|
|
if subscription.get('status') in ['active', 'trialing']: |
|
|
|
await client.schema('basejump').from_('billing_customers').update( |
|
{'active': True} |
|
).eq('id', customer_id).execute() |
|
logger.info(f"Webhook: Updated customer {customer_id} active status to TRUE based on {event.type}") |
|
else: |
|
|
|
|
|
has_active = len(stripe.Subscription.list( |
|
customer=customer_id, |
|
status='active', |
|
limit=1 |
|
).get('data', [])) > 0 |
|
|
|
if not has_active: |
|
await client.schema('basejump').from_('billing_customers').update( |
|
{'active': False} |
|
).eq('id', customer_id).execute() |
|
logger.info(f"Webhook: Updated customer {customer_id} active status to FALSE based on {event.type}") |
|
|
|
elif event.type == 'customer.subscription.deleted': |
|
|
|
has_active = len(stripe.Subscription.list( |
|
customer=customer_id, |
|
status='active', |
|
limit=1 |
|
).get('data', [])) > 0 |
|
|
|
if not has_active: |
|
|
|
await client.schema('basejump').from_('billing_customers').update( |
|
{'active': False} |
|
).eq('id', customer_id).execute() |
|
logger.info(f"Webhook: Updated customer {customer_id} active status to FALSE after subscription deletion") |
|
|
|
logger.info(f"Processed {event.type} event for customer {customer_id}") |
|
|
|
return {"status": "success"} |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing webhook: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|