evijit HF Staff commited on
Commit
caa5704
·
verified ·
1 Parent(s): 4d0a8a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -63
app.py CHANGED
@@ -2,7 +2,13 @@ import json
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
 
 
 
 
 
5
 
 
6
  PIPELINE_TAGS = [
7
  'text-generation',
8
  'text-to-image',
@@ -44,6 +50,16 @@ PIPELINE_TAGS = [
44
  'table-question-answering',
45
  ]
46
 
 
 
 
 
 
 
 
 
 
 
47
  def is_audio_speech(repo_dct):
48
  res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \
49
  (repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \
@@ -84,6 +100,21 @@ def is_text(repo_dct):
84
  res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", [])))
85
  return res
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  TAG_FILTER_FUNCS = {
88
  "Audio & Speech": is_audio_speech,
89
  "Time series": is_timeseries,
@@ -96,79 +127,211 @@ TAG_FILTER_FUNCS = {
96
  "Sciences": is_science,
97
  }
98
 
99
- def make_org_stats(repo_type, count_by, org_stats, top_k=20, filter_func=None):
100
- assert count_by in ["likes", "downloads", "downloads_all"]
101
- assert repo_type in ["all", "datasets", "models"]
102
- repos = ["datasets", "models"] if repo_type == "all" else [repo_type]
103
- if filter_func is None:
104
- filter_func = lambda x: True
 
 
 
 
105
  sorted_stats = sorted(
106
  [(
107
- author,
108
- sum(dct[count_by] for dct in author_dct[repo] if filter_func(dct))
109
- ) for repo in repos for author, author_dct in org_stats.items()],
110
- key=lambda x:x[1],
111
  reverse=True,
112
  )
 
 
113
  res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))]
114
  total_st = sum(st for o, st in res)
 
 
115
  res_plot_df = []
116
  for org, st in res:
117
  if org == "Others...":
118
- res_plot_df += [("Others...", "other", st * 100 / total_st)]
119
  else:
120
- for repo in repos:
121
- for dct in org_stats[org][repo]:
122
- if filter_func(dct):
123
- res_plot_df += [(org, dct["id"], dct[count_by] * 100 / total_st)]
124
- return ([(o, 100 * st / total_st) for o, st in res if st > 0], res_plot_df)
125
 
126
- def make_figure(count_by, repo_type, org_stats, tag_filter=None, pipeline_filter=None):
127
- assert count_by in ["downloads", "likes", "downloads_all"]
128
- assert repo_type in ["all", "models", "datasets"]
129
- assert tag_filter is None or pipeline_filter is None
130
  filter_func = None
131
  if tag_filter:
132
  filter_func = TAG_FILTER_FUNCS[tag_filter]
133
- if pipeline_filter:
134
  filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter
135
- _, res_plot_df = make_org_stats(repo_type, count_by, org_stats, top_k=25, filter_func=filter_func)
 
 
 
 
 
 
136
  df = pd.DataFrame(
137
  dict(
138
  organizations=[o for o, _, _ in res_plot_df],
139
- repo=[r for _, r, _ in res_plot_df],
140
  stats=[s for _, _, s in res_plot_df],
141
  )
142
  )
143
- df[repo_type] = repo_type # in order to have a single root node
144
- fig = px.treemap(df, path=[repo_type, 'organizations', 'repo'], values='stats')
 
 
 
 
 
145
  fig.update_layout(
146
- treemapcolorway = ["pink" for _ in range(len(res_plot_df))],
147
- margin = dict(t=50, l=25, r=25, b=25)
148
  )
 
149
  return fig
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  with gr.Blocks() as demo:
153
- org_stats_data = gr.State(value=None) # To store loaded data
154
 
155
  with gr.Row():
156
  gr.Markdown("""
157
- ## Hugging Face Organization Stats
158
 
159
- This app shows how different organizations are contributing to different aspects of the open AI ecosystem.
160
- Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest.
161
  """)
 
162
  with gr.Row():
163
  with gr.Column(scale=1):
164
- repo_type_dropdown = gr.Dropdown(
165
- label="Repository Type",
166
- choices=["all", "models", "datasets"],
167
- value="all"
168
- )
169
  count_by_dropdown = gr.Dropdown(
170
  label="Metric",
171
- choices=["downloads", "likes", "downloads_all"],
172
  value="downloads"
173
  )
174
 
@@ -184,47 +347,50 @@ with gr.Blocks() as demo:
184
  value=None,
185
  visible=False
186
  )
 
187
  pipeline_filter_dropdown = gr.Dropdown(
188
  label="Select Pipeline Tag",
189
  choices=PIPELINE_TAGS,
190
  value=None,
191
  visible=False
192
  )
 
 
 
 
 
 
193
 
194
  generate_plot_button = gr.Button("Generate Plot")
195
 
196
  with gr.Column(scale=3):
197
  plot_output = gr.Plot()
198
 
199
- def generate_plot_on_click(repo_type, count_by, filter_choice, tag_filter, pipeline_filter, data):
200
- # Print the current state of the input variables
201
- print(f"Generating plot with the following inputs:")
202
- print(f" Repository Type: {repo_type}")
203
- print(f" Metric (Count By): {count_by}")
204
- print(f" Filter Choice: {filter_choice}")
205
- if filter_choice == "Tag Filter":
206
- print(f" Tag Filter: {tag_filter}")
207
- elif filter_choice == "Pipeline Filter":
208
- print(f" Pipeline Filter: {pipeline_filter}")
209
-
210
  if data is None:
211
  print("Error: Data not loaded yet.")
212
  return None
213
 
214
  selected_tag_filter = None
215
  selected_pipeline_filter = None
 
216
 
217
  if filter_choice == "Tag Filter":
218
  selected_tag_filter = tag_filter
219
  elif filter_choice == "Pipeline Filter":
220
  selected_pipeline_filter = pipeline_filter
 
 
 
221
 
222
  fig = make_figure(
223
  count_by=count_by,
224
- repo_type=repo_type,
225
  org_stats=data,
226
  tag_filter=selected_tag_filter,
227
- pipeline_filter=selected_pipeline_filter
 
228
  )
229
  return fig
230
 
@@ -233,7 +399,7 @@ with gr.Blocks() as demo:
233
  return gr.update(visible=True), gr.update(visible=False)
234
  elif filter_choice == "Pipeline Filter":
235
  return gr.update(visible=False), gr.update(visible=True)
236
- else: # "None"
237
  return gr.update(visible=False), gr.update(visible=False)
238
 
239
  filter_choice_radio.change(
@@ -243,33 +409,26 @@ with gr.Blocks() as demo:
243
  )
244
 
245
  # Load data once at startup
246
- def load_org_data():
247
- print("Loading organization statistics data...")
248
- loaded_org_stats = json.load(open("org_to_artifacts_2l_stats.json"))
249
- print("Data loaded successfully.")
250
- return loaded_org_stats
251
-
252
  demo.load(
253
- fn=load_org_data,
254
- inputs=[], # No inputs needed to just load data
255
- outputs=[org_stats_data] # Only output to the state
256
  )
257
 
258
  # Button click event to generate plot
259
  generate_plot_button.click(
260
  fn=generate_plot_on_click,
261
  inputs=[
262
- repo_type_dropdown,
263
  count_by_dropdown,
264
  filter_choice_radio,
265
  tag_filter_dropdown,
266
  pipeline_filter_dropdown,
267
- org_stats_data
 
268
  ],
269
  outputs=[plot_output]
270
  )
271
 
272
 
273
  if __name__ == "__main__":
274
- # org_stats = json.load(open("org_to_artifacts_2l_stats.json")) # Data loading handled by demo.load
275
  demo.launch()
 
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
5
+ import pyarrow.parquet as pq
6
+ import os
7
+ import requests
8
+ from io import BytesIO
9
+ import math
10
 
