|
from fastapi import FastAPI, Request, BackgroundTasks |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
import requests |
|
from bs4 import BeautifulSoup |
|
import asyncio |
|
import aiohttp |
|
from datetime import datetime, timezone |
|
from typing import List, Dict, Optional |
|
import uvicorn |
|
import os |
|
import pandas as pd |
|
from datasets import Dataset, load_dataset |
|
from huggingface_hub import HfApi |
|
import logging |
|
from contextlib import asynccontextmanager |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DATASET_REPO_NAME = os.getenv("DATASET_REPO_NAME", "nbroad/hf-inference-providers-data") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
DATA_COLLECTION_INTERVAL = 1800 |
|
|
|
|
|
data_collection_task = None |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""Manage application lifecycle""" |
|
|
|
global data_collection_task |
|
data_collection_task = asyncio.create_task(timed_data_collection()) |
|
logger.info("Started hourly data collection task") |
|
yield |
|
|
|
if data_collection_task: |
|
data_collection_task.cancel() |
|
logger.info("Stopped hourly data collection task") |
|
|
|
app = FastAPI(title="Inference Provider Dashboard", lifespan=lifespan) |
|
|
|
|
|
PROVIDERS = [ |
|
"togethercomputer", |
|
"fireworks-ai", |
|
"nebius", |
|
"fal", |
|
"groq", |
|
"cerebras", |
|
"sambanovasystems", |
|
"replicate", |
|
"novita", |
|
"Hyperbolic", |
|
"featherless-ai", |
|
"CohereLabs", |
|
"nscale", |
|
] |
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]: |
|
"""Get monthly requests for a provider from HuggingFace""" |
|
url = f"https://huggingface.co/{provider}" |
|
try: |
|
async with session.get(url) as response: |
|
html = await response.text() |
|
soup = BeautifulSoup(html, 'html.parser') |
|
request_div = soup.find('div', text=lambda t: t and 'monthly requests' in t.lower()) |
|
if request_div: |
|
requests_text = request_div.text.split()[0].replace(',', '') |
|
return { |
|
"provider": provider, |
|
"monthly_requests": requests_text, |
|
"monthly_requests_int": int(requests_text) if requests_text.isdigit() else 0 |
|
} |
|
return { |
|
"provider": provider, |
|
"monthly_requests": "N/A", |
|
"monthly_requests_int": 0 |
|
} |
|
except Exception as e: |
|
logger.error(f"Error fetching {provider}: {e}") |
|
return { |
|
"provider": provider, |
|
"monthly_requests": "N/A", |
|
"monthly_requests_int": 0 |
|
} |
|
|
|
async def collect_and_store_data(): |
|
"""Collect current data and store it in the dataset""" |
|
if not HF_TOKEN: |
|
logger.warning("No HF_TOKEN found, skipping data storage") |
|
return |
|
|
|
try: |
|
logger.info("Collecting data for storage...") |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
timestamp = datetime.now(timezone.utc).isoformat() |
|
data_rows = [] |
|
|
|
for result in results: |
|
data_rows.append({ |
|
"timestamp": timestamp, |
|
"provider": result["provider"], |
|
"monthly_requests": result["monthly_requests"], |
|
"monthly_requests_int": result["monthly_requests_int"] |
|
}) |
|
|
|
new_df = pd.DataFrame(data_rows) |
|
|
|
|
|
try: |
|
existing_dataset = load_dataset(DATASET_REPO_NAME, split="train") |
|
existing_df = existing_dataset.to_pandas() |
|
combined_df = pd.concat([existing_df, new_df], ignore_index=True) |
|
except Exception as e: |
|
logger.info(f"Creating new dataset (existing not found): {e}") |
|
combined_df = new_df |
|
|
|
|
|
new_dataset = Dataset.from_pandas(combined_df) |
|
new_dataset.push_to_hub(DATASET_REPO_NAME, token=HF_TOKEN, private=False) |
|
|
|
logger.info(f"Successfully stored data for {len(results)} providers") |
|
|
|
except Exception as e: |
|
logger.error(f"Error collecting and storing data: {e}") |
|
|
|
async def timed_data_collection(): |
|
"""Background task that runs every DATA_COLLECTION_INTERVAL seconds to collect data""" |
|
while True: |
|
try: |
|
await collect_and_store_data() |
|
await asyncio.sleep(DATA_COLLECTION_INTERVAL) |
|
except asyncio.CancelledError: |
|
logger.info("Data collection task cancelled") |
|
break |
|
except Exception as e: |
|
logger.error(f"Error in hourly data collection: {e}") |
|
|
|
await asyncio.sleep(300) |
|
|
|
@app.get("/") |
|
async def dashboard(request: Request): |
|
"""Serve the main dashboard page""" |
|
return templates.TemplateResponse("dashboard.html", {"request": request}) |
|
|
|
@app.get("/api/providers") |
|
async def get_providers_data(): |
|
"""API endpoint to get provider data""" |
|
async with aiohttp.ClientSession() as session: |
|
tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
results.sort(key=lambda x: x["monthly_requests_int"], reverse=True) |
|
|
|
return { |
|
"providers": results, |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
|
"total_providers": len(results) |
|
} |
|
|
|
@app.get("/api/providers/{provider}") |
|
async def get_provider_data(provider: str): |
|
"""API endpoint to get data for a specific provider""" |
|
if provider not in PROVIDERS: |
|
return {"error": "Provider not found"} |
|
|
|
async with aiohttp.ClientSession() as session: |
|
result = await get_monthly_requests(session, provider) |
|
|
|
return { |
|
"provider_data": result, |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
} |
|
|
|
@app.get("/api/historical") |
|
async def get_historical_data(): |
|
"""API endpoint to get historical data for line chart""" |
|
if not HF_TOKEN: |
|
return {"error": "Historical data not available", "data": []} |
|
|
|
try: |
|
|
|
dataset = load_dataset(DATASET_REPO_NAME, split="train") |
|
df = dataset.to_pandas() |
|
|
|
|
|
df['timestamp'] = pd.to_datetime(df['timestamp']) |
|
df = df.sort_values('timestamp') |
|
|
|
|
|
cutoff_time = datetime.now(timezone.utc) - pd.Timedelta(hours=48) |
|
df = df[df['timestamp'] >= cutoff_time] |
|
|
|
|
|
historical_data = {} |
|
|
|
for provider in PROVIDERS: |
|
provider_data = df[df['provider'] == provider].copy() |
|
if not provider_data.empty: |
|
|
|
historical_data[provider] = [ |
|
{ |
|
"x": row['timestamp'].isoformat(), |
|
"y": row['monthly_requests_int'] |
|
} |
|
for _, row in provider_data.iterrows() |
|
] |
|
else: |
|
historical_data[provider] = [] |
|
|
|
return { |
|
"historical_data": historical_data, |
|
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error fetching historical data: {e}") |
|
return {"error": "Failed to fetch historical data", "data": []} |
|
|
|
@app.post("/api/collect-now") |
|
async def trigger_data_collection(background_tasks: BackgroundTasks): |
|
"""Manual trigger for data collection""" |
|
background_tasks.add_task(collect_and_store_data) |
|
return {"message": "Data collection triggered", "timestamp": datetime.now().isoformat()} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |