nbroad commited on
Commit
70ed3ab
·
verified ·
1 Parent(s): 5dd9ac2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py CHANGED
@@ -69,6 +69,24 @@ PROVIDERS = [
69
  "nscale",
70
  ]
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  templates = Jinja2Templates(directory="templates")
73
 
74
  async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]:
@@ -99,6 +117,33 @@ async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) ->
99
  "monthly_requests_int": 0
100
  }
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  async def collect_and_store_data():
103
  """Collect current data and store it in the dataset"""
104
  if not HF_TOKEN:
@@ -309,6 +354,61 @@ async def get_historical_data():
309
  "message": "Historical data temporarily unavailable"
310
  }
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  @app.post("/api/collect-now")
313
  async def trigger_data_collection(background_tasks: BackgroundTasks):
314
  """Manual trigger for data collection"""
 
69
  "nscale",
70
  ]
71
 
72
+ # Mapping from display provider names to inference provider API names
73
+ PROVIDER_TO_INFERENCE_NAME = {
74
+ "togethercomputer": "together",
75
+ "fal": "fal-ai",
76
+ "sambanovasystems": "sambanova",
77
+ "Hyperbolic": "hyperbolic",
78
+ "CohereLabs": "cohere",
79
+ # Other providers may not have inference provider support or use different names
80
+ "fireworks-ai": "fireworks-ai",
81
+ "nebius": "nebius",
82
+ "groq": "groq",
83
+ "cerebras": "cerebras",
84
+ "replicate": "replicate",
85
+ "novita": "novita",
86
+ "featherless-ai": "featherless-ai",
87
+ "nscale": "nscale",
88
+ }
89
+
90
  templates = Jinja2Templates(directory="templates")
91
 
92
  async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]:
 
117
  "monthly_requests_int": 0
118
  }
119
 
120
+ async def get_provider_models(session: aiohttp.ClientSession, provider: str) -> List[str]:
121
+ """Get supported models for a provider from HuggingFace API"""
122
+ if not HF_TOKEN:
123
+ return []
124
+
125
+ # Map display provider name to inference provider API name
126
+ inference_provider = PROVIDER_TO_INFERENCE_NAME.get(provider)
127
+ if not inference_provider:
128
+ logger.warning(f"No inference provider mapping found for {provider}")
129
+ return []
130
+
131
+ url = f"https://huggingface.co/api/models?inference_provider={inference_provider}&limit=50&sort=downloads&direction=-1"
132
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
133
+
134
+ try:
135
+ async with session.get(url, headers=headers) as response:
136
+ if response.status == 200:
137
+ models_data = await response.json()
138
+ model_ids = [model.get('id', '') for model in models_data if model.get('id')]
139
+ return model_ids
140
+ else:
141
+ logger.warning(f"Failed to fetch models for {provider} (inference_provider={inference_provider}): {response.status}")
142
+ return []
143
+ except Exception as e:
144
+ logger.error(f"Error fetching models for {provider} (inference_provider={inference_provider}): {e}")
145
+ return []
146
+
147
  async def collect_and_store_data():
148
  """Collect current data and store it in the dataset"""
149
  if not HF_TOKEN:
 
354
  "message": "Historical data temporarily unavailable"
355
  }
356
 
357
+ @app.get("/api/models")
358
+ async def get_provider_models_data():
359
+ """API endpoint to get supported models matrix for all providers"""
360
+ if not HF_TOKEN:
361
+ return {"error": "HF_TOKEN required for models data", "matrix": [], "providers": PROVIDERS}
362
+
363
+ async with aiohttp.ClientSession() as session:
364
+ tasks = [get_provider_models(session, provider) for provider in PROVIDERS]
365
+ results = await asyncio.gather(*tasks)
366
+
367
+ # Create provider -> models mapping
368
+ provider_models = {}
369
+ all_models = set()
370
+
371
+ for provider, models in zip(PROVIDERS, results):
372
+ provider_models[provider] = set(models)
373
+ all_models.update(models)
374
+
375
+ # Convert to list and sort by popularity (number of providers supporting each model)
376
+ model_popularity = []
377
+ for model in all_models:
378
+ provider_count = sum(1 for provider in PROVIDERS if model in provider_models.get(provider, set()))
379
+ model_popularity.append((model, provider_count))
380
+
381
+ # Sort by popularity (descending) then by model name
382
+ model_popularity.sort(key=lambda x: (-x[1], x[0]))
383
+
384
+ # Build matrix data
385
+ matrix = []
386
+ for model_id, popularity in model_popularity:
387
+ row = {
388
+ "model_id": model_id,
389
+ "total_providers": popularity,
390
+ "providers": {}
391
+ }
392
+
393
+ for provider in PROVIDERS:
394
+ row["providers"][provider] = model_id in provider_models.get(provider, set())
395
+
396
+ matrix.append(row)
397
+
398
+ # Calculate totals per provider
399
+ provider_totals = {}
400
+ for provider in PROVIDERS:
401
+ provider_totals[provider] = len(provider_models.get(provider, set()))
402
+
403
+ return {
404
+ "matrix": matrix,
405
+ "providers": PROVIDERS,
406
+ "provider_totals": provider_totals,
407
+ "provider_mapping": PROVIDER_TO_INFERENCE_NAME,
408
+ "total_models": len(all_models),
409
+ "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
410
+ }
411
+
412
  @app.post("/api/collect-now")
413
  async def trigger_data_collection(background_tasks: BackgroundTasks):
414
  """Manual trigger for data collection"""