11
+ # Define pipeline tags (keeping the same ones from the provided code)
12
  PIPELINE_TAGS = [
13
  'text-generation',
14
  'text-to-image',
 
50
  'table-question-answering',
51
  ]
52
 
53
+ # Model size categories in GB
54
+ MODEL_SIZE_RANGES = {
55
+ "Small (<1GB)": (0, 1),
56
+ "Medium (1-5GB)": (1, 5),
57
+ "Large (5-20GB)": (5, 20),
58
+ "X-Large (20-50GB)": (20, 50),
59
+ "XX-Large (>50GB)": (50, float('inf'))
60
+ }
61
+
62
+ # Filter functions for tags - keeping the same from provided code
63
  def is_audio_speech(repo_dct):
64
  res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \
65
  (repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \
 
100
  res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", [])))
101
  return res
102
 
103
+ # Add model size filter function
104
+ def is_in_size_range(repo_dct, size_range):
105
+ if size_range is None:
106
+ return True
107
+
108
+ min_size, max_size = MODEL_SIZE_RANGES[size_range]
109
+
110
+ # Get model size in GB from safetensors total (if available)
111
+ if repo_dct.get("safetensors") and repo_dct["safetensors"].get("total"):
112
+ # Convert bytes to GB
113
+ size_gb = repo_dct["safetensors"]["total"] / (1024 * 1024 * 1024)
114
+ return min_size <= size_gb < max_size
115
+
116
+ return False
117
+
118
  TAG_FILTER_FUNCS = {
119
  "Audio & Speech": is_audio_speech,
120
  "Time series": is_timeseries,
 
127
  "Sciences": is_science,
128
  }
129
 
130
+ def make_org_stats(count_by, org_stats, top_k=20, filter_func=None, size_range=None):
131
+ assert count_by in ["likes", "downloads"]
132
+
133
+ # Apply both filter_func and size_range if provided
134
+ def combined_filter(dct):
135
+ passes_tag_filter = filter_func(dct) if filter_func else True
136
+ passes_size_filter = is_in_size_range(dct, size_range) if size_range else True
137
+ return passes_tag_filter and passes_size_filter
138
+
139
+ # Sort organizations by total count
140
  sorted_stats = sorted(
141
  [(
142
+ org_id,
143
+ sum(model[count_by] for model in models if combined_filter(model))
144
+ ) for org_id, models in org_stats.items()],
145
+ key=lambda x: x[1],
146
  reverse=True,
147
  )
148
+
149
+ # Top organizations + Others category
150
  res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))]
151
  total_st = sum(st for o, st in res)
152
+
153
+ # Prepare data for treemap
154
  res_plot_df = []
155
  for org, st in res:
156
  if org == "Others...":
157
+ res_plot_df += [("Others...", "other", st * 100 / total_st if total_st > 0 else 0)]
158
  else:
159
+ for model in org_stats[org]:
160
+ if combined_filter(model):
161
+ res_plot_df += [(org, model["id"], model[count_by] * 100 / total_st if total_st > 0 else 0)]
162
+
163
+ return ([(o, 100 * st / total_st if total_st > 0 else 0) for o, st in res if st > 0], res_plot_df)
164
 
165
+ def make_figure(count_by, org_stats, tag_filter=None, pipeline_filter=None, size_range=None):
166
+ assert count_by in ["downloads", "likes"]
167
+
168
+ # Determine which filter function to use
169
  filter_func = None
170
  if tag_filter:
171
  filter_func = TAG_FILTER_FUNCS[tag_filter]
172
+ elif pipeline_filter:
173
  filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter
174
+ else:
175
+ filter_func = lambda dct: True
176
+
177
+ # Generate stats with filters
178
+ _, res_plot_df = make_org_stats(count_by, org_stats, top_k=25, filter_func=filter_func, size_range=size_range)
179
+
180
+ # Create DataFrame for Plotly
181
  df = pd.DataFrame(
182
  dict(
183
  organizations=[o for o, _, _ in res_plot_df],
184
+ model=[r for _, r, _ in res_plot_df],
185
  stats=[s for _, _, s in res_plot_df],
186
  )
187
  )
