data.ai / billing.py
lattmamb's picture
Upload 251 files
a01965e verified
raw
history blame
4.72 kB
from datetime import datetime, timezone
from typing import Dict, Optional, Tuple
from utils.logger import logger
from utils.config import config, EnvMode
# Define subscription tiers and their monthly limits (in minutes)
SUBSCRIPTION_TIERS = {
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 8},
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300},
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400}
}
async def get_account_subscription(client, account_id: str) -> Optional[Dict]:
"""Get the current subscription for an account."""
result = await client.schema('basejump').from_('billing_subscriptions') \
.select('*') \
.eq('account_id', account_id) \
.eq('status', 'active') \
.order('created', desc=True) \
.limit(1) \
.execute()
if result.data and len(result.data) > 0:
return result.data[0]
return None
async def calculate_monthly_usage(client, account_id: str) -> float:
"""Calculate total agent run minutes for the current month for an account."""
# Get start of current month in UTC
now = datetime.now(timezone.utc)
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
# First get all threads for this account
threads_result = await client.table('threads') \
.select('thread_id') \
.eq('account_id', account_id) \
.execute()
if not threads_result.data:
return 0.0
thread_ids = [t['thread_id'] for t in threads_result.data]
# Then get all agent runs for these threads in current month
runs_result = await client.table('agent_runs') \
.select('started_at, completed_at') \
.in_('thread_id', thread_ids) \
.gte('started_at', start_of_month.isoformat()) \
.execute()
if not runs_result.data:
return 0.0
# Calculate total minutes
total_seconds = 0
now_ts = now.timestamp()
for run in runs_result.data:
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
if run['completed_at']:
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
else:
# For running jobs, use current time
end_time = now_ts
total_seconds += (end_time - start_time)
return total_seconds / 60 # Convert to minutes
async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]:
"""
Check if an account can run agents based on their subscription and usage.
Returns:
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
"""
if config.ENV_MODE == EnvMode.LOCAL:
logger.info("Running in local development mode - billing checks are disabled")
return True, "Local development mode - billing disabled", {
"price_id": "local_dev",
"plan_name": "Local Development",
"minutes_limit": "no limit"
}
# For staging/production, check subscription status
# Get current subscription
subscription = await get_account_subscription(client, account_id)
# If no subscription, they can use free tier
if not subscription:
subscription = {
'price_id': 'price_1RGJ9GG6l1KZGqIroxSqgphC', # Free tier
'plan_name': 'free'
}
# if not subscription or subscription.get('price_id') is None or subscription.get('price_id') == 'price_1RGJ9GG6l1KZGqIroxSqgphC':
# return False, "You are not subscribed to any plan. Please upgrade your plan to continue.", subscription
# Get tier info
tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id'])
if not tier_info:
return False, "Invalid subscription tier", subscription
# Calculate current month's usage
current_usage = await calculate_monthly_usage(client, account_id)
# Check if within limits
if current_usage >= tier_info['minutes']:
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription
return True, "OK", subscription
# Helper function to get account ID from thread
async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]:
"""Get the account ID associated with a thread."""
result = await client.table('threads') \
.select('account_id') \
.eq('thread_id', thread_id) \
.limit(1) \
.execute()
if result.data and len(result.data) > 0:
return result.data[0]['account_id']
return None