|
from fastapi import FastAPI, Request, BackgroundTasks |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
import requests |
|
from bs4 import BeautifulSoup |
|
import asyncio |
|
import aiohttp |
|
from datetime import datetime, timezone |
|
from typing import List, Dict, Optional |
|
import uvicorn |
|
import os |
|
import pandas as pd |
|
from datasets import Dataset, load_dataset |
|
from huggingface_hub import HfApi |
|
import logging |
|
from contextlib import asynccontextmanager |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
IN_SPACE = os.getenv("SPACE_REPO_NAME") is not None |
|
|
|
if not IN_SPACE: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
|
|
DATASET_REPO_NAME = os.getenv("DATASET_REPO_NAME", "nbroad/hf-inference-providers-data") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
DATA_COLLECTION_INTERVAL = 1800 |
|
|
|
|
|
data_collection_task = None |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""Manage application lifecycle""" |
|
|
|
global data_collection_task |
|
data_collection_task = asyncio.create_task(timed_data_collection()) |
|
logger.info("Started hourly data collection task") |
|
yield |
|
|
|
if data_collection_task: |
|
data_collection_task.cancel() |
|
logger.info("Stopped hourly data collection task") |
|
|
|
app = FastAPI(title="Inference Provider Dashboard", lifespan=lifespan) |
|
|
|
|
|
PROVIDERS = [ |
|
"togethercomputer", |
|
"fireworks-ai", |
|
"nebius", |
|
"fal", |
|
"groq", |
|
"cerebras", |
|
"sambanovasystems", |
|
"replicate", |
|
"novita", |
|
"Hyperbolic", |
|
"featherless-ai", |
|
"CohereLabs", |
|
"nscale", |
|
] |
|
|
|
|
|
PROVIDER_TO_INFERENCE_NAME = { |
|
"togethercomputer": "together", |
|
"fal": "fal-ai", |
|
"sambanovasystems": "sambanova", |
|
"Hyperbolic": "hyperbolic", |
|
"CohereLabs": "cohere", |
|
|
|
"fireworks-ai": "fireworks-ai", |
|
"nebius": "nebius", |
|
"groq": "groq", |
|
"cerebras": "cerebras", |
|
"replicate": "replicate", |
|
"novita": "novita", |
|
"featherless-ai": "featherless-ai", |
|
"nscale": "nscale", |
|
} |
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]: |
|
"""Get monthly requests for a provider from HuggingFace""" |
|
url = f"https://huggingface.co/{provider}" |
|
try: |
|
async with session.get(url) as response: |
|
html = await response.text() |
|
soup = BeautifulSoup(html, 'html.parser') |
|
request_div = soup.find('div', text=lambda t: t and 'monthly requests' in t.lower()) |
|
if request_div: |
|
requests_text = request_div.text.split()[0].replace(',', '') |
|
return { |
|
"provider": provider, |
|
"monthly_requests": requests_text, |
|
"monthly_requests_int": int(requests_text) if requests_text.isdigit() else 0 |
|
} |
|
return { |
|
"provider": provider, |
|
"monthly_requests": "N/A", |
|
"monthly_requests_int": 0 |
|
} |
|
except Exception as e: |
|
logger.error(f"Error fetching {provider}: {e}") |
|
return { |
|
"provider": provider, |
|
"monthly_requests": "N/A", |
|
"monthly_requests_int": 0 |
|
} |
|
|
|
async def get_provider_models(session: aiohttp.ClientSession, provider: str) -> List[str]: |
|
"""Get supported models for a provider from HuggingFace API""" |
|
if not HF_TOKEN: |
|
return [] |
|
|
|
|
|
inference_provider = PROVIDER_TO_INFERENCE_NAME.get(provider) |
|
if not inference_provider: |
|
logger.warning(f"No inference provider mapping found for {provider}") |
|
return [] |
|
|
|
url = f"https://huggingface.co/api/models?inference_provider={inference_provider}&limit=50&sort=downloads&direction=-1" |
|
headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
|
|
|
try: |
|
async with session.get(url, headers=headers) as response: |
|
if response.status == 200: |
|
models_data = await response.json() |
|
model_ids = [model.get('id', '') for model in models_data if model.get('id')] |
|
return model_ids |
|
else: |
|
logger.warning(f"Failed to fetch models for {provider} (inference_provider={inference_provider}): {response.status}") |
|
return [] |
|
except Exception as e: |
|
logger.error(f"Error fetching models for {provider} (inference_provider={inference_provider}): {e}") |
|
return [] |
|
|
|
async def collect_and_store_data(): |
|
"""Collect current data and store it in the dataset""" |
|
if not HF_TOKEN: |
|
logger.warning("No HF_TOKEN found, skipping data storage") |
|
return |
|
|
|
try: |
|
logger.info("Collecting data for storage...") |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
timestamp = datetime.now(timezone.utc).isoformat() |
|
data_rows = [] |
|
|
|
for result in results: |
|
data_rows.append({ |
|
"timestamp": timestamp, |
|
"provider": result["provider"], |
|
"monthly_requests": result["monthly_requests"], |
|
"monthly_requests_int": result["monthly_requests_int"] |
|
}) |
|
|
|
new_df = pd.DataFrame(data_rows) |
|
|
|
|
|
try: |
|
existing_dataset = load_dataset(DATASET_REPO_NAME, split="train") |
|
existing_df = existing_dataset.to_pandas() |
|
combined_df = pd.concat([existing_df, new_df], ignore_index=True) |
|
except Exception as e: |
|
logger.info(f"Creating new dataset (existing not found): {e}") |
|
combined_df = new_df |
|
|
|
|
|
combined_df['timestamp'] = pd.to_datetime(combined_df['timestamp']) |
|
combined_df = combined_df.sort_values('timestamp') |
|
|
|
|
|
deduplicated_df = combined_df.groupby(['provider', 'monthly_requests_int']).first().reset_index() |
|
|
|
|
|
deduplicated_df['timestamp'] = deduplicated_df['timestamp'].dt.strftime('%Y-%m-%dT%H:%M:%S.%f%z') |
|
|
|
logger.info(f"De-duplicated dataset: {len(combined_df)} -> {len(deduplicated_df)} records") |
|
|
|
|
|
new_dataset = Dataset.from_pandas(deduplicated_df) |
|
new_dataset.push_to_hub(DATASET_REPO_NAME, token=HF_TOKEN, private=False) |
|
|
|
logger.info(f"Successfully stored data for {len(results)} providers") |
|
|
|
except Exception as e: |
|
logger.error(f"Error collecting and storing data: {e}") |
|
|
|
async def timed_data_collection(): |
|
"""Background task that runs every DATA_COLLECTION_INTERVAL seconds to collect data""" |
|
while True: |
|
try: |
|
await collect_and_store_data() |
|
await asyncio.sleep(DATA_COLLECTION_INTERVAL) |
|
except asyncio.CancelledError: |
|
logger.info("Data collection task cancelled") |
|
break |
|
except Exception as e: |
|
logger.error(f"Error in hourly data collection: {e}") |
|
|
|
await asyncio.sleep(300) |
|
|
|
@app.get("/") |
|
async def dashboard(request: Request): |
|
"""Serve the main dashboard page""" |
|
return templates.TemplateResponse("dashboard.html", {"request": request}) |
|
|
|
@app.get("/api/providers") |
|
async def get_providers_data(): |
|
"""API endpoint to get provider data""" |
|
async with aiohttp.ClientSession() as session: |
|
tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
results.sort(key=lambda x: x["monthly_requests_int"], reverse=True) |
|
|
|
return { |
|
"providers": results, |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
|
"total_providers": len(results) |
|
} |
|
|
|
@app.get("/api/providers/{provider}") |
|
async def get_provider_data(provider: str): |
|
"""API endpoint to get data for a specific provider""" |
|
if provider not in PROVIDERS: |
|
return {"error": "Provider not found"} |
|
|
|
async with aiohttp.ClientSession() as session: |
|
result = await get_monthly_requests(session, provider) |
|
|
|
return { |
|
"provider_data": result, |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
} |
|
|
|
@app.get("/api/historical") |
|
async def get_historical_data(): |
|
"""API endpoint to get historical data for line chart""" |
|
if not HF_TOKEN: |
|
logger.warning("No HF_TOKEN available for historical data") |
|
return { |
|
"error": "Historical data not available - no HF token", |
|
"historical_data": {}, |
|
"message": "Historical data collection requires HuggingFace token" |
|
} |
|
|
|
try: |
|
|
|
dataset = load_dataset(DATASET_REPO_NAME, split="train") |
|
df = dataset.to_pandas() |
|
|
|
logger.info(f"Loaded dataset with {len(df)} total records") |
|
|
|
if df.empty: |
|
logger.info("Dataset is empty - no historical data available yet") |
|
return { |
|
"historical_data": {}, |
|
"message": "No historical data available yet. Data collection is running - check back in 30 minutes.", |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
} |
|
|
|
|
|
df['timestamp'] = pd.to_datetime(df['timestamp']) |
|
df = df.sort_values('timestamp') |
|
|
|
|
|
df_filtered = df.copy() |
|
|
|
logger.info(f"Using all {len(df_filtered)} records for full historical view") |
|
|
|
|
|
max_points_per_provider = 500 |
|
if len(df_filtered) > max_points_per_provider * len(PROVIDERS): |
|
|
|
df_filtered = df_filtered.groupby('provider').apply( |
|
lambda x: x.iloc[::max(1, len(x) // max_points_per_provider)] |
|
).reset_index(drop=True) |
|
logger.info(f"Sampled down to {len(df_filtered)} records for performance") |
|
|
|
|
|
historical_data = {} |
|
total_data_points = 0 |
|
|
|
for provider in PROVIDERS: |
|
provider_data = df_filtered[df_filtered['provider'] == provider].copy() |
|
if not provider_data.empty: |
|
|
|
historical_data[provider] = [ |
|
{ |
|
"x": row['timestamp'].isoformat(), |
|
"y": row['monthly_requests_int'] |
|
} |
|
for _, row in provider_data.iterrows() |
|
] |
|
total_data_points += len(historical_data[provider]) |
|
else: |
|
historical_data[provider] = [] |
|
|
|
logger.info(f"Returning {total_data_points} total data points across {len([p for p in historical_data.values() if p])} providers") |
|
|
|
|
|
if not df_filtered.empty: |
|
earliest_date = df_filtered['timestamp'].min().strftime('%Y-%m-%d %H:%M') |
|
latest_date = df_filtered['timestamp'].max().strftime('%Y-%m-%d %H:%M') |
|
date_range = f"From {earliest_date} to {latest_date}" |
|
else: |
|
date_range = "No data" |
|
|
|
return { |
|
"historical_data": historical_data, |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
|
"total_data_points": total_data_points, |
|
"data_range": date_range, |
|
"earliest_date": df_filtered['timestamp'].min().isoformat() if not df_filtered.empty else None, |
|
"latest_date": df_filtered['timestamp'].max().isoformat() if not df_filtered.empty else None |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error fetching historical data: {e}") |
|
|
|
if "does not exist" in str(e).lower() or "not found" in str(e).lower(): |
|
logger.info("Dataset doesn't exist yet, triggering initial data collection") |
|
try: |
|
await collect_and_store_data() |
|
return { |
|
"historical_data": {}, |
|
"message": "Dataset created! Historical data will appear after a few data collection cycles.", |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
} |
|
except Exception as create_error: |
|
logger.error(f"Failed to create initial dataset: {create_error}") |
|
|
|
return { |
|
"error": f"Failed to fetch historical data: {str(e)}", |
|
"historical_data": {}, |
|
"message": "Historical data temporarily unavailable" |
|
} |
|
|
|
@app.get("/api/models") |
|
async def get_provider_models_data(): |
|
"""API endpoint to get supported models matrix for all providers""" |
|
if not HF_TOKEN: |
|
return {"error": "HF_TOKEN required for models data", "matrix": [], "providers": PROVIDERS} |
|
|
|
async with aiohttp.ClientSession() as session: |
|
tasks = [get_provider_models(session, provider) for provider in PROVIDERS] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
provider_models = {} |
|
all_models = set() |
|
|
|
for provider, models in zip(PROVIDERS, results): |
|
provider_models[provider] = set(models) |
|
all_models.update(models) |
|
|
|
|
|
model_popularity = [] |
|
for model in all_models: |
|
provider_count = sum(1 for provider in PROVIDERS if model in provider_models.get(provider, set())) |
|
model_popularity.append((model, provider_count)) |
|
|
|
|
|
model_popularity.sort(key=lambda x: (-x[1], x[0])) |
|
|
|
|
|
matrix = [] |
|
for model_id, popularity in model_popularity: |
|
row = { |
|
"model_id": model_id, |
|
"total_providers": popularity, |
|
"providers": {} |
|
} |
|
|
|
for provider in PROVIDERS: |
|
row["providers"][provider] = model_id in provider_models.get(provider, set()) |
|
|
|
matrix.append(row) |
|
|
|
|
|
provider_totals = {} |
|
for provider in PROVIDERS: |
|
provider_totals[provider] = len(provider_models.get(provider, set())) |
|
|
|
return { |
|
"matrix": matrix, |
|
"providers": PROVIDERS, |
|
"provider_totals": provider_totals, |
|
"provider_mapping": PROVIDER_TO_INFERENCE_NAME, |
|
"total_models": len(all_models), |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
} |
|
|
|
@app.post("/api/collect-now") |
|
async def trigger_data_collection(background_tasks: BackgroundTasks): |
|
"""Manual trigger for data collection""" |
|
background_tasks.add_task(collect_and_store_data) |
|
return {"message": "Data collection triggered", "timestamp": datetime.now().isoformat()} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |