#!/usr/bin/env python """ Script to check Stripe subscriptions for all customers and update their active status. Usage: python update_customer_active_status.py This script: 1. Queries all customers from basejump.billing_customers 2. Checks subscription status directly on Stripe using customer_id 3. Updates customer active status in database Make sure your environment variables are properly set: - SUPABASE_URL - SUPABASE_SERVICE_ROLE_KEY - STRIPE_SECRET_KEY """ import asyncio import sys import os import time from typing import List, Dict, Any, Tuple from dotenv import load_dotenv import stripe # Load script-specific environment variables load_dotenv(".env") # Import relative modules from services.supabase import DBConnection from utils.logger import logger from utils.config import config # Initialize Stripe with the API key stripe.api_key = config.STRIPE_SECRET_KEY # Batch size settings BATCH_SIZE = 100 # Process customers in batches MAX_CONCURRENCY = 20 # Maximum concurrent Stripe API calls # Global DB connection to reuse db_connection = None async def get_all_customers() -> List[Dict[str, Any]]: """ Query all customers from the database. Returns: List of customers with their ID (customer_id is used for Stripe) """ global db_connection if db_connection is None: db_connection = DBConnection() client = await db_connection.client # Print the Supabase URL being used print(f"Using Supabase URL: {os.getenv('SUPABASE_URL')}") # Query all customers from billing_customers result = await client.schema('basejump').from_('billing_customers').select( 'id', 'active' ).execute() # Print the query result print(f"Found {len(result.data)} customers in database") if not result.data: logger.info("No customers found in database") return [] return result.data async def check_stripe_subscription(customer_id: str) -> bool: """ Check if a customer has an active subscription directly on Stripe. Args: customer_id: Customer ID (billing_customers.id) which is the Stripe customer ID Returns: True if customer has at least one active subscription, False otherwise """ if not customer_id: print(f"⚠️ Empty customer_id") return False try: # Print what we're checking for debugging print(f"Checking Stripe subscriptions for customer: {customer_id}") # List all subscriptions for this customer directly on Stripe subscriptions = stripe.Subscription.list( customer=customer_id, status='active', # Only get active subscriptions limit=1 # We only need to know if there's at least one ) # Print the raw data for debugging print(f"Stripe returned data: {subscriptions.data}") # If there's at least one active subscription, the customer is active has_active_subscription = len(subscriptions.data) > 0 if has_active_subscription: print(f"✅ Customer {customer_id} has ACTIVE subscription") else: print(f"❌ Customer {customer_id} has NO active subscription") return has_active_subscription except Exception as e: logger.error(f"Error checking Stripe subscription for customer {customer_id}: {str(e)}") print(f"⚠️ Error checking subscription for {customer_id}: {str(e)}") return False async def process_customer_batch(batch: List[Dict[str, Any]], batch_number: int, total_batches: int) -> Dict[str, bool]: """ Process a batch of customers by checking their Stripe subscriptions concurrently. Args: batch: List of customer records in this batch batch_number: Current batch number (for logging) total_batches: Total number of batches (for logging) Returns: Dictionary mapping customer IDs to subscription status (True/False) """ start_time = time.time() batch_size = len(batch) print(f"Processing batch {batch_number}/{total_batches} ({batch_size} customers)...") # Create a semaphore to limit concurrency within the batch to avoid rate limiting semaphore = asyncio.Semaphore(MAX_CONCURRENCY) async def check_single_customer(customer: Dict[str, Any]) -> Tuple[str, bool]: async with semaphore: # Limit concurrent API calls customer_id = customer['id'] # Check directly on Stripe - customer_id IS the Stripe customer ID is_active = await check_stripe_subscription(customer_id) return customer_id, is_active # Create tasks for all customers in this batch tasks = [check_single_customer(customer) for customer in batch] # Run all tasks in this batch concurrently results = await asyncio.gather(*tasks) # Convert results to dictionary subscription_status = {customer_id: status for customer_id, status in results} end_time = time.time() # Count active/inactive in this batch active_count = sum(1 for status in subscription_status.values() if status) inactive_count = batch_size - active_count print(f"Batch {batch_number} completed in {end_time - start_time:.2f} seconds") print(f"Results (batch {batch_number}): {active_count} active, {inactive_count} inactive subscriptions") return subscription_status async def update_customer_batch(subscription_status: Dict[str, bool]) -> Dict[str, int]: """ Update a batch of customers in the database. Args: subscription_status: Dictionary mapping customer IDs to active status Returns: Dictionary with statistics about the update """ start_time = time.time() global db_connection if db_connection is None: db_connection = DBConnection() client = await db_connection.client # Separate customers into active and inactive groups active_customers = [cid for cid, status in subscription_status.items() if status] inactive_customers = [cid for cid, status in subscription_status.items() if not status] total_count = len(active_customers) + len(inactive_customers) # Update statistics stats = { 'total': total_count, 'active_updated': 0, 'inactive_updated': 0, 'errors': 0 } # Update active customers in a single operation if active_customers: try: print(f"Updating {len(active_customers)} customers to ACTIVE status") await client.schema('basejump').from_('billing_customers').update( {'active': True} ).in_('id', active_customers).execute() stats['active_updated'] = len(active_customers) logger.info(f"Updated {len(active_customers)} customers to ACTIVE status") except Exception as e: logger.error(f"Error updating active customers: {str(e)}") stats['errors'] += 1 # Update inactive customers in a single operation if inactive_customers: try: print(f"Updating {len(inactive_customers)} customers to INACTIVE status") await client.schema('basejump').from_('billing_customers').update( {'active': False} ).in_('id', inactive_customers).execute() stats['inactive_updated'] = len(inactive_customers) logger.info(f"Updated {len(inactive_customers)} customers to INACTIVE status") except Exception as e: logger.error(f"Error updating inactive customers: {str(e)}") stats['errors'] += 1 end_time = time.time() print(f"Database updates completed in {end_time - start_time:.2f} seconds") return stats async def main(): """Main function to run the script.""" total_start_time = time.time() logger.info("Starting customer active status update process") try: # Check Stripe API key print(f"Stripe API key configured: {'Yes' if config.STRIPE_SECRET_KEY else 'No'}") if not config.STRIPE_SECRET_KEY: print("ERROR: Stripe API key not configured. Please set STRIPE_SECRET_KEY in your environment.") return # Initialize global DB connection global db_connection db_connection = DBConnection() # Get all customers from the database all_customers = await get_all_customers() if not all_customers: logger.info("No customers to process") return # Print a small sample of the customer data print("\nCustomer data sample (customer_id = Stripe customer ID):") for i, customer in enumerate(all_customers[:5]): # Show first 5 only print(f" {i+1}. ID: {customer['id']}, Active: {customer.get('active')}") if len(all_customers) > 5: print(f" ... and {len(all_customers) - 5} more") # Split customers into batches batches = [all_customers[i:i + BATCH_SIZE] for i in range(0, len(all_customers), BATCH_SIZE)] total_batches = len(batches) # Ask for confirmation before proceeding confirm = input(f"\nProcess {len(all_customers)} customers in {total_batches} batches of {BATCH_SIZE}? (y/n): ") if confirm.lower() != 'y': logger.info("Operation cancelled by user") return # Overall statistics all_stats = { 'total': 0, 'active_updated': 0, 'inactive_updated': 0, 'errors': 0 } # Process each batch for i, batch in enumerate(batches): batch_number = i + 1 # STEP 1: Process this batch of customers subscription_status = await process_customer_batch(batch, batch_number, total_batches) # STEP 2: Update this batch in the database batch_stats = await update_customer_batch(subscription_status) # Accumulate statistics all_stats['total'] += batch_stats['total'] all_stats['active_updated'] += batch_stats['active_updated'] all_stats['inactive_updated'] += batch_stats['inactive_updated'] all_stats['errors'] += batch_stats['errors'] # Show batch completion print(f"Completed batch {batch_number}/{total_batches}") # Brief pause between batches to avoid Stripe rate limiting if batch_number < total_batches: await asyncio.sleep(1) # 1 second pause between batches # Print summary total_end_time = time.time() total_time = total_end_time - total_start_time print("\nCustomer Status Update Summary:") print(f"Total customers processed: {all_stats['total']}") print(f"Customers set to active: {all_stats['active_updated']}") print(f"Customers set to inactive: {all_stats['inactive_updated']}") if all_stats['errors'] > 0: print(f"Update errors: {all_stats['errors']}") print(f"Total processing time: {total_time:.2f} seconds") logger.info(f"Customer active status update completed in {total_time:.2f} seconds") except Exception as e: logger.error(f"Error during customer status update: {str(e)}") sys.exit(1) finally: # Clean up database connection if db_connection: await DBConnection.disconnect() if __name__ == "__main__": asyncio.run(main())