evijit HF Staff commited on
Commit
9c451ee
·
verified ·
1 Parent(s): caa5704

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -187
app.py CHANGED
@@ -6,9 +6,9 @@ import pyarrow.parquet as pq
6
  import os
7
  import requests
8
  from io import BytesIO
9
- import math
10
 
11
- # Define pipeline tags (keeping the same ones from the provided code)
12
  PIPELINE_TAGS = [
13
  'text-generation',
14
  'text-to-image',
@@ -60,57 +60,59 @@ MODEL_SIZE_RANGES = {
60
  }
61
 
62
  # Filter functions for tags - keeping the same from provided code
63
- def is_audio_speech(repo_dct):
64
- res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \
65
- (repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \
66
- (repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
67
- (repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", [])))
68
- return res
 
69
 
70
- def is_music(repo_dct):
71
- res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", [])))
72
- return res
73
 
74
- def is_robotics(repo_dct):
75
- res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", [])))
76
- return res
77
 
78
- def is_biomed(repo_dct):
79
- res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
80
- (repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", [])))
81
- return res
82
 
83
- def is_timeseries(repo_dct):
84
- res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", [])))
85
- return res
86
 
87
- def is_science(repo_dct):
88
- res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", [])))
89
- return res
90
 
91
- def is_video(repo_dct):
92
- res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", [])))
93
- return res
94
 
95
- def is_image(repo_dct):
96
- res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", [])))
97
- return res
98
 
99
- def is_text(repo_dct):
100
- res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", [])))
101
- return res
102
 
103
  # Add model size filter function
104
- def is_in_size_range(repo_dct, size_range):
105
  if size_range is None:
106
  return True
107
 
108
  min_size, max_size = MODEL_SIZE_RANGES[size_range]
109
 
110
  # Get model size in GB from safetensors total (if available)
111
- if repo_dct.get("safetensors") and repo_dct["safetensors"].get("total"):
 
112
  # Convert bytes to GB
113
- size_gb = repo_dct["safetensors"]["total"] / (1024 * 1024 * 1024)
114
  return min_size <= size_gb < max_size
115
 
116
  return False
@@ -127,251 +129,421 @@ TAG_FILTER_FUNCS = {
127
  "Sciences": is_science,
128
  }
129
 
130
- def make_org_stats(count_by, org_stats, top_k=20, filter_func=None, size_range=None):
131
- assert count_by in ["likes", "downloads"]
132
-
133
- # Apply both filter_func and size_range if provided
134
- def combined_filter(dct):
135
- passes_tag_filter = filter_func(dct) if filter_func else True
136
- passes_size_filter = is_in_size_range(dct, size_range) if size_range else True
137
- return passes_tag_filter and passes_size_filter
138
-
139
- # Sort organizations by total count
140
- sorted_stats = sorted(
141
- [(
142
- org_id,
143
- sum(model[count_by] for model in models if combined_filter(model))
144
- ) for org_id, models in org_stats.items()],
145
- key=lambda x: x[1],
146
- reverse=True,
147
- )
148
 
149
- # Top organizations + Others category
150
- res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))]
151
- total_st = sum(st for o, st in res)
 
152
 
153
- # Prepare data for treemap
154
- res_plot_df = []
155
- for org, st in res:
156
- if org == "Others...":
157
- res_plot_df += [("Others...", "other", st * 100 / total_st if total_st > 0 else 0)]
158
- else:
159
- for model in org_stats[org]:
160
- if combined_filter(model):
161
- res_plot_df += [(org, model["id"], model[count_by] * 100 / total_st if total_st > 0 else 0)]
162
 
163
- return ([(o, 100 * st / total_st if total_st > 0 else 0) for o, st in res if st > 0], res_plot_df)
164
-
165
- def make_figure(count_by, org_stats, tag_filter=None, pipeline_filter=None, size_range=None):
166
- assert count_by in ["downloads", "likes"]
 
 
167
 
168
- # Determine which filter function to use
169
- filter_func = None
170
- if tag_filter:
171
- filter_func = TAG_FILTER_FUNCS[tag_filter]
172
- elif pipeline_filter:
173
- filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter
174
- else:
175
- filter_func = lambda dct: True
176
 
177
- # Generate stats with filters
178
- _, res_plot_df = make_org_stats(count_by, org_stats, top_k=25, filter_func=filter_func, size_range=size_range)
 
179
 
180
- # Create DataFrame for Plotly
181
- df = pd.DataFrame(
182
- dict(
183
- organizations=[o for o, _, _ in res_plot_df],
184
- model=[r for _, r, _ in res_plot_df],
185
- stats=[s for _, _, s in res_plot_df],
186
- )
187
- )
188
 
189
- df["models"] = "models" # Root node
 
190
 
191
- # Create treemap
192
- fig = px.treemap(df, path=["models", 'organizations', 'model'], values='stats',
193
- title=f"HuggingFace Models - {count_by.capitalize()} by Organization")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
 
 
 
 
 
 
 
 
 
195
  fig.update_layout(
196
  margin=dict(t=50, l=25, r=25, b=25)
197
  )
198
 
 
 
 
 
 
 
199
  return fig
200
 
201
- def download_and_process_models():
202
- """Download and process the models data from HuggingFace dataset"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  try:
204
  # Create a cache directory
205
  if not os.path.exists('data'):
206
  os.makedirs('data')
207
 
208
  # Check if we have cached data
209
- if os.path.exists('data/processed_models.json'):
210
- print("Loading from cache...")
211
- with open('data/processed_models.json', 'r') as f:
212
- return json.load(f)
 
 
213
 
214
  # URL to the models.parquet file
215
  url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
216
 
 
 
217
  print(f"Downloading models data from {url}...")
218
- response = requests.get(url)
219
- if response.status_code != 200:
220
- raise Exception(f"Failed to download data: HTTP {response.status_code}")
 
 
 
221
 
222
  # Read the parquet file
223
- table = pq.read_table(BytesIO(response.content))
224
  df = table.to_pandas()
225
 
226
  print(f"Downloaded {len(df)} models")
227
 
228
- # Process the dataframe into the organization structure we need
229
- org_stats = {}
230
 
231
- for _, row in df.iterrows():
232
- model_id = row['id']
233
-
234
- # Extract the organization part of the model ID
235
- if '/' in model_id:
236
- org_id = model_id.split('/')[0]
237
- else:
238
- org_id = "unaffiliated"
239
-
240
- # Create model entry with needed fields
241
- model_entry = {
242
- "id": model_id,
243
- "downloads": row.get('downloads', 0),
244
- "likes": row.get('likes', 0),
245
- "pipeline_tag": row.get('pipeline_tag'),
246
- "tags": row.get('tags', []),
247
- }
248
-
249
- # Add safetensors information if available
250
- if 'safetensors' in row and row['safetensors']:
251
- if isinstance(row['safetensors'], dict) and 'total' in row['safetensors']:
252
- model_entry["safetensors"] = {"total": row['safetensors']['total']}
253
- elif isinstance(row['safetensors'], str):
254
- # Try to parse JSON string
255
  try:
256
- safetensors = json.loads(row['safetensors'])
257
- if isinstance(safetensors, dict) and 'total' in safetensors:
258
- model_entry["safetensors"] = {"total": safetensors['total']}
259
  except:
260
- pass
 
261
 
262
- # Add to organization stats
263
- if org_id not in org_stats:
264
- org_stats[org_id] = []
 
 
 
 
 
 
 
 
265
 
266
- org_stats[org_id].append(model_entry)
267
 
268
  # Cache the processed data
269
- with open('data/processed_models.json', 'w') as f:
270
- json.dump(org_stats, f)
 
 
 
 
271
 
272
- return org_stats
273
 
274
  except Exception as e:
275
  print(f"Error downloading or processing data: {e}")
 
 
276
  # Return sample data for testing if real data unavailable
277
  return create_sample_data()
278
 
279
- def create_sample_data():
280
  """Create sample data for testing when real data is unavailable"""
281
  print("Creating sample data for testing...")
282
 
283
- sample_orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'stability', 'huggingface']
284
- org_stats = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- for org in sample_orgs:
287
- org_stats[org] = []
288
- num_models = 5 # Each org has 5 sample models
 
 
 
 
 
289
 
290
  for i in range(num_models):
291
- model_id = f"{org}/model-{i+1}"
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- # Random pipeline tag
294
- pipeline_idx = i % len(PIPELINE_TAGS)
295
- pipeline_tag = PIPELINE_TAGS[pipeline_idx]
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- # Random tags
298
- tags = [pipeline_tag, "sample-data"]
299
 
300
- # Random downloads and likes
301
- downloads = int(1000 * (10 ** (org_stats.keys().index(org) % 3))) # Different magnitudes
302
- likes = int(downloads * 0.05) # 5% like rate
 
 
 
 
 
 
 
303
 
304
- # Random model size in bytes (from 100MB to 100GB)
305
- model_size = (10**8) * (10 ** (i % 3)) # Different magnitudes
 
 
 
306
 
307
- org_stats[org].append({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  "id": model_id,
309
  "downloads": downloads,
 
310
  "likes": likes,
311
  "pipeline_tag": pipeline_tag,
312
  "tags": tags,
313
- "safetensors": {"total": model_size}
314
- })
 
 
 
 
 
 
 
 
 
315
 
316
- return org_stats
 
 
 
317
 
318
  # Create Gradio interface
319
  with gr.Blocks() as demo:
320
- models_data = gr.State(value=None) # To store loaded data
321
-
322
- with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
323
  gr.Markdown("""
324
- ## HuggingFace Models TreeMap
325
 
326
  This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
327
  Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
 
 
328
  """)
329
 
330
- with gr.Row():
331
  with gr.Column(scale=1):
332
  count_by_dropdown = gr.Dropdown(
333
  label="Metric",
334
- choices=["downloads", "likes"],
335
- value="downloads"
 
336
  )
337
 
338
  filter_choice_radio = gr.Radio(
339
- label="Filter by",
340
  choices=["None", "Tag Filter", "Pipeline Filter"],
341
- value="None"
 
342
  )
343
 
344
  tag_filter_dropdown = gr.Dropdown(
345
  label="Select Tag",
346
  choices=list(TAG_FILTER_FUNCS.keys()),
347
  value=None,
348
- visible=False
 
349
  )
350
 
351
  pipeline_filter_dropdown = gr.Dropdown(
352
  label="Select Pipeline Tag",
353
  choices=PIPELINE_TAGS,
354
  value=None,
355
- visible=False
 
356
  )
357
 
358
  size_filter_dropdown = gr.Dropdown(
359
  label="Model Size Filter",
360
  choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
361
- value="None"
 
 
 
 
 
 
 
 
 
 
362
  )
363
 
364
- generate_plot_button = gr.Button("Generate Plot")
365
 
366
  with gr.Column(scale=3):
367
  plot_output = gr.Plot()
 
368
 
369
- def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, data):
370
- print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}")
371
 
372
- if data is None:
373
- print("Error: Data not loaded yet.")
374
- return None
375
 
376
  selected_tag_filter = None
377
  selected_pipeline_filter = None
@@ -384,15 +556,47 @@ with gr.Blocks() as demo:
384
 
385
  if size_filter != "None":
386
  selected_size_filter = size_filter
387
-
388
- fig = make_figure(
 
 
389
  count_by=count_by,
390
- org_stats=data,
391
  tag_filter=selected_tag_filter,
392
  pipeline_filter=selected_pipeline_filter,
393
- size_range=selected_size_filter
394
  )
395
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  def update_filter_visibility(filter_choice):
398
  if filter_choice == "Tag Filter":
@@ -408,11 +612,17 @@ with gr.Blocks() as demo:
408
  outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
409
  )
410
 
411
- # Load data once at startup
 
 
 
 
 
 
412
  demo.load(
413
- fn=download_and_process_models,
414
  inputs=[],
415
- outputs=[models_data]
416
  )
417
 
418
  # Button click event to generate plot
@@ -424,9 +634,10 @@ with gr.Blocks() as demo:
424
  tag_filter_dropdown,
425
  pipeline_filter_dropdown,
426
  size_filter_dropdown,
 
427
  models_data
428
  ],
429
- outputs=[plot_output]
430
  )
431
 
432
 
 
6
  import os
7
  import requests
8
  from io import BytesIO
9
+ import numpy as np
10
 
11
+ # Define pipeline tags from the provided code
12
  PIPELINE_TAGS = [
13
  'text-generation',
14
  'text-to-image',
 
60
  }
61
 
62
  # Filter functions for tags - keeping the same from provided code
63
+ def is_audio_speech(model_dict):
64
+ tags = model_dict.get("tags", [])
65
+ pipeline_tag = model_dict.get("pipeline_tag", "")
66
+
67
+ return (pipeline_tag and ("audio" in pipeline_tag.lower() or "speech" in pipeline_tag.lower())) or \
68
+ any("audio" in tag.lower() for tag in tags) or \
69
+ any("speech" in tag.lower() for tag in tags)
70
 
71
+ def is_music(model_dict):
72
+ tags = model_dict.get("tags", [])
73
+ return any("music" in tag.lower() for tag in tags)
74
 
75
+ def is_robotics(model_dict):
76
+ tags = model_dict.get("tags", [])
77
+ return any("robot" in tag.lower() for tag in tags)
78
 
79
+ def is_biomed(model_dict):
80
+ tags = model_dict.get("tags", [])
81
+ return any("bio" in tag.lower() for tag in tags) or \
82
+ any("medic" in tag.lower() for tag in tags)
83
 
84
+ def is_timeseries(model_dict):
85
+ tags = model_dict.get("tags", [])
86
+ return any("series" in tag.lower() for tag in tags)
87
 
88
+ def is_science(model_dict):
89
+ tags = model_dict.get("tags", [])
90
+ return any("science" in tag.lower() and "bigscience" not in tag for tag in tags)
91
 
92
+ def is_video(model_dict):
93
+ tags = model_dict.get("tags", [])
94
+ return any("video" in tag.lower() for tag in tags)
95
 
96
+ def is_image(model_dict):
97
+ tags = model_dict.get("tags", [])
98
+ return any("image" in tag.lower() for tag in tags)
99
 
100
+ def is_text(model_dict):
101
+ tags = model_dict.get("tags", [])
102
+ return any("text" in tag.lower() for tag in tags)
103
 
104
  # Add model size filter function
105
+ def is_in_size_range(model_dict, size_range):
106
  if size_range is None:
107
  return True
108
 
109
  min_size, max_size = MODEL_SIZE_RANGES[size_range]
110
 
111
  # Get model size in GB from safetensors total (if available)
112
+ safetensors = model_dict.get("safetensors", None)
113
+ if safetensors and isinstance(safetensors, dict) and "total" in safetensors:
114
  # Convert bytes to GB
115
+ size_gb = safetensors["total"] / (1024 * 1024 * 1024)
116
  return min_size <= size_gb < max_size
117
 
118
  return False
 
129
  "Sciences": is_science,
130
  }
131
 
132
+ def extract_org_from_id(model_id):
133
+ """Extract organization name from model ID"""
134
+ if "/" in model_id:
135
+ return model_id.split("/")[0]
136
+ return "unaffiliated"
137
+
138
+ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None):
139
+ """Process DataFrame into treemap format with filters applied"""
140
+ # Create a copy to avoid modifying the original
141
+ filtered_df = df.copy()
 
 
 
 
 
 
 
 
142
 
143
+ # Apply filters
144
+ if tag_filter and tag_filter in TAG_FILTER_FUNCS:
145
+ filter_func = TAG_FILTER_FUNCS[tag_filter]
146
+ filtered_df = filtered_df[filtered_df.apply(filter_func, axis=1)]
147
 
148
+ if pipeline_filter:
149
+ filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
 
 
 
 
 
 
 
150
 
151
+ if size_filter and size_filter in MODEL_SIZE_RANGES:
152
+ # Create a function to check if a model is in the size range
153
+ def check_size(row):
154
+ return is_in_size_range(row, size_filter)
155
+
156
+ filtered_df = filtered_df[filtered_df.apply(check_size, axis=1)]
157
 
158
+ # Add organization column
159
+ filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
 
 
 
 
 
 
160
 
161
+ # Aggregate by organization
162
+ org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
163
+ org_totals = org_totals.sort_values(by=count_by, ascending=False)
164
 
165
+ # Get top organizations
166
+ top_orgs = org_totals.head(top_k)["organization"].tolist()
167
+
168
+ # Filter to only include models from top organizations
169
+ filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
170
+
171
+ # Prepare data for treemap
172
+ treemap_data = filtered_df[["id", "organization", count_by]].copy()
173
 
174
+ # Add a root node
175
+ treemap_data["root"] = "models"
176
 
177
+ # Ensure numeric values
178
+ treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
179
+
180
+ return treemap_data
181
+
182
+ def create_treemap(treemap_data, count_by, title=None):
183
+ """Create a Plotly treemap from the prepared data"""
184
+ if treemap_data.empty:
185
+ # Create an empty figure with a message
186
+ fig = px.treemap(
187
+ names=["No data matches the selected filters"],
188
+ values=[1]
189
+ )
190
+ fig.update_layout(
191
+ title="No data matches the selected filters",
192
+ margin=dict(t=50, l=25, r=25, b=25)
193
+ )
194
+ return fig
195
 
196
+ # Create the treemap
197
+ fig = px.treemap(
198
+ treemap_data,
199
+ path=["root", "organization", "id"],
200
+ values=count_by,
201
+ title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization"
202
+ )
203
+
204
+ # Update layout
205
  fig.update_layout(
206
  margin=dict(t=50, l=25, r=25, b=25)
207
  )
208
 
209
+ # Update traces for better readability
210
+ fig.update_traces(
211
+ textinfo="label+value+percent root",
212
+ hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
213
+ )
214
+
215
  return fig
216
 
217
+ def download_with_progress(url, progress=None):
218
+ """Download a file with progress tracking"""
219
+ response = requests.get(url, stream=True)
220
+ total_size = int(response.headers.get('content-length', 0))
221
+ block_size = 1024 # 1 Kibibyte
222
+ data = BytesIO()
223
+
224
+ if total_size == 0:
225
+ # If content length is unknown, we can't show accurate progress
226
+ if progress:
227
+ progress(0, "Starting download...")
228
+
229
+ for chunk in response.iter_content(block_size):
230
+ data.write(chunk)
231
+ if progress:
232
+ progress(0, f"Downloading... (unknown size)")
233
+ else:
234
+ downloaded = 0
235
+ for chunk in response.iter_content(block_size):
236
+ downloaded += len(chunk)
237
+ data.write(chunk)
238
+ if progress:
239
+ percent = int(100 * downloaded / total_size)
240
+ progress(percent / 100, f"Downloading... {percent}% ({downloaded//(1024*1024)}MB/{total_size//(1024*1024)}MB)")
241
+
242
+ return data.getvalue()
243
+
244
+ def download_and_process_models(progress=None):
245
+ """Download and process the models data from HuggingFace dataset with progress tracking"""
246
  try:
247
  # Create a cache directory
248
  if not os.path.exists('data'):
249
  os.makedirs('data')
250
 
251
  # Check if we have cached data
252
+ if os.path.exists('data/processed_models.parquet'):
253
+ if progress:
254
+ progress(1.0, "Loading from cache...")
255
+ print("Loading models from cache...")
256
+ df = pd.read_parquet('data/processed_models.parquet')
257
+ return df
258
 
259
  # URL to the models.parquet file
260
  url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
261
 
262
+ if progress:
263
+ progress(0.0, "Starting download...")
264
  print(f"Downloading models data from {url}...")
265
+
266
+ # Download with progress tracking
267
+ file_content = download_with_progress(url, progress)
268
+
269
+ if progress:
270
+ progress(0.9, "Parsing parquet file...")
271
 
272
  # Read the parquet file
273
+ table = pq.read_table(BytesIO(file_content))
274
  df = table.to_pandas()
275
 
276
  print(f"Downloaded {len(df)} models")
277
 
278
+ if progress:
279
+ progress(0.95, "Processing data...")
280
 
281
+ # Process the safetensors column if it's a string (JSON)
282
+ if 'safetensors' in df.columns:
283
+ def parse_safetensors(val):
284
+ if isinstance(val, str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  try:
286
+ return json.loads(val)
 
 
287
  except:
288
+ return None
289
+ return val
290
 
291
+ df['safetensors'] = df['safetensors'].apply(parse_safetensors)
292
+
293
+ # Process the tags column if needed
294
+ if 'tags' in df.columns and not isinstance(df['tags'].iloc[0], list):
295
+ def parse_tags(val):
296
+ if isinstance(val, str):
297
+ try:
298
+ return json.loads(val)
299
+ except:
300
+ return []
301
+ return val if isinstance(val, list) else []
302
 
303
+ df['tags'] = df['tags'].apply(parse_tags)
304
 
305
  # Cache the processed data
306
+ if progress:
307
+ progress(0.98, "Saving to cache...")
308
+ df.to_parquet('data/processed_models.parquet')
309
+
310
+ if progress:
311
+ progress(1.0, "Data ready!")
312
 
313
+ return df
314
 
315
  except Exception as e:
316
  print(f"Error downloading or processing data: {e}")
317
+ if progress:
318
+ progress(1.0, "Using sample data (download failed)")
319
  # Return sample data for testing if real data unavailable
320
  return create_sample_data()
321
 
322
+ def create_sample_data(progress=None):
323
  """Create sample data for testing when real data is unavailable"""
324
  print("Creating sample data for testing...")
325
 
326
+ if progress:
327
+ progress(0.3, "Creating sample data...")
328
+
329
+ # Sample organizations
330
+ orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'nvidia', 'huggingface',
331
+ 'deepseek-ai', 'stability-ai', 'mistralai', 'cerebras', 'databricks', 'together',
332
+ 'facebook', 'amazon', 'deepmind', 'cohere', 'nvidia', 'bigscience', 'eleutherai']
333
+
334
+ # Common model name formats
335
+ model_name_patterns = [
336
+ "model-{size}-{version}",
337
+ "{prefix}-{size}b",
338
+ "{prefix}-{size}b-{variant}",
339
+ "llama-{size}b-{variant}",
340
+ "gpt-{variant}-{size}b",
341
+ "{prefix}-instruct-{size}b",
342
+ "{prefix}-chat-{size}b",
343
+ "{prefix}-coder-{size}b",
344
+ "stable-diffusion-{version}",
345
+ "whisper-{size}",
346
+ "bert-{size}-{variant}",
347
+ "roberta-{size}",
348
+ "t5-{size}",
349
+ "{prefix}-vision-{size}b"
350
+ ]
351
+
352
+ # Common name parts
353
+ prefixes = ["falcon", "llama", "mistral", "gpt", "phi", "gemma", "qwen", "yi", "mpt", "bloom"]
354
+ sizes = ["7", "13", "34", "70", "1", "3", "7b", "13b", "70b", "8b", "2b", "1b", "0.5b", "small", "base", "large", "huge"]
355
+ variants = ["chat", "instruct", "base", "v1.0", "v2", "beta", "turbo", "fast", "xl", "xxl"]
356
 
357
+ # Generate sample data
358
+ data = []
359
+ total_models = sum(np.random.randint(5, 20) for _ in orgs)
360
+ models_created = 0
361
+
362
+ for org_idx, org in enumerate(orgs):
363
+ # Create 5-20 models per organization
364
+ num_models = np.random.randint(5, 20)
365
 
366
  for i in range(num_models):
367
+ # Create realistic model name
368
+ pattern = np.random.choice(model_name_patterns)
369
+ prefix = np.random.choice(prefixes)
370
+ size = np.random.choice(sizes)
371
+ version = f"v{np.random.randint(1, 4)}"
372
+ variant = np.random.choice(variants)
373
+
374
+ model_name = pattern.format(
375
+ prefix=prefix,
376
+ size=size,
377
+ version=version,
378
+ variant=variant
379
+ )
380
 
381
+ model_id = f"{org}/{model_name}"
382
+
383
+ # Select a realistic pipeline tag based on name
384
+ if "diffusion" in model_name or "image" in model_name:
385
+ pipeline_tag = np.random.choice(["text-to-image", "image-to-image", "image-segmentation"])
386
+ elif "whisper" in model_name or "speech" in model_name:
387
+ pipeline_tag = np.random.choice(["automatic-speech-recognition", "text-to-speech"])
388
+ elif "coder" in model_name or "code" in model_name:
389
+ pipeline_tag = "text-generation"
390
+ elif "bert" in model_name or "roberta" in model_name:
391
+ pipeline_tag = np.random.choice(["fill-mask", "text-classification", "token-classification"])
392
+ elif "vision" in model_name:
393
+ pipeline_tag = np.random.choice(["image-classification", "image-to-text", "visual-question-answering"])
394
+ else:
395
+ pipeline_tag = "text-generation" # Most common
396
 
397
+ # Generate realistic tags
398
+ tags = [pipeline_tag]
399
 
400
+ if "text-generation" in pipeline_tag:
401
+ tags.extend(["language-model", "text", "gpt", "llm"])
402
+ if "instruct" in model_name:
403
+ tags.append("instruction-following")
404
+ if "chat" in model_name:
405
+ tags.append("chat")
406
+ elif "speech" in pipeline_tag:
407
+ tags.extend(["audio", "speech", "voice"])
408
+ elif "image" in pipeline_tag:
409
+ tags.extend(["vision", "image", "diffusion"])
410
 
411
+ # Add language tags
412
+ if np.random.random() < 0.8: # 80% chance for English
413
+ tags.append("en")
414
+ if np.random.random() < 0.3: # 30% chance for multilingual
415
+ tags.append("multilingual")
416
 
417
+ # Generate downloads and likes (weighted by org position for variety)
418
+ # Earlier orgs get more downloads to make the visualization interesting
419
+ popularity_factor = (len(orgs) - org_idx) / len(orgs) # 1.0 to 0.0
420
+ base_downloads = 1000 * (10 ** (2 * popularity_factor))
421
+ downloads = int(base_downloads * np.random.uniform(0.3, 3.0))
422
+ likes = int(downloads * np.random.uniform(0.01, 0.1)) # 1-10% like ratio
423
+
424
+ # Generate model size (in bytes for safetensors total)
425
+ # Model size should correlate somewhat with the size in the name
426
+ size_indicator = 1
427
+ for s in ["70b", "13b", "7b", "3b", "2b", "1b", "large", "huge", "xl", "xxl"]:
428
+ if s in model_name.lower():
429
+ size_indicator = float(s.replace("b", "")) if s[0].isdigit() else 3
430
+ break
431
+
432
+ # Size in GB, then convert to bytes
433
+ size_gb = np.random.uniform(0.1, 2.0) * size_indicator
434
+ if size_gb > 50: # Cap at 100GB
435
+ size_gb = min(size_gb, 100)
436
+ size_bytes = int(size_gb * 1e9)
437
+
438
+ # Create model entry
439
+ model = {
440
  "id": model_id,
441
  "downloads": downloads,
442
+ "downloadsAllTime": int(downloads * np.random.uniform(1.5, 3.0)), # All-time higher than recent
443
  "likes": likes,
444
  "pipeline_tag": pipeline_tag,
445
  "tags": tags,
446
+ "safetensors": {"total": size_bytes}
447
+ }
448
+
449
+ data.append(model)
450
+ models_created += 1
451
+
452
+ if progress and i % 5 == 0:
453
+ progress(0.3 + 0.6 * (models_created / total_models), f"Created {models_created}/{total_models} sample models...")
454
+
455
+ # Convert to DataFrame
456
+ df = pd.DataFrame(data)
457
 
458
+ if progress:
459
+ progress(0.95, "Finalizing sample data...")
460
+
461
+ return df
462
 
463
  # Create Gradio interface
464
  with gr.Blocks() as demo:
465
+ models_data = gr.State() # To store loaded data
466
+
467
+ # Loading screen components
468
+ with gr.Row(visible=True) as loading_screen:
469
+ with gr.Column(scale=1):
470
+ gr.Markdown("""
471
+ # HuggingFace Models TreeMap Visualization
472
+
473
+ Loading data... This might take a moment.
474
+ """)
475
+ data_loading_progress = gr.Progress()
476
+
477
+ # Main application components (initially hidden)
478
+ with gr.Row(visible=False) as main_app:
479
  gr.Markdown("""
480
+ # HuggingFace Models TreeMap Visualization
481
 
482
  This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
483
  Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
484
+
485
+ The treemap visualizes models grouped by organization, with the size of each box representing the selected metric (downloads or likes).
486
  """)
487
 
488
+ with gr.Row(visible=False) as control_panel:
489
  with gr.Column(scale=1):
490
  count_by_dropdown = gr.Dropdown(
491
  label="Metric",
492
+ choices=["downloads", "downloadsAllTime", "likes"],
493
+ value="downloads",
494
+ info="Select the metric to determine box sizes"
495
  )
496
 
497
  filter_choice_radio = gr.Radio(
498
+ label="Filter Type",
499
  choices=["None", "Tag Filter", "Pipeline Filter"],
500
+ value="None",
501
+ info="Choose how to filter the models"
502
  )
503
 
504
  tag_filter_dropdown = gr.Dropdown(
505
  label="Select Tag",
506
  choices=list(TAG_FILTER_FUNCS.keys()),
507
  value=None,
508
+ visible=False,
509
+ info="Filter models by domain/category"
510
  )
511
 
512
  pipeline_filter_dropdown = gr.Dropdown(
513
  label="Select Pipeline Tag",
514
  choices=PIPELINE_TAGS,
515
  value=None,
516
+ visible=False,
517
+ info="Filter models by specific pipeline"
518
  )
519
 
520
  size_filter_dropdown = gr.Dropdown(
521
  label="Model Size Filter",
522
  choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
523
+ value="None",
524
+ info="Filter models by their size (in safetensors['total'])"
525
+ )
526
+
527
+ top_k_slider = gr.Slider(
528
+ label="Number of Top Organizations",
529
+ minimum=5,
530
+ maximum=50,
531
+ value=25,
532
+ step=5,
533
+ info="Number of top organizations to include"
534
  )
535
 
536
+ generate_plot_button = gr.Button("Generate Plot", variant="primary")
537
 
538
  with gr.Column(scale=3):
539
  plot_output = gr.Plot()
540
+ stats_output = gr.Markdown("*Generate a plot to see statistics*")
541
 
542
+ def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, data_df):
543
+ print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}, Top K={top_k}")
544
 
545
+ if data_df is None or len(data_df) == 0:
546
+ return None, "Error: No data available. Please try again."
 
547
 
548
  selected_tag_filter = None
549
  selected_pipeline_filter = None
 
556
 
557
  if size_filter != "None":
558
  selected_size_filter = size_filter
559
+
560
+ # Process data for treemap
561
+ treemap_data = make_treemap_data(
562
+ df=data_df,
563
  count_by=count_by,
564
+ top_k=top_k,
565
  tag_filter=selected_tag_filter,
566
  pipeline_filter=selected_pipeline_filter,
567
+ size_filter=selected_size_filter
568
  )
569
+
570
+ # Create plot
571
+ fig = create_treemap(
572
+ treemap_data=treemap_data,
573
+ count_by=count_by,
574
+ title=f"HuggingFace Models - {count_by.capitalize()} by Organization"
575
+ )
576
+
577
+ # Generate statistics
578
+ if treemap_data.empty:
579
+ stats_md = "No data matches the selected filters."
580
+ else:
581
+ total_models = len(treemap_data)
582
+ total_value = treemap_data[count_by].sum()
583
+ top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
584
+
585
+ stats_md = f"""
586
+ ### Statistics
587
+ - **Total models shown**: {total_models:,}
588
+ - **Total {count_by}**: {total_value:,}
589
+
590
+ ### Top 5 Organizations
591
+ | Organization | {count_by.capitalize()} | % of Total |
592
+ | --- | --- | --- |
593
+ """
594
+
595
+ for org, value in top_5_orgs.items():
596
+ percentage = (value / total_value) * 100
597
+ stats_md += f"| {org} | {value:,} | {percentage:.2f}% |\n"
598
+
599
+ return fig, stats_md
600
 
601
  def update_filter_visibility(filter_choice):
602
  if filter_choice == "Tag Filter":
 
612
  outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
613
  )
614
 
615
+ def load_data_with_progress(progress=gr.Progress()):
616
+ """Load data with progress tracking and update UI visibility"""
617
+ data_df = download_and_process_models(progress)
618
+ # Return both the data and the visibility updates
619
+ return data_df, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
620
+
621
+ # Load data once at startup with progress bar
622
  demo.load(
623
+ fn=load_data_with_progress,
624
  inputs=[],
625
+ outputs=[models_data, loading_screen, main_app, control_panel]
626
  )
627
 
628
  # Button click event to generate plot
 
634
  tag_filter_dropdown,
635
  pipeline_filter_dropdown,
636
  size_filter_dropdown,
637
+ top_k_slider,
638
  models_data
639
  ],
640
+ outputs=[plot_output, stats_output]
641
  )
642
 
643