nbroad commited on
Commit
53eacf5
·
verified ·
1 Parent(s): e41e908

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, BackgroundTasks
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ import requests
6
+ from bs4 import BeautifulSoup
7
+ import asyncio
8
+ import aiohttp
9
+ from datetime import datetime, timezone
10
+ from typing import List, Dict, Optional
11
+ import uvicorn
12
+ import os
13
+ import pandas as pd
14
+ from datasets import Dataset, load_dataset
15
+ from huggingface_hub import HfApi
16
+ import logging
17
+ from contextlib import asynccontextmanager
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Global variables for dataset management
24
+ DATASET_REPO_NAME = os.getenv("DATASET_REPO_NAME", "nbroad/hf-inference-providers-data")
25
+ HF_TOKEN = os.getenv("HF_TOKEN")
26
+
27
+ # Time to wait between data collection runs in seconds
28
+ DATA_COLLECTION_INTERVAL = 1800
29
+
30
+ # Background task state
31
+ data_collection_task = None
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ """Manage application lifecycle"""
36
+ # Start background task
37
+ global data_collection_task
38
+ data_collection_task = asyncio.create_task(timed_data_collection())
39
+ logger.info("Started hourly data collection task")
40
+ yield
41
+ # Cleanup
42
+ if data_collection_task:
43
+ data_collection_task.cancel()
44
+ logger.info("Stopped hourly data collection task")
45
+
46
+ app = FastAPI(title="Inference Provider Dashboard", lifespan=lifespan)
47
+
48
+ # List of providers to track
49
+ PROVIDERS = [
50
+ "togethercomputer",
51
+ "fireworks-ai",
52
+ "nebius",
53
+ "fal",
54
+ "groq",
55
+ "cerebras",
56
+ "sambanovasystems",
57
+ "replicate",
58
+ "novita",
59
+ "Hyperbolic",
60
+ "featherless-ai",
61
+ "CohereLabs",
62
+ "nscale",
63
+ ]
64
+
65
+ templates = Jinja2Templates(directory="templates")
66
+
67
+ async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]:
68
+ """Get monthly requests for a provider from HuggingFace"""
69
+ url = f"https://huggingface.co/{provider}"
70
+ try:
71
+ async with session.get(url) as response:
72
+ html = await response.text()
73
+ soup = BeautifulSoup(html, 'html.parser')
74
+ request_div = soup.find('div', text=lambda t: t and 'monthly requests' in t.lower())
75
+ if request_div:
76
+ requests_text = request_div.text.split()[0].replace(',', '')
77
+ return {
78
+ "provider": provider,
79
+ "monthly_requests": requests_text,
80
+ "monthly_requests_int": int(requests_text) if requests_text.isdigit() else 0
81
+ }
82
+ return {
83
+ "provider": provider,
84
+ "monthly_requests": "N/A",
85
+ "monthly_requests_int": 0
86
+ }
87
+ except Exception as e:
88
+ logger.error(f"Error fetching {provider}: {e}")
89
+ return {
90
+ "provider": provider,
91
+ "monthly_requests": "N/A",
92
+ "monthly_requests_int": 0
93
+ }
94
+
95
+ async def collect_and_store_data():
96
+ """Collect current data and store it in the dataset"""
97
+ if not HF_TOKEN:
98
+ logger.warning("No HF_TOKEN found, skipping data storage")
99
+ return
100
+
101
+ try:
102
+ logger.info("Collecting data for storage...")
103
+
104
+ # Collect current data
105
+ async with aiohttp.ClientSession() as session:
106
+ tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS]
107
+ results = await asyncio.gather(*tasks)
108
+
109
+ # Create DataFrame with timestamp
110
+ timestamp = datetime.now(timezone.utc).isoformat()
111
+ data_rows = []
112
+
113
+ for result in results:
114
+ data_rows.append({
115
+ "timestamp": timestamp,
116
+ "provider": result["provider"],
117
+ "monthly_requests": result["monthly_requests"],
118
+ "monthly_requests_int": result["monthly_requests_int"]
119
+ })
120
+
121
+ new_df = pd.DataFrame(data_rows)
122
+
123
+ # Try to load existing dataset and append
124
+ try:
125
+ existing_dataset = load_dataset(DATASET_REPO_NAME, split="train")
126
+ existing_df = existing_dataset.to_pandas()
127
+ combined_df = pd.concat([existing_df, new_df], ignore_index=True)
128
+ except Exception as e:
129
+ logger.info(f"Creating new dataset (existing not found): {e}")
130
+ combined_df = new_df
131
+
132
+ # Convert back to dataset and push
133
+ new_dataset = Dataset.from_pandas(combined_df)
134
+ new_dataset.push_to_hub(DATASET_REPO_NAME, token=HF_TOKEN, private=False)
135
+
136
+ logger.info(f"Successfully stored data for {len(results)} providers")
137
+
138
+ except Exception as e:
139
+ logger.error(f"Error collecting and storing data: {e}")
140
+
141
+ async def timed_data_collection():
142
+ """Background task that runs every DATA_COLLECTION_INTERVAL seconds to collect data"""
143
+ while True:
144
+ try:
145
+ await collect_and_store_data()
146
+ await asyncio.sleep(DATA_COLLECTION_INTERVAL)
147
+ except asyncio.CancelledError:
148
+ logger.info("Data collection task cancelled")
149
+ break
150
+ except Exception as e:
151
+ logger.error(f"Error in hourly data collection: {e}")
152
+ # Wait 5 minutes before retrying on error
153
+ await asyncio.sleep(300)
154
+
155
+ @app.get("/")
156
+ async def dashboard(request: Request):
157
+ """Serve the main dashboard page"""
158
+ return templates.TemplateResponse("dashboard.html", {"request": request})
159
+
160
+ @app.get("/api/providers")
161
+ async def get_providers_data():
162
+ """API endpoint to get provider data"""
163
+ async with aiohttp.ClientSession() as session:
164
+ tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS]
165
+ results = await asyncio.gather(*tasks)
166
+
167
+ # Sort by request count descending
168
+ results.sort(key=lambda x: x["monthly_requests_int"], reverse=True)
169
+
170
+ return {
171
+ "providers": results,
172
+ "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
173
+ "total_providers": len(results)
174
+ }
175
+
176
+ @app.get("/api/providers/{provider}")
177
+ async def get_provider_data(provider: str):
178
+ """API endpoint to get data for a specific provider"""
179
+ if provider not in PROVIDERS:
180
+ return {"error": "Provider not found"}
181
+
182
+ async with aiohttp.ClientSession() as session:
183
+ result = await get_monthly_requests(session, provider)
184
+
185
+ return {
186
+ "provider_data": result,
187
+ "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
188
+ }
189
+
190
+ @app.get("/api/historical")
191
+ async def get_historical_data():
192
+ """API endpoint to get historical data for line chart"""
193
+ if not HF_TOKEN:
194
+ return {"error": "Historical data not available", "data": []}
195
+
196
+ try:
197
+ # Load historical dataset
198
+ dataset = load_dataset(DATASET_REPO_NAME, split="train")
199
+ df = dataset.to_pandas()
200
+
201
+ # Group by timestamp and provider, get the latest entry for each timestamp-provider combo
202
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
203
+ df = df.sort_values('timestamp')
204
+
205
+ # Get last 48 hours of data (48 data points max for performance)
206
+ cutoff_time = datetime.now(timezone.utc) - pd.Timedelta(hours=48)
207
+ df = df[df['timestamp'] >= cutoff_time]
208
+
209
+ # Prepare data for Chart.js line chart
210
+ historical_data = {}
211
+
212
+ for provider in PROVIDERS:
213
+ provider_data = df[df['provider'] == provider].copy()
214
+ if not provider_data.empty:
215
+ # Format for Chart.js: {x: timestamp, y: value}
216
+ historical_data[provider] = [
217
+ {
218
+ "x": row['timestamp'].isoformat(),
219
+ "y": row['monthly_requests_int']
220
+ }
221
+ for _, row in provider_data.iterrows()
222
+ ]
223
+ else:
224
+ historical_data[provider] = []
225
+
226
+ return {
227
+ "historical_data": historical_data,
228
+ "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
229
+ }
230
+
231
+ except Exception as e:
232
+ logger.error(f"Error fetching historical data: {e}")
233
+ return {"error": "Failed to fetch historical data", "data": []}
234
+
235
+ @app.post("/api/collect-now")
236
+ async def trigger_data_collection(background_tasks: BackgroundTasks):
237
+ """Manual trigger for data collection"""
238
+ background_tasks.add_task(collect_and_store_data)
239
+ return {"message": "Data collection triggered", "timestamp": datetime.now().isoformat()}
240
+
241
+ if __name__ == "__main__":
242
+ uvicorn.run(app, host="0.0.0.0", port=7860)