|
|
|
""" |
|
Script to check Stripe subscriptions for all customers and update their active status. |
|
|
|
Usage: |
|
python update_customer_active_status.py |
|
|
|
This script: |
|
1. Queries all customers from basejump.billing_customers |
|
2. Checks subscription status directly on Stripe using customer_id |
|
3. Updates customer active status in database |
|
|
|
Make sure your environment variables are properly set: |
|
- SUPABASE_URL |
|
- SUPABASE_SERVICE_ROLE_KEY |
|
- STRIPE_SECRET_KEY |
|
""" |
|
|
|
import asyncio |
|
import sys |
|
import os |
|
import time |
|
from typing import List, Dict, Any, Tuple |
|
from dotenv import load_dotenv |
|
import stripe |
|
|
|
|
|
load_dotenv(".env") |
|
|
|
|
|
from services.supabase import DBConnection |
|
from utils.logger import logger |
|
from utils.config import config |
|
|
|
|
|
stripe.api_key = config.STRIPE_SECRET_KEY |
|
|
|
|
|
BATCH_SIZE = 100 |
|
MAX_CONCURRENCY = 20 |
|
|
|
|
|
db_connection = None |
|
|
|
async def get_all_customers() -> List[Dict[str, Any]]: |
|
""" |
|
Query all customers from the database. |
|
|
|
Returns: |
|
List of customers with their ID (customer_id is used for Stripe) |
|
""" |
|
global db_connection |
|
if db_connection is None: |
|
db_connection = DBConnection() |
|
|
|
client = await db_connection.client |
|
|
|
|
|
print(f"Using Supabase URL: {os.getenv('SUPABASE_URL')}") |
|
|
|
|
|
result = await client.schema('basejump').from_('billing_customers').select( |
|
'id', |
|
'active' |
|
).execute() |
|
|
|
|
|
print(f"Found {len(result.data)} customers in database") |
|
|
|
if not result.data: |
|
logger.info("No customers found in database") |
|
return [] |
|
|
|
return result.data |
|
|
|
async def check_stripe_subscription(customer_id: str) -> bool: |
|
""" |
|
Check if a customer has an active subscription directly on Stripe. |
|
|
|
Args: |
|
customer_id: Customer ID (billing_customers.id) which is the Stripe customer ID |
|
|
|
Returns: |
|
True if customer has at least one active subscription, False otherwise |
|
""" |
|
if not customer_id: |
|
print(f"⚠️ Empty customer_id") |
|
return False |
|
|
|
try: |
|
|
|
print(f"Checking Stripe subscriptions for customer: {customer_id}") |
|
|
|
|
|
subscriptions = stripe.Subscription.list( |
|
customer=customer_id, |
|
status='active', |
|
limit=1 |
|
) |
|
|
|
|
|
print(f"Stripe returned data: {subscriptions.data}") |
|
|
|
|
|
has_active_subscription = len(subscriptions.data) > 0 |
|
|
|
if has_active_subscription: |
|
print(f"✅ Customer {customer_id} has ACTIVE subscription") |
|
else: |
|
print(f"❌ Customer {customer_id} has NO active subscription") |
|
|
|
return has_active_subscription |
|
|
|
except Exception as e: |
|
logger.error(f"Error checking Stripe subscription for customer {customer_id}: {str(e)}") |
|
print(f"⚠️ Error checking subscription for {customer_id}: {str(e)}") |
|
return False |
|
|
|
async def process_customer_batch(batch: List[Dict[str, Any]], batch_number: int, total_batches: int) -> Dict[str, bool]: |
|
""" |
|
Process a batch of customers by checking their Stripe subscriptions concurrently. |
|
|
|
Args: |
|
batch: List of customer records in this batch |
|
batch_number: Current batch number (for logging) |
|
total_batches: Total number of batches (for logging) |
|
|
|
Returns: |
|
Dictionary mapping customer IDs to subscription status (True/False) |
|
""" |
|
start_time = time.time() |
|
batch_size = len(batch) |
|
print(f"Processing batch {batch_number}/{total_batches} ({batch_size} customers)...") |
|
|
|
|
|
semaphore = asyncio.Semaphore(MAX_CONCURRENCY) |
|
|
|
async def check_single_customer(customer: Dict[str, Any]) -> Tuple[str, bool]: |
|
async with semaphore: |
|
customer_id = customer['id'] |
|
|
|
|
|
is_active = await check_stripe_subscription(customer_id) |
|
return customer_id, is_active |
|
|
|
|
|
tasks = [check_single_customer(customer) for customer in batch] |
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
subscription_status = {customer_id: status for customer_id, status in results} |
|
|
|
end_time = time.time() |
|
|
|
|
|
active_count = sum(1 for status in subscription_status.values() if status) |
|
inactive_count = batch_size - active_count |
|
|
|
print(f"Batch {batch_number} completed in {end_time - start_time:.2f} seconds") |
|
print(f"Results (batch {batch_number}): {active_count} active, {inactive_count} inactive subscriptions") |
|
|
|
return subscription_status |
|
|
|
async def update_customer_batch(subscription_status: Dict[str, bool]) -> Dict[str, int]: |
|
""" |
|
Update a batch of customers in the database. |
|
|
|
Args: |
|
subscription_status: Dictionary mapping customer IDs to active status |
|
|
|
Returns: |
|
Dictionary with statistics about the update |
|
""" |
|
start_time = time.time() |
|
|
|
global db_connection |
|
if db_connection is None: |
|
db_connection = DBConnection() |
|
|
|
client = await db_connection.client |
|
|
|
|
|
active_customers = [cid for cid, status in subscription_status.items() if status] |
|
inactive_customers = [cid for cid, status in subscription_status.items() if not status] |
|
|
|
total_count = len(active_customers) + len(inactive_customers) |
|
|
|
|
|
stats = { |
|
'total': total_count, |
|
'active_updated': 0, |
|
'inactive_updated': 0, |
|
'errors': 0 |
|
} |
|
|
|
|
|
if active_customers: |
|
try: |
|
print(f"Updating {len(active_customers)} customers to ACTIVE status") |
|
await client.schema('basejump').from_('billing_customers').update( |
|
{'active': True} |
|
).in_('id', active_customers).execute() |
|
|
|
stats['active_updated'] = len(active_customers) |
|
logger.info(f"Updated {len(active_customers)} customers to ACTIVE status") |
|
except Exception as e: |
|
logger.error(f"Error updating active customers: {str(e)}") |
|
stats['errors'] += 1 |
|
|
|
|
|
if inactive_customers: |
|
try: |
|
print(f"Updating {len(inactive_customers)} customers to INACTIVE status") |
|
await client.schema('basejump').from_('billing_customers').update( |
|
{'active': False} |
|
).in_('id', inactive_customers).execute() |
|
|
|
stats['inactive_updated'] = len(inactive_customers) |
|
logger.info(f"Updated {len(inactive_customers)} customers to INACTIVE status") |
|
except Exception as e: |
|
logger.error(f"Error updating inactive customers: {str(e)}") |
|
stats['errors'] += 1 |
|
|
|
end_time = time.time() |
|
print(f"Database updates completed in {end_time - start_time:.2f} seconds") |
|
|
|
return stats |
|
|
|
async def main(): |
|
"""Main function to run the script.""" |
|
total_start_time = time.time() |
|
logger.info("Starting customer active status update process") |
|
|
|
try: |
|
|
|
print(f"Stripe API key configured: {'Yes' if config.STRIPE_SECRET_KEY else 'No'}") |
|
if not config.STRIPE_SECRET_KEY: |
|
print("ERROR: Stripe API key not configured. Please set STRIPE_SECRET_KEY in your environment.") |
|
return |
|
|
|
|
|
global db_connection |
|
db_connection = DBConnection() |
|
|
|
|
|
all_customers = await get_all_customers() |
|
|
|
if not all_customers: |
|
logger.info("No customers to process") |
|
return |
|
|
|
|
|
print("\nCustomer data sample (customer_id = Stripe customer ID):") |
|
for i, customer in enumerate(all_customers[:5]): |
|
print(f" {i+1}. ID: {customer['id']}, Active: {customer.get('active')}") |
|
if len(all_customers) > 5: |
|
print(f" ... and {len(all_customers) - 5} more") |
|
|
|
|
|
batches = [all_customers[i:i + BATCH_SIZE] for i in range(0, len(all_customers), BATCH_SIZE)] |
|
total_batches = len(batches) |
|
|
|
|
|
confirm = input(f"\nProcess {len(all_customers)} customers in {total_batches} batches of {BATCH_SIZE}? (y/n): ") |
|
if confirm.lower() != 'y': |
|
logger.info("Operation cancelled by user") |
|
return |
|
|
|
|
|
all_stats = { |
|
'total': 0, |
|
'active_updated': 0, |
|
'inactive_updated': 0, |
|
'errors': 0 |
|
} |
|
|
|
|
|
for i, batch in enumerate(batches): |
|
batch_number = i + 1 |
|
|
|
|
|
subscription_status = await process_customer_batch(batch, batch_number, total_batches) |
|
|
|
|
|
batch_stats = await update_customer_batch(subscription_status) |
|
|
|
|
|
all_stats['total'] += batch_stats['total'] |
|
all_stats['active_updated'] += batch_stats['active_updated'] |
|
all_stats['inactive_updated'] += batch_stats['inactive_updated'] |
|
all_stats['errors'] += batch_stats['errors'] |
|
|
|
|
|
print(f"Completed batch {batch_number}/{total_batches}") |
|
|
|
|
|
if batch_number < total_batches: |
|
await asyncio.sleep(1) |
|
|
|
|
|
total_end_time = time.time() |
|
total_time = total_end_time - total_start_time |
|
|
|
print("\nCustomer Status Update Summary:") |
|
print(f"Total customers processed: {all_stats['total']}") |
|
print(f"Customers set to active: {all_stats['active_updated']}") |
|
print(f"Customers set to inactive: {all_stats['inactive_updated']}") |
|
if all_stats['errors'] > 0: |
|
print(f"Update errors: {all_stats['errors']}") |
|
print(f"Total processing time: {total_time:.2f} seconds") |
|
|
|
logger.info(f"Customer active status update completed in {total_time:.2f} seconds") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during customer status update: {str(e)}") |
|
sys.exit(1) |
|
finally: |
|
|
|
if db_connection: |
|
await DBConnection.disconnect() |
|
|
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |