from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import requests from bs4 import BeautifulSoup import asyncio import aiohttp from datetime import datetime, timezone from typing import List, Dict, Optional import uvicorn import os import pandas as pd from datasets import Dataset, load_dataset from huggingface_hub import HfApi import logging from contextlib import asynccontextmanager # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for dataset management DATASET_REPO_NAME = os.getenv("DATASET_REPO_NAME", "nbroad/hf-inference-providers-data") HF_TOKEN = os.getenv("HF_TOKEN") # Time to wait between data collection runs in seconds DATA_COLLECTION_INTERVAL = 1800 # Background task state data_collection_task = None @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifecycle""" # Start background task global data_collection_task data_collection_task = asyncio.create_task(timed_data_collection()) logger.info("Started hourly data collection task") yield # Cleanup if data_collection_task: data_collection_task.cancel() logger.info("Stopped hourly data collection task") app = FastAPI(title="Inference Provider Dashboard", lifespan=lifespan) # List of providers to track PROVIDERS = [ "togethercomputer", "fireworks-ai", "nebius", "fal", "groq", "cerebras", "sambanovasystems", "replicate", "novita", "Hyperbolic", "featherless-ai", "CohereLabs", "nscale", ] templates = Jinja2Templates(directory="templates") async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]: """Get monthly requests for a provider from HuggingFace""" url = f"https://huggingface.co/{provider}" try: async with session.get(url) as response: html = await response.text() soup = BeautifulSoup(html, 'html.parser') request_div = soup.find('div', text=lambda t: t and 'monthly requests' in t.lower()) if request_div: requests_text = request_div.text.split()[0].replace(',', '') return { "provider": provider, "monthly_requests": requests_text, "monthly_requests_int": int(requests_text) if requests_text.isdigit() else 0 } return { "provider": provider, "monthly_requests": "N/A", "monthly_requests_int": 0 } except Exception as e: logger.error(f"Error fetching {provider}: {e}") return { "provider": provider, "monthly_requests": "N/A", "monthly_requests_int": 0 } async def collect_and_store_data(): """Collect current data and store it in the dataset""" if not HF_TOKEN: logger.warning("No HF_TOKEN found, skipping data storage") return try: logger.info("Collecting data for storage...") # Collect current data async with aiohttp.ClientSession() as session: tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] results = await asyncio.gather(*tasks) # Create DataFrame with timestamp timestamp = datetime.now(timezone.utc).isoformat() data_rows = [] for result in results: data_rows.append({ "timestamp": timestamp, "provider": result["provider"], "monthly_requests": result["monthly_requests"], "monthly_requests_int": result["monthly_requests_int"] }) new_df = pd.DataFrame(data_rows) # Try to load existing dataset and append try: existing_dataset = load_dataset(DATASET_REPO_NAME, split="train") existing_df = existing_dataset.to_pandas() combined_df = pd.concat([existing_df, new_df], ignore_index=True) except Exception as e: logger.info(f"Creating new dataset (existing not found): {e}") combined_df = new_df # Convert back to dataset and push new_dataset = Dataset.from_pandas(combined_df) new_dataset.push_to_hub(DATASET_REPO_NAME, token=HF_TOKEN, private=False) logger.info(f"Successfully stored data for {len(results)} providers") except Exception as e: logger.error(f"Error collecting and storing data: {e}") async def timed_data_collection(): """Background task that runs every DATA_COLLECTION_INTERVAL seconds to collect data""" while True: try: await collect_and_store_data() await asyncio.sleep(DATA_COLLECTION_INTERVAL) except asyncio.CancelledError: logger.info("Data collection task cancelled") break except Exception as e: logger.error(f"Error in hourly data collection: {e}") # Wait 5 minutes before retrying on error await asyncio.sleep(300) @app.get("/") async def dashboard(request: Request): """Serve the main dashboard page""" return templates.TemplateResponse("dashboard.html", {"request": request}) @app.get("/api/providers") async def get_providers_data(): """API endpoint to get provider data""" async with aiohttp.ClientSession() as session: tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] results = await asyncio.gather(*tasks) # Sort by request count descending results.sort(key=lambda x: x["monthly_requests_int"], reverse=True) return { "providers": results, "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "total_providers": len(results) } @app.get("/api/providers/{provider}") async def get_provider_data(provider: str): """API endpoint to get data for a specific provider""" if provider not in PROVIDERS: return {"error": "Provider not found"} async with aiohttp.ClientSession() as session: result = await get_monthly_requests(session, provider) return { "provider_data": result, "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') } @app.get("/api/historical") async def get_historical_data(): """API endpoint to get historical data for line chart""" if not HF_TOKEN: return {"error": "Historical data not available", "data": []} try: # Load historical dataset dataset = load_dataset(DATASET_REPO_NAME, split="train") df = dataset.to_pandas() # Group by timestamp and provider, get the latest entry for each timestamp-provider combo df['timestamp'] = pd.to_datetime(df['timestamp']) df = df.sort_values('timestamp') # Get last 48 hours of data (48 data points max for performance) cutoff_time = datetime.now(timezone.utc) - pd.Timedelta(hours=48) df = df[df['timestamp'] >= cutoff_time] # Prepare data for Chart.js line chart historical_data = {} for provider in PROVIDERS: provider_data = df[df['provider'] == provider].copy() if not provider_data.empty: # Format for Chart.js: {x: timestamp, y: value} historical_data[provider] = [ { "x": row['timestamp'].isoformat(), "y": row['monthly_requests_int'] } for _, row in provider_data.iterrows() ] else: historical_data[provider] = [] return { "historical_data": historical_data, "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') } except Exception as e: logger.error(f"Error fetching historical data: {e}") return {"error": "Failed to fetch historical data", "data": []} @app.post("/api/collect-now") async def trigger_data_collection(background_tasks: BackgroundTasks): """Manual trigger for data collection""" background_tasks.add_task(collect_and_store_data) return {"message": "Data collection triggered", "timestamp": datetime.now().isoformat()} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)