Update app.py
Browse files
app.py
CHANGED
@@ -239,18 +239,19 @@ def load_models_csv():
|
|
239 |
|
240 |
# Create Gradio interface
|
241 |
with gr.Blocks() as demo:
|
242 |
-
models_data = gr.State()
|
243 |
-
|
|
|
244 |
with gr.Row():
|
245 |
gr.Markdown("""
|
246 |
# HuggingFace Models TreeMap Visualization
|
247 |
-
|
248 |
This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
|
249 |
Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
|
250 |
-
|
251 |
-
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric
|
252 |
""")
|
253 |
-
|
254 |
with gr.Row():
|
255 |
with gr.Column(scale=1):
|
256 |
count_by_dropdown = gr.Dropdown(
|
@@ -263,14 +264,14 @@ with gr.Blocks() as demo:
|
|
263 |
value="downloads",
|
264 |
info="Select the metric to determine box sizes"
|
265 |
)
|
266 |
-
|
267 |
filter_choice_radio = gr.Radio(
|
268 |
label="Filter Type",
|
269 |
choices=["None", "Tag Filter", "Pipeline Filter"],
|
270 |
value="None",
|
271 |
info="Choose how to filter the models"
|
272 |
)
|
273 |
-
|
274 |
tag_filter_dropdown = gr.Dropdown(
|
275 |
label="Select Tag",
|
276 |
choices=list(TAG_FILTER_FUNCS.keys()),
|
@@ -278,7 +279,7 @@ with gr.Blocks() as demo:
|
|
278 |
visible=False,
|
279 |
info="Filter models by domain/category"
|
280 |
)
|
281 |
-
|
282 |
pipeline_filter_dropdown = gr.Dropdown(
|
283 |
label="Select Pipeline Tag",
|
284 |
choices=PIPELINE_TAGS,
|
@@ -286,7 +287,7 @@ with gr.Blocks() as demo:
|
|
286 |
visible=False,
|
287 |
info="Filter models by specific pipeline"
|
288 |
)
|
289 |
-
|
290 |
size_filter_dropdown = gr.Dropdown(
|
291 |
label="Model Size Filter",
|
292 |
choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
|
@@ -302,26 +303,63 @@ with gr.Blocks() as demo:
|
|
302 |
step=5,
|
303 |
info="Number of top organizations to include"
|
304 |
)
|
305 |
-
|
306 |
skip_orgs_textbox = gr.Textbox(
|
307 |
label="Organizations to Skip (comma-separated)",
|
308 |
placeholder="e.g., OpenAI, Google",
|
309 |
-
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
|
310 |
-
info="Enter names of organizations to exclude from the visualization"
|
311 |
)
|
312 |
|
313 |
-
generate_plot_button = gr.Button("Generate Plot", variant="primary")
|
314 |
|
315 |
with gr.Column(scale=3):
|
316 |
plot_output = gr.Plot()
|
317 |
stats_output = gr.Markdown("*Generate a plot to see statistics*")
|
318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
return None, "Error: No data available. Please try again."
|
324 |
-
|
325 |
selected_tag_filter = None
|
326 |
selected_pipeline_filter = None
|
327 |
selected_size_filter = None
|
@@ -330,17 +368,14 @@ with gr.Blocks() as demo:
|
|
330 |
selected_tag_filter = tag_filter
|
331 |
elif filter_choice == "Pipeline Filter":
|
332 |
selected_pipeline_filter = pipeline_filter
|
333 |
-
|
334 |
if size_filter != "None":
|
335 |
selected_size_filter = size_filter
|
336 |
-
|
337 |
-
# Process skip organizations list
|
338 |
skip_orgs = []
|
339 |
if skip_orgs_text and skip_orgs_text.strip():
|
340 |
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
|
341 |
-
|
342 |
-
|
343 |
-
# Process data for treemap
|
344 |
treemap_data = make_treemap_data(
|
345 |
df=data_df,
|
346 |
count_by=count_by,
|
@@ -350,64 +385,77 @@ with gr.Blocks() as demo:
|
|
350 |
size_filter=selected_size_filter,
|
351 |
skip_orgs=skip_orgs
|
352 |
)
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
fig = create_treemap(
|
356 |
treemap_data=treemap_data,
|
357 |
count_by=count_by,
|
358 |
-
title=
|
359 |
)
|
360 |
-
|
361 |
-
# Generate statistics
|
362 |
if treemap_data.empty:
|
363 |
stats_md = "No data matches the selected filters."
|
364 |
else:
|
365 |
total_models = len(treemap_data)
|
366 |
total_value = treemap_data[count_by].sum()
|
|
|
|
|
367 |
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
|
368 |
|
369 |
-
#
|
|
|
|
|
|
|
370 |
stats_md = f"""
|
371 |
## Statistics
|
372 |
- **Total models shown**: {total_models:,}
|
373 |
- **Total {count_by}**: {int(total_value):,}
|
|
|
374 |
## Top Organizations by {count_by.capitalize()}
|
|
|
375 |
| Organization | {count_by.capitalize()} | % of Total |
|
376 |
-
|
|
|
377 |
|
378 |
-
# Add
|
379 |
for org, value in top_5_orgs.items():
|
380 |
percentage = (value / total_value) * 100
|
381 |
-
stats_md += f"
|
382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
# Add note about skipped organizations if any
|
384 |
if skip_orgs:
|
385 |
-
stats_md += f"\n
|
386 |
|
387 |
return fig, stats_md
|
388 |
|
389 |
-
|
390 |
-
if filter_choice == "Tag Filter":
|
391 |
-
return gr.update(visible=True), gr.update(visible=False)
|
392 |
-
elif filter_choice == "Pipeline Filter":
|
393 |
-
return gr.update(visible=False), gr.update(visible=True)
|
394 |
-
else: # "None"
|
395 |
-
return gr.update(visible=False), gr.update(visible=False)
|
396 |
-
|
397 |
-
filter_choice_radio.change(
|
398 |
-
fn=update_filter_visibility,
|
399 |
-
inputs=[filter_choice_radio],
|
400 |
-
outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
|
401 |
-
)
|
402 |
-
|
403 |
-
# Load data once at startup
|
404 |
demo.load(
|
405 |
fn=load_models_csv,
|
406 |
inputs=[],
|
407 |
-
outputs=[models_data]
|
408 |
)
|
409 |
|
410 |
-
# Button click event to generate plot
|
411 |
generate_plot_button.click(
|
412 |
fn=generate_plot_on_click,
|
413 |
inputs=[
|
@@ -424,5 +472,6 @@ with gr.Blocks() as demo:
|
|
424 |
)
|
425 |
|
426 |
|
|
|
427 |
if __name__ == "__main__":
|
428 |
demo.launch()
|
|
|
239 |
|
240 |
# Create Gradio interface
|
241 |
with gr.Blocks() as demo:
|
242 |
+
models_data = gr.State()
|
243 |
+
loading_complete = gr.State(False) # Flag to indicate data load completion
|
244 |
+
|
245 |
with gr.Row():
|
246 |
gr.Markdown("""
|
247 |
# HuggingFace Models TreeMap Visualization
|
248 |
+
|
249 |
This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
|
250 |
Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
|
251 |
+
|
252 |
+
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
|
253 |
""")
|
254 |
+
|
255 |
with gr.Row():
|
256 |
with gr.Column(scale=1):
|
257 |
count_by_dropdown = gr.Dropdown(
|
|
|
264 |
value="downloads",
|
265 |
info="Select the metric to determine box sizes"
|
266 |
)
|
267 |
+
|
268 |
filter_choice_radio = gr.Radio(
|
269 |
label="Filter Type",
|
270 |
choices=["None", "Tag Filter", "Pipeline Filter"],
|
271 |
value="None",
|
272 |
info="Choose how to filter the models"
|
273 |
)
|
274 |
+
|
275 |
tag_filter_dropdown = gr.Dropdown(
|
276 |
label="Select Tag",
|
277 |
choices=list(TAG_FILTER_FUNCS.keys()),
|
|
|
279 |
visible=False,
|
280 |
info="Filter models by domain/category"
|
281 |
)
|
282 |
+
|
283 |
pipeline_filter_dropdown = gr.Dropdown(
|
284 |
label="Select Pipeline Tag",
|
285 |
choices=PIPELINE_TAGS,
|
|
|
287 |
visible=False,
|
288 |
info="Filter models by specific pipeline"
|
289 |
)
|
290 |
+
|
291 |
size_filter_dropdown = gr.Dropdown(
|
292 |
label="Model Size Filter",
|
293 |
choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
|
|
|
303 |
step=5,
|
304 |
info="Number of top organizations to include"
|
305 |
)
|
306 |
+
|
307 |
skip_orgs_textbox = gr.Textbox(
|
308 |
label="Organizations to Skip (comma-separated)",
|
309 |
placeholder="e.g., OpenAI, Google",
|
310 |
+
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
|
|
|
311 |
)
|
312 |
|
313 |
+
generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
|
314 |
|
315 |
with gr.Column(scale=3):
|
316 |
plot_output = gr.Plot()
|
317 |
stats_output = gr.Markdown("*Generate a plot to see statistics*")
|
318 |
|
319 |
+
# Updated load function returning both the data and loading flag
|
320 |
+
def load_models_csv():
|
321 |
+
df = pd.read_csv('models.csv')
|
322 |
+
|
323 |
+
def process_tags(tags_str):
|
324 |
+
if pd.isna(tags_str):
|
325 |
+
return []
|
326 |
+
tags_str = tags_str.strip("[]").replace("'", "")
|
327 |
+
tags = [tag.strip() for tag in tags_str.split() if tag.strip()]
|
328 |
+
return tags
|
329 |
+
|
330 |
+
df['tags'] = df['tags'].apply(process_tags)
|
331 |
+
return df, True
|
332 |
+
|
333 |
+
# Button enablement after data load
|
334 |
+
def enable_plot_button(loaded):
|
335 |
+
return gr.update(interactive=loaded)
|
336 |
+
|
337 |
+
loading_complete.change(
|
338 |
+
fn=enable_plot_button,
|
339 |
+
inputs=[loading_complete],
|
340 |
+
outputs=[generate_plot_button]
|
341 |
+
)
|
342 |
+
|
343 |
+
# Show/hide tag/pipeline dropdown
|
344 |
+
def update_filter_visibility(filter_choice):
|
345 |
+
if filter_choice == "Tag Filter":
|
346 |
+
return gr.update(visible=True), gr.update(visible=False)
|
347 |
+
elif filter_choice == "Pipeline Filter":
|
348 |
+
return gr.update(visible=False), gr.update(visible=True)
|
349 |
+
else:
|
350 |
+
return gr.update(visible=False), gr.update(visible=False)
|
351 |
+
|
352 |
+
filter_choice_radio.change(
|
353 |
+
fn=update_filter_visibility,
|
354 |
+
inputs=[filter_choice_radio],
|
355 |
+
outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
|
356 |
+
)
|
357 |
+
|
358 |
+
# Main generate function
|
359 |
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
|
360 |
+
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
|
361 |
+
return None, "Error: Data is still loading. Please wait a moment and try again."
|
362 |
+
|
|
|
|
|
363 |
selected_tag_filter = None
|
364 |
selected_pipeline_filter = None
|
365 |
selected_size_filter = None
|
|
|
368 |
selected_tag_filter = tag_filter
|
369 |
elif filter_choice == "Pipeline Filter":
|
370 |
selected_pipeline_filter = pipeline_filter
|
371 |
+
|
372 |
if size_filter != "None":
|
373 |
selected_size_filter = size_filter
|
374 |
+
|
|
|
375 |
skip_orgs = []
|
376 |
if skip_orgs_text and skip_orgs_text.strip():
|
377 |
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
|
378 |
+
|
|
|
|
|
379 |
treemap_data = make_treemap_data(
|
380 |
df=data_df,
|
381 |
count_by=count_by,
|
|
|
385 |
size_filter=selected_size_filter,
|
386 |
skip_orgs=skip_orgs
|
387 |
)
|
388 |
+
|
389 |
+
title_labels = {
|
390 |
+
"downloads": "Downloads (last 30 days)",
|
391 |
+
"downloadsAllTime": "Downloads (All Time)",
|
392 |
+
"likes": "Likes"
|
393 |
+
}
|
394 |
+
title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
|
395 |
+
|
396 |
fig = create_treemap(
|
397 |
treemap_data=treemap_data,
|
398 |
count_by=count_by,
|
399 |
+
title=title_text
|
400 |
)
|
401 |
+
|
|
|
402 |
if treemap_data.empty:
|
403 |
stats_md = "No data matches the selected filters."
|
404 |
else:
|
405 |
total_models = len(treemap_data)
|
406 |
total_value = treemap_data[count_by].sum()
|
407 |
+
|
408 |
+
# Get top 5 organizations
|
409 |
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
|
410 |
|
411 |
+
# Get top 5 individual models
|
412 |
+
top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
|
413 |
+
|
414 |
+
# Create statistics section
|
415 |
stats_md = f"""
|
416 |
## Statistics
|
417 |
- **Total models shown**: {total_models:,}
|
418 |
- **Total {count_by}**: {int(total_value):,}
|
419 |
+
|
420 |
## Top Organizations by {count_by.capitalize()}
|
421 |
+
|
422 |
| Organization | {count_by.capitalize()} | % of Total |
|
423 |
+
|--------------|-------------:|----------:|
|
424 |
+
"""
|
425 |
|
426 |
+
# Add top organizations to the table
|
427 |
for org, value in top_5_orgs.items():
|
428 |
percentage = (value / total_value) * 100
|
429 |
+
stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
|
430 |
|
431 |
+
# Add the top models table
|
432 |
+
stats_md += f"""
|
433 |
+
## Top Models by {count_by.capitalize()}
|
434 |
+
|
435 |
+
| Model | {count_by.capitalize()} | % of Total |
|
436 |
+
|-------|-------------:|----------:|
|
437 |
+
"""
|
438 |
+
|
439 |
+
# Add top models to the table
|
440 |
+
for _, row in top_5_models.iterrows():
|
441 |
+
model_id = row["id"]
|
442 |
+
value = row[count_by]
|
443 |
+
percentage = (value / total_value) * 100
|
444 |
+
stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
|
445 |
+
|
446 |
# Add note about skipped organizations if any
|
447 |
if skip_orgs:
|
448 |
+
stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
|
449 |
|
450 |
return fig, stats_md
|
451 |
|
452 |
+
# Load data at startup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
demo.load(
|
454 |
fn=load_models_csv,
|
455 |
inputs=[],
|
456 |
+
outputs=[models_data, loading_complete]
|
457 |
)
|
458 |
|
|
|
459 |
generate_plot_button.click(
|
460 |
fn=generate_plot_on_click,
|
461 |
inputs=[
|
|
|
472 |
)
|
473 |
|
474 |
|
475 |
+
|
476 |
if __name__ == "__main__":
|
477 |
demo.launch()
|