188
+
189
+ df["models"] = "models" # Root node
190
+
191
+ # Create treemap
192
+ fig = px.treemap(df, path=["models", 'organizations', 'model'], values='stats',
193
+ title=f"HuggingFace Models - {count_by.capitalize()} by Organization")
194
+
195
  fig.update_layout(
196
+ margin=dict(t=50, l=25, r=25, b=25)
 
197
  )
198
+
199
  return fig
200
 
201
+ def download_and_process_models():
202
+ """Download and process the models data from HuggingFace dataset"""
203
+ try:
204
+ # Create a cache directory
205
+ if not os.path.exists('data'):
206
+ os.makedirs('data')
207
+
208
+ # Check if we have cached data
209
+ if os.path.exists('data/processed_models.json'):
210
+ print("Loading from cache...")
211
+ with open('data/processed_models.json', 'r') as f:
212
+ return json.load(f)
213
+
214
+ # URL to the models.parquet file
215
+ url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
216
+
217
+ print(f"Downloading models data from {url}...")
218
+ response = requests.get(url)
219
+ if response.status_code != 200:
220
+ raise Exception(f"Failed to download data: HTTP {response.status_code}")
221
+
222
+ # Read the parquet file
223
+ table = pq.read_table(BytesIO(response.content))
224
+ df = table.to_pandas()
225
+
226
+ print(f"Downloaded {len(df)} models")
227
+
228
+ # Process the dataframe into the organization structure we need
229
+ org_stats = {}
230
+
231
+ for _, row in df.iterrows():
232
+ model_id = row['id']
233
+
234
+ # Extract the organization part of the model ID
235
+ if '/' in model_id:
236
+ org_id = model_id.split('/')[0]
237
+ else:
238
+ org_id = "unaffiliated"
239
+
240
+ # Create model entry with needed fields
241
+ model_entry = {
242
+ "id": model_id,
243
+ "downloads": row.get('downloads', 0),
244
+ "likes": row.get('likes', 0),
245
+ "pipeline_tag": row.get('pipeline_tag'),
246
+ "tags": row.get('tags', []),
247
+ }
248
+
249
+ # Add safetensors information if available
250
+ if 'safetensors' in row and row['safetensors']:
251
+ if isinstance(row['safetensors'], dict) and 'total' in row['safetensors']:
252
+ model_entry["safetensors"] = {"total": row['safetensors']['total']}
253
+ elif isinstance(row['safetensors'], str):
254
+ # Try to parse JSON string
255
+ try:
256
+ safetensors = json.loads(row['safetensors'])
257
+ if isinstance(safetensors, dict) and 'total' in safetensors:
258
+ model_entry["safetensors"] = {"total": safetensors['total']}
259
+ except:
260
+ pass
261
+
262
+ # Add to organization stats
263
+ if org_id not in org_stats:
264
+ org_stats[org_id] = []
265
+
266
+ org_stats[org_id].append(model_entry)
267
+
268
+ # Cache the processed data
269
+ with open('data/processed_models.json', 'w') as f:
270
+ json.dump(org_stats, f)
271
+
272
+ return org_stats
273
+
274
+ except Exception as e:
275
+ print(f"Error downloading or processing data: {e}")
276
+ # Return sample data for testing if real data unavailable
277
+ return create_sample_data()
278
 
279
+ def create_sample_data():
280
+ """Create sample data for testing when real data is unavailable"""
281
+ print("Creating sample data for testing...")
282
+
283
+ sample_orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'stability', 'huggingface']
284
+ org_stats = {}
285
+
286
+ for org in sample_orgs:
287
+ org_stats[org] = []
288
+ num_models = 5 # Each org has 5 sample models
289
+
290
+ for i in range(num_models):
291
+ model_id = f"{org}/model-{i+1}"
292
+
293
+ # Random pipeline tag
294
+ pipeline_idx = i % len(PIPELINE_TAGS)
295
+ pipeline_tag = PIPELINE_TAGS[pipeline_idx]
296
+
297
+ # Random tags
298
+ tags = [pipeline_tag, "sample-data"]
299
+
300
+ # Random downloads and likes
301
+ downloads = int(1000 * (10 ** (org_stats.keys().index(org) % 3))) # Different magnitudes
302
+ likes = int(downloads * 0.05) # 5% like rate
303
+
304
+ # Random model size in bytes (from 100MB to 100GB)
305
+ model_size = (10**8) * (10 ** (i % 3)) # Different magnitudes
306
+
307
+ org_stats[org].append({
308
+ "id": model_id,
309
+ "downloads": downloads,
310
+ "likes": likes,
311
+ "pipeline_tag": pipeline_tag,
312
+ "tags": tags,
313
+ "safetensors": {"total": model_size}
314
+ })
315
+
316
+ return org_stats
317
+
318
+ # Create Gradio interface
319
  with gr.Blocks() as demo:
320
+ models_data = gr.State(value=None) # To store loaded data
321
 
322
  with gr.Row():
323
  gr.Markdown("""
324
+ ## HuggingFace Models TreeMap
325
 
326
+ This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
327
+ Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
328
  """)
329
+
330
  with gr.Row():
331
  with gr.Column(scale=1):
 
 
 
 
 
332
  count_by_dropdown = gr.Dropdown(
333
  label="Metric",
334
+ choices=["downloads", "likes"],
335
  value="downloads"
336
  )
337
 
 
347
  value=None,
348
  visible=False
349
  )
350
+
351
  pipeline_filter_dropdown = gr.Dropdown(
352
  label="Select Pipeline Tag",
353
  choices=PIPELINE_TAGS,
354
  value=None,
355
  visible=False
356
  )
357
+
358
+ size_filter_dropdown = gr.Dropdown(
359
+ label="Model Size Filter",
360
+ choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
361
+ value="None"
362
+ )
363
 
364
  generate_plot_button = gr.Button("Generate Plot")
365
 
366
  with gr.Column(scale=3):
367
  plot_output = gr.Plot()
368
 
369
+ def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, data):
370
+ print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}")
371
+
 
 
 
 
 
 
 
 
372
  if data is None:
373
  print("Error: Data not loaded yet.")
374
  return None
375
 
376
  selected_tag_filter = None
377
  selected_pipeline_filter = None
378
+ selected_size_filter = None
379
 
380
  if filter_choice == "Tag Filter":
381
  selected_tag_filter = tag_filter
382
  elif filter_choice == "Pipeline Filter":
383
  selected_pipeline_filter = pipeline_filter
384
+
385
+ if size_filter != "None":
386
+ selected_size_filter = size_filter
387
 
388
  fig = make_figure(
389
  count_by=count_by,
 
390
  org_stats=data,
391
  tag_filter=selected_tag_filter,
392
+ pipeline_filter=selected_pipeline_filter,
393
+ size_range=selected_size_filter
394
  )
395
  return fig
396
 
 
399
  return gr.update(visible=True), gr.update(visible=False)
400
  elif filter_choice == "Pipeline Filter":
401
  return gr.update(visible=False), gr.update(visible=True)
402
+ else: # "None"
403
  return gr.update(visible=False), gr.update(visible=False)
404
 
405
  filter_choice_radio.change(
 
409
  )
410
 
411
  # Load data once at startup
 
 
 
 
 
 
412
  demo.load(
413
+ fn=download_and_process_models,
414
+ inputs=[],
415
+ outputs=[models_data]
416
  )
417
 
418
  # Button click event to generate plot
419
  generate_plot_button.click(
420
  fn=generate_plot_on_click,
421
  inputs=[
 
422
  count_by_dropdown,
423
  filter_choice_radio,
424
  tag_filter_dropdown,
425
  pipeline_filter_dropdown,
426
+ size_filter_dropdown,
427
+ models_data
428
  ],
429
  outputs=[plot_output]
430
  )
431
 
432
 
433
  if __name__ == "__main__":
 
434
  demo.launch()