evijit HF Staff commited on
Commit
57e108c
·
verified ·
1 Parent(s): f973e99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -52
app.py CHANGED
@@ -239,18 +239,19 @@ def load_models_csv():
239
 
240
  # Create Gradio interface
241
  with gr.Blocks() as demo:
242
- models_data = gr.State() # To store loaded data
243
-
 
244
  with gr.Row():
245
  gr.Markdown("""
246
  # HuggingFace Models TreeMap Visualization
247
-
248
  This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
249
  Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
250
-
251
- The treemap visualizes models grouped by organization, with the size of each box representing the selected metric (downloads or likes).
252
  """)
253
-
254
  with gr.Row():
255
  with gr.Column(scale=1):
256
  count_by_dropdown = gr.Dropdown(
@@ -263,14 +264,14 @@ with gr.Blocks() as demo:
263
  value="downloads",
264
  info="Select the metric to determine box sizes"
265
  )
266
-
267
  filter_choice_radio = gr.Radio(
268
  label="Filter Type",
269
  choices=["None", "Tag Filter", "Pipeline Filter"],
270
  value="None",
271
  info="Choose how to filter the models"
272
  )
273
-
274
  tag_filter_dropdown = gr.Dropdown(
275
  label="Select Tag",
276
  choices=list(TAG_FILTER_FUNCS.keys()),
@@ -278,7 +279,7 @@ with gr.Blocks() as demo:
278
  visible=False,
279
  info="Filter models by domain/category"
280
  )
281
-
282
  pipeline_filter_dropdown = gr.Dropdown(
283
  label="Select Pipeline Tag",
284
  choices=PIPELINE_TAGS,
@@ -286,7 +287,7 @@ with gr.Blocks() as demo:
286
  visible=False,
287
  info="Filter models by specific pipeline"
288
  )
289
-
290
  size_filter_dropdown = gr.Dropdown(
291
  label="Model Size Filter",
292
  choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
@@ -302,26 +303,63 @@ with gr.Blocks() as demo:
302
  step=5,
303
  info="Number of top organizations to include"
304
  )
305
-
306
  skip_orgs_textbox = gr.Textbox(
307
  label="Organizations to Skip (comma-separated)",
308
  placeholder="e.g., OpenAI, Google",
309
- value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski",
310
- info="Enter names of organizations to exclude from the visualization"
311
  )
312
 
313
- generate_plot_button = gr.Button("Generate Plot", variant="primary")
314
 
315
  with gr.Column(scale=3):
316
  plot_output = gr.Plot()
317
  stats_output = gr.Markdown("*Generate a plot to see statistics*")
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
320
- print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}, Top K={top_k}")
321
-
322
- if data_df is None or len(data_df) == 0:
323
- return None, "Error: No data available. Please try again."
324
-
325
  selected_tag_filter = None
326
  selected_pipeline_filter = None
327
  selected_size_filter = None
@@ -330,17 +368,14 @@ with gr.Blocks() as demo:
330
  selected_tag_filter = tag_filter
331
  elif filter_choice == "Pipeline Filter":
332
  selected_pipeline_filter = pipeline_filter
333
-
334
  if size_filter != "None":
335
  selected_size_filter = size_filter
336
-
337
- # Process skip organizations list
338
  skip_orgs = []
339
  if skip_orgs_text and skip_orgs_text.strip():
340
  skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
341
- print(f"Skipping organizations: {skip_orgs}")
342
-
343
- # Process data for treemap
344
  treemap_data = make_treemap_data(
345
  df=data_df,
346
  count_by=count_by,
@@ -350,64 +385,77 @@ with gr.Blocks() as demo:
350
  size_filter=selected_size_filter,
351
  skip_orgs=skip_orgs
352
  )
353
-
354
- # Create plot
 
 
 
 
 
 
355
  fig = create_treemap(
356
  treemap_data=treemap_data,
357
  count_by=count_by,
358
- title=f"HuggingFace Models - {count_by.replace('AllTime', ' (All Time)').capitalize()} by Organization"
359
  )
360
-
361
- # Generate statistics
362
  if treemap_data.empty:
363
  stats_md = "No data matches the selected filters."
364
  else:
365
  total_models = len(treemap_data)
366
  total_value = treemap_data[count_by].sum()
 
 
367
  top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
368
 
369
- # Format the statistics using clean markdown
 
 
 
370
  stats_md = f"""
371
  ## Statistics
372
  - **Total models shown**: {total_models:,}
373
  - **Total {count_by}**: {int(total_value):,}
 
374
  ## Top Organizations by {count_by.capitalize()}
 
375
  | Organization | {count_by.capitalize()} | % of Total |
376
- |--------------|--------:|--------:|"""
 
377
 
378
- # Add each organization as a row in the table
379
  for org, value in top_5_orgs.items():
380
  percentage = (value / total_value) * 100
381
- stats_md += f"\n| {org} | {int(value):,} | {percentage:.2f}% |"
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  # Add note about skipped organizations if any
384
  if skip_orgs:
385
- stats_md += f"\n\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
386
 
387
  return fig, stats_md
388
 
389
- def update_filter_visibility(filter_choice):
390
- if filter_choice == "Tag Filter":
391
- return gr.update(visible=True), gr.update(visible=False)
392
- elif filter_choice == "Pipeline Filter":
393
- return gr.update(visible=False), gr.update(visible=True)
394
- else: # "None"
395
- return gr.update(visible=False), gr.update(visible=False)
396
-
397
- filter_choice_radio.change(
398
- fn=update_filter_visibility,
399
- inputs=[filter_choice_radio],
400
- outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
401
- )
402
-
403
- # Load data once at startup
404
  demo.load(
405
  fn=load_models_csv,
406
  inputs=[],
407
- outputs=[models_data]
408
  )
409
 
410
- # Button click event to generate plot
411
  generate_plot_button.click(
412
  fn=generate_plot_on_click,
413
  inputs=[
@@ -424,5 +472,6 @@ with gr.Blocks() as demo:
424
  )
425
 
426
 
 
427
  if __name__ == "__main__":
428
  demo.launch()
 
239
 
240
  # Create Gradio interface
241
  with gr.Blocks() as demo:
242
+ models_data = gr.State()
243
+ loading_complete = gr.State(False) # Flag to indicate data load completion
244
+
245
  with gr.Row():
246
  gr.Markdown("""
247
  # HuggingFace Models TreeMap Visualization
248
+
249
  This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
250
  Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
251
+
252
+ The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
253
  """)
254
+
255
  with gr.Row():
256
  with gr.Column(scale=1):
257
  count_by_dropdown = gr.Dropdown(
 
264
  value="downloads",
265
  info="Select the metric to determine box sizes"
266
  )
267
+
268
  filter_choice_radio = gr.Radio(
269
  label="Filter Type",
270
  choices=["None", "Tag Filter", "Pipeline Filter"],
271
  value="None",
272
  info="Choose how to filter the models"
273
  )
274
+
275
  tag_filter_dropdown = gr.Dropdown(
276
  label="Select Tag",
277
  choices=list(TAG_FILTER_FUNCS.keys()),
 
279
  visible=False,
280
  info="Filter models by domain/category"
281
  )
282
+
283
  pipeline_filter_dropdown = gr.Dropdown(
284
  label="Select Pipeline Tag",
285
  choices=PIPELINE_TAGS,
 
287
  visible=False,
288
  info="Filter models by specific pipeline"
289
  )
290
+
291
  size_filter_dropdown = gr.Dropdown(
292
  label="Model Size Filter",
293
  choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
 
303
  step=5,
304
  info="Number of top organizations to include"
305
  )
306
+
307
  skip_orgs_textbox = gr.Textbox(
308
  label="Organizations to Skip (comma-separated)",
309
  placeholder="e.g., OpenAI, Google",
310
+ value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
 
311
  )
312
 
313
+ generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
314
 
315
  with gr.Column(scale=3):
316
  plot_output = gr.Plot()
317
  stats_output = gr.Markdown("*Generate a plot to see statistics*")
318
 
319
+ # Updated load function returning both the data and loading flag
320
+ def load_models_csv():
321
+ df = pd.read_csv('models.csv')
322
+
323
+ def process_tags(tags_str):
324
+ if pd.isna(tags_str):
325
+ return []
326
+ tags_str = tags_str.strip("[]").replace("'", "")
327
+ tags = [tag.strip() for tag in tags_str.split() if tag.strip()]
328
+ return tags
329
+
330
+ df['tags'] = df['tags'].apply(process_tags)
331
+ return df, True
332
+
333
+ # Button enablement after data load
334
+ def enable_plot_button(loaded):
335
+ return gr.update(interactive=loaded)
336
+
337
+ loading_complete.change(
338
+ fn=enable_plot_button,
339
+ inputs=[loading_complete],
340
+ outputs=[generate_plot_button]
341
+ )
342
+
343
+ # Show/hide tag/pipeline dropdown
344
+ def update_filter_visibility(filter_choice):
345
+ if filter_choice == "Tag Filter":
346
+ return gr.update(visible=True), gr.update(visible=False)
347
+ elif filter_choice == "Pipeline Filter":
348
+ return gr.update(visible=False), gr.update(visible=True)
349
+ else:
350
+ return gr.update(visible=False), gr.update(visible=False)
351
+
352
+ filter_choice_radio.change(
353
+ fn=update_filter_visibility,
354
+ inputs=[filter_choice_radio],
355
+ outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
356
+ )
357
+
358
+ # Main generate function
359
  def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
360
+ if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
361
+ return None, "Error: Data is still loading. Please wait a moment and try again."
362
+
 
 
363
  selected_tag_filter = None
364
  selected_pipeline_filter = None
365
  selected_size_filter = None
 
368
  selected_tag_filter = tag_filter
369
  elif filter_choice == "Pipeline Filter":
370
  selected_pipeline_filter = pipeline_filter
371
+
372
  if size_filter != "None":
373
  selected_size_filter = size_filter
374
+
 
375
  skip_orgs = []
376
  if skip_orgs_text and skip_orgs_text.strip():
377
  skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
378
+
 
 
379
  treemap_data = make_treemap_data(
380
  df=data_df,
381
  count_by=count_by,
 
385
  size_filter=selected_size_filter,
386
  skip_orgs=skip_orgs
387
  )
388
+
389
+ title_labels = {
390
+ "downloads": "Downloads (last 30 days)",
391
+ "downloadsAllTime": "Downloads (All Time)",
392
+ "likes": "Likes"
393
+ }
394
+ title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
395
+
396
  fig = create_treemap(
397
  treemap_data=treemap_data,
398
  count_by=count_by,
399
+ title=title_text
400
  )
401
+
 
402
  if treemap_data.empty:
403
  stats_md = "No data matches the selected filters."
404
  else:
405
  total_models = len(treemap_data)
406
  total_value = treemap_data[count_by].sum()
407
+
408
+ # Get top 5 organizations
409
  top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
410
 
411
+ # Get top 5 individual models
412
+ top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
413
+
414
+ # Create statistics section
415
  stats_md = f"""
416
  ## Statistics
417
  - **Total models shown**: {total_models:,}
418
  - **Total {count_by}**: {int(total_value):,}
419
+
420
  ## Top Organizations by {count_by.capitalize()}
421
+
422
  | Organization | {count_by.capitalize()} | % of Total |
423
+ |--------------|-------------:|----------:|
424
+ """
425
 
426
+ # Add top organizations to the table
427
  for org, value in top_5_orgs.items():
428
  percentage = (value / total_value) * 100
429
+ stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
430
 
431
+ # Add the top models table
432
+ stats_md += f"""
433
+ ## Top Models by {count_by.capitalize()}
434
+
435
+ | Model | {count_by.capitalize()} | % of Total |
436
+ |-------|-------------:|----------:|
437
+ """
438
+
439
+ # Add top models to the table
440
+ for _, row in top_5_models.iterrows():
441
+ model_id = row["id"]
442
+ value = row[count_by]
443
+ percentage = (value / total_value) * 100
444
+ stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
445
+
446
  # Add note about skipped organizations if any
447
  if skip_orgs:
448
+ stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
449
 
450
  return fig, stats_md
451
 
452
+ # Load data at startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  demo.load(
454
  fn=load_models_csv,
455
  inputs=[],
456
+ outputs=[models_data, loading_complete]
457
  )
458
 
 
459
  generate_plot_button.click(
460
  fn=generate_plot_on_click,
461
  inputs=[
 
472
  )
473
 
474
 
475
+
476
  if __name__ == "__main__":
477
  demo.launch()