evijit HF Staff commited on
Commit
6f8106d
·
verified ·
1 Parent(s): 57e108c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +447 -78
app.py CHANGED
@@ -5,6 +5,7 @@ import plotly.express as px
5
  import os
6
  import numpy as np
7
  import io
 
8
 
9
  # Define pipeline tags
10
  PIPELINE_TAGS = [
@@ -57,65 +58,120 @@ MODEL_SIZE_RANGES = {
57
  "XX-Large (>50GB)": (50, float('inf'))
58
  }
59
 
60
- # Filter functions for tags
61
  def is_audio_speech(row):
62
- tags = row.get("tags", [])
63
- pipeline_tag = row.get("pipeline_tag", "")
64
-
65
- return (pipeline_tag and ("audio" in pipeline_tag.lower() or "speech" in pipeline_tag.lower())) or \
66
- any("audio" in tag.lower() for tag in tags) or \
67
- any("speech" in tag.lower() for tag in tags)
68
 
69
  def is_music(row):
70
- tags = row.get("tags", [])
71
- return any("music" in tag.lower() for tag in tags)
72
 
73
  def is_robotics(row):
74
- tags = row.get("tags", [])
75
- return any("robot" in tag.lower() for tag in tags)
76
 
77
  def is_biomed(row):
78
- tags = row.get("tags", [])
79
- return any("bio" in tag.lower() for tag in tags) or \
80
- any("medic" in tag.lower() for tag in tags)
81
 
82
  def is_timeseries(row):
83
- tags = row.get("tags", [])
84
- return any("series" in tag.lower() for tag in tags)
85
 
86
  def is_science(row):
87
- tags = row.get("tags", [])
88
- return any("science" in tag.lower() and "bigscience" not in tag for tag in tags)
89
 
90
  def is_video(row):
91
- tags = row.get("tags", [])
92
- return any("video" in tag.lower() for tag in tags)
 
 
 
 
 
 
 
 
93
 
94
  def is_image(row):
95
  tags = row.get("tags", [])
96
- return any("image" in tag.lower() for tag in tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def is_text(row):
99
  tags = row.get("tags", [])
100
- return any("text" in tag.lower() for tag in tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # Add model size filter function
103
  def is_in_size_range(row, size_range):
104
- if size_range is None:
 
105
  return True
106
 
107
- min_size, max_size = MODEL_SIZE_RANGES[size_range]
108
-
109
- # Get model size in GB from params column
110
- if "params" in row and pd.notna(row["params"]):
111
- try:
112
- # Convert to GB (assuming params are in bytes or scientific notation)
113
- size_gb = float(row["params"]) / (1024 * 1024 * 1024)
114
- return min_size <= size_gb < max_size
115
- except (ValueError, TypeError):
116
- return False
117
-
118
- return False
119
 
120
  TAG_FILTER_FUNCS = {
121
  "Audio & Speech": is_audio_speech,
@@ -136,24 +192,64 @@ def extract_org_from_id(model_id):
136
  return "unaffiliated"
137
 
138
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
139
- """Process DataFrame into treemap format with filters applied"""
140
  # Create a copy to avoid modifying the original
141
  filtered_df = df.copy()
142
 
143
  # Apply filters
 
 
 
 
144
  if tag_filter and tag_filter in TAG_FILTER_FUNCS:
145
- filter_func = TAG_FILTER_FUNCS[tag_filter]
146
- filtered_df = filtered_df[filtered_df.apply(filter_func, axis=1)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
148
  if pipeline_filter:
 
149
  filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
 
 
 
150
 
 
151
  if size_filter and size_filter in MODEL_SIZE_RANGES:
152
- # Create a function to check if a model is in the size range
153
- def check_size(row):
154
- return is_in_size_range(row, size_filter)
 
 
 
 
 
155
 
156
- filtered_df = filtered_df[filtered_df.apply(check_size, axis=1)]
 
 
157
 
158
  # Add organization column
159
  filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
@@ -161,6 +257,17 @@ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=N
161
  # Skip organizations if specified
162
  if skip_orgs and len(skip_orgs) > 0:
163
  filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # Aggregate by organization
166
  org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
@@ -181,6 +288,7 @@ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=N
181
  # Ensure numeric values
182
  treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
183
 
 
184
  return treemap_data
185
 
186
  def create_treemap(treemap_data, count_by, title=None):
@@ -219,23 +327,271 @@ def create_treemap(treemap_data, count_by, title=None):
219
 
220
  return fig
221
 
222
- def load_models_csv():
223
- # Read the CSV file
224
- df = pd.read_csv('models.csv')
225
-
226
- # Process the tags column
227
- def process_tags(tags_str):
228
- if pd.isna(tags_str):
229
- return []
230
 
231
- # Clean the string and convert to a list
232
- tags_str = tags_str.strip("[]").replace("'", "")
233
- tags = [tag.strip() for tag in tags_str.split() if tag.strip()]
234
- return tags
235
-
236
- df['tags'] = df['tags'].apply(process_tags)
237
-
238
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  # Create Gradio interface
241
  with gr.Blocks() as demo:
@@ -250,6 +606,7 @@ with gr.Blocks() as demo:
250
  Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
251
 
252
  The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
 
253
  """)
254
 
255
  with gr.Row():
@@ -311,24 +668,12 @@ with gr.Blocks() as demo:
311
  )
312
 
313
  generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
 
314
 
315
  with gr.Column(scale=3):
316
  plot_output = gr.Plot()
317
- stats_output = gr.Markdown("*Generate a plot to see statistics*")
318
-
319
- # Updated load function returning both the data and loading flag
320
- def load_models_csv():
321
- df = pd.read_csv('models.csv')
322
-
323
- def process_tags(tags_str):
324
- if pd.isna(tags_str):
325
- return []
326
- tags_str = tags_str.strip("[]").replace("'", "")
327
- tags = [tag.strip() for tag in tags_str.split() if tag.strip()]
328
- return tags
329
-
330
- df['tags'] = df['tags'].apply(process_tags)
331
- return df, True
332
 
333
  # Button enablement after data load
334
  def enable_plot_button(loaded):
@@ -355,6 +700,25 @@ with gr.Blocks() as demo:
355
  outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
356
  )
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  # Main generate function
359
  def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
360
  if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
@@ -451,9 +815,16 @@ with gr.Blocks() as demo:
451
 
452
  # Load data at startup
453
  demo.load(
454
- fn=load_models_csv,
455
  inputs=[],
456
- outputs=[models_data, loading_complete]
 
 
 
 
 
 
 
457
  )
458
 
459
  generate_plot_button.click(
@@ -471,7 +842,5 @@ with gr.Blocks() as demo:
471
  outputs=[plot_output, stats_output]
472
  )
473
 
474
-
475
-
476
  if __name__ == "__main__":
477
  demo.launch()
 
5
  import os
6
  import numpy as np
7
  import io
8
+ import duckdb
9
 
10
  # Define pipeline tags
11
  PIPELINE_TAGS = [
 
58
  "XX-Large (>50GB)": (50, float('inf'))
59
  }
60
 
61
+ # Filter functions for tags - UPDATED to use cached columns
62
  def is_audio_speech(row):
63
+ # Use cached column instead of recalculating
64
+ return row['is_audio_speech']
 
 
 
 
65
 
66
  def is_music(row):
67
+ # Use cached column instead of recalculating
68
+ return row['has_music']
69
 
70
  def is_robotics(row):
71
+ # Use cached column instead of recalculating
72
+ return row['has_robot']
73
 
74
  def is_biomed(row):
75
+ # Use cached column instead of recalculating
76
+ return row['is_biomed']
 
77
 
78
  def is_timeseries(row):
79
+ # Use cached column instead of recalculating
80
+ return row['has_series']
81
 
82
  def is_science(row):
83
+ # Use cached column instead of recalculating
84
+ return row['has_science']
85
 
86
  def is_video(row):
87
+ # Use cached column instead of recalculating
88
+ return row['has_video']
89
+
90
+ def is_image(row):
91
+ # Use cached column instead of recalculating
92
+ return row['has_image']
93
+
94
+ def is_text(row):
95
+ # Use cached column instead of recalculating
96
+ return row['has_text']
97
 
98
  def is_image(row):
99
  tags = row.get("tags", [])
100
+
101
+ # Check if tags exists and is not empty
102
+ if tags is not None:
103
+ # For numpy arrays
104
+ if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
105
+ # Convert numpy array to list
106
+ tags_list = tags.tolist()
107
+ return any("image" in str(tag).lower() for tag in tags_list)
108
+ # For regular lists
109
+ elif isinstance(tags, list):
110
+ return any("image" in str(tag).lower() for tag in tags)
111
+ # For string tags
112
+ elif isinstance(tags, str):
113
+ return "image" in tags.lower()
114
+ return False
115
 
116
  def is_text(row):
117
  tags = row.get("tags", [])
118
+
119
+ # Check if tags exists and is not empty
120
+ if tags is not None:
121
+ # For numpy arrays
122
+ if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
123
+ # Convert numpy array to list
124
+ tags_list = tags.tolist()
125
+ return any("text" in str(tag).lower() for tag in tags_list)
126
+ # For regular lists
127
+ elif isinstance(tags, list):
128
+ return any("text" in str(tag).lower() for tag in tags)
129
+ # For string tags
130
+ elif isinstance(tags, str):
131
+ return "text" in tags.lower()
132
+ return False
133
+
134
+ def extract_model_size(safetensors_data):
135
+ """Extract model size in GB from safetensors data"""
136
+ try:
137
+ if pd.isna(safetensors_data):
138
+ return 0
139
+
140
+ # If it's already a dictionary, use it directly
141
+ if isinstance(safetensors_data, dict):
142
+ if 'total' in safetensors_data:
143
+ try:
144
+ size_bytes = float(safetensors_data['total'])
145
+ return size_bytes / (1024 * 1024 * 1024) # Convert to GB
146
+ except (ValueError, TypeError):
147
+ pass
148
+
149
+ # If it's a string, try to parse it as JSON
150
+ elif isinstance(safetensors_data, str):
151
+ try:
152
+ data_dict = json.loads(safetensors_data)
153
+ if 'total' in data_dict:
154
+ try:
155
+ size_bytes = float(data_dict['total'])
156
+ return size_bytes / (1024 * 1024 * 1024) # Convert to GB
157
+ except (ValueError, TypeError):
158
+ pass
159
+ except:
160
+ pass
161
+
162
+ return 0
163
+ except Exception as e:
164
+ print(f"Error extracting model size: {e}")
165
+ return 0
166
 
167
+ # Add model size filter function - UPDATED to use cached size_category column
168
  def is_in_size_range(row, size_range):
169
+ """Check if a model is in the specified size range using pre-calculated size category"""
170
+ if size_range is None or size_range == "None":
171
  return True
172
 
173
+ # Simply compare with cached size_category
174
+ return row['size_category'] == size_range
 
 
 
 
 
 
 
 
 
 
175
 
176
  TAG_FILTER_FUNCS = {
177
  "Audio & Speech": is_audio_speech,
 
192
  return "unaffiliated"
193
 
194
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
195
+ """Process DataFrame into treemap format with filters applied - OPTIMIZED with cached columns"""
196
  # Create a copy to avoid modifying the original
197
  filtered_df = df.copy()
198
 
199
  # Apply filters
200
+ filter_stats = {"initial": len(filtered_df)}
201
+ start_time = pd.Timestamp.now()
202
+
203
+ # Apply tag filter - OPTIMIZED to use cached columns
204
  if tag_filter and tag_filter in TAG_FILTER_FUNCS:
205
+ print(f"Applying tag filter: {tag_filter}")
206
+
207
+ # Use direct column filtering instead of applying a function to each row
208
+ if tag_filter == "Audio & Speech":
209
+ filtered_df = filtered_df[filtered_df['is_audio_speech']]
210
+ elif tag_filter == "Music":
211
+ filtered_df = filtered_df[filtered_df['has_music']]
212
+ elif tag_filter == "Robotics":
213
+ filtered_df = filtered_df[filtered_df['has_robot']]
214
+ elif tag_filter == "Biomedical":
215
+ filtered_df = filtered_df[filtered_df['is_biomed']]
216
+ elif tag_filter == "Time series":
217
+ filtered_df = filtered_df[filtered_df['has_series']]
218
+ elif tag_filter == "Sciences":
219
+ filtered_df = filtered_df[filtered_df['has_science']]
220
+ elif tag_filter == "Video":
221
+ filtered_df = filtered_df[filtered_df['has_video']]
222
+ elif tag_filter == "Images":
223
+ filtered_df = filtered_df[filtered_df['has_image']]
224
+ elif tag_filter == "Text":
225
+ filtered_df = filtered_df[filtered_df['has_text']]
226
+
227
+ filter_stats["after_tag_filter"] = len(filtered_df)
228
+ print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
229
+ start_time = pd.Timestamp.now()
230
 
231
+ # Apply pipeline filter
232
  if pipeline_filter:
233
+ print(f"Applying pipeline filter: {pipeline_filter}")
234
  filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
235
+ filter_stats["after_pipeline_filter"] = len(filtered_df)
236
+ print(f"Pipeline filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
237
+ start_time = pd.Timestamp.now()
238
 
239
+ # Apply size filter - OPTIMIZED to use cached size_category column
240
  if size_filter and size_filter in MODEL_SIZE_RANGES:
241
+ print(f"Applying size filter: {size_filter}")
242
+
243
+ # Use the cached size_category column directly
244
+ filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
245
+
246
+ # Debug info
247
+ print(f"Size filter '{size_filter}' applied.")
248
+ print(f"Models after size filter: {len(filtered_df)}")
249
 
250
+ filter_stats["after_size_filter"] = len(filtered_df)
251
+ print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
252
+ start_time = pd.Timestamp.now()
253
 
254
  # Add organization column
255
  filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
 
257
  # Skip organizations if specified
258
  if skip_orgs and len(skip_orgs) > 0:
259
  filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
260
+ filter_stats["after_skip_orgs"] = len(filtered_df)
261
+
262
+ # Print filter stats
263
+ print("Filter statistics:")
264
+ for stage, count in filter_stats.items():
265
+ print(f" {stage}: {count} models")
266
+
267
+ # Check if we have any data left
268
+ if filtered_df.empty:
269
+ print("Warning: No data left after applying filters!")
270
+ return pd.DataFrame() # Return empty DataFrame
271
 
272
  # Aggregate by organization
273
  org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
 
288
  # Ensure numeric values
289
  treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
290
 
291
+ print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
292
  return treemap_data
293
 
294
  def create_treemap(treemap_data, count_by, title=None):
 
327
 
328
  return fig
329
 
330
+ def load_models_data():
331
+ """Load models data from Hugging Face using DuckDB with caching for improved performance"""
332
+ try:
333
+ # The URL to the parquet file
334
+ parquet_url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
 
 
 
335
 
336
+ print("Fetching data from Hugging Face models.parquet...")
337
+
338
+ # Based on the column names provided, we can directly select the columns we need
339
+ # Note: We need to select safetensors to get the model size information
340
+ try:
341
+ query = """
342
+ SELECT
343
+ id,
344
+ downloads,
345
+ downloadsAllTime,
346
+ likes,
347
+ pipeline_tag,
348
+ tags,
349
+ safetensors
350
+ FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')
351
+ """
352
+ df = duckdb.sql(query).df()
353
+ except Exception as sql_error:
354
+ print(f"Error with specific column selection: {sql_error}")
355
+ # Fallback to just selecting everything and then filtering
356
+ print("Falling back to select * query...")
357
+ query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')"
358
+ raw_df = duckdb.sql(query).df()
359
+
360
+ # Now extract only the columns we need
361
+ needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors']
362
+ available_columns = set(raw_df.columns)
363
+ df = pd.DataFrame()
364
+
365
+ # Copy over columns that exist
366
+ for col in needed_columns:
367
+ if col in available_columns:
368
+ df[col] = raw_df[col]
369
+ else:
370
+ # Create empty columns for missing data
371
+ if col in ['downloads', 'downloadsAllTime', 'likes']:
372
+ df[col] = 0
373
+ elif col == 'pipeline_tag':
374
+ df[col] = ''
375
+ elif col == 'tags':
376
+ df[col] = [[] for _ in range(len(raw_df))]
377
+ elif col == 'safetensors':
378
+ df[col] = None
379
+ elif col == 'id':
380
+ # Create IDs based on index if missing
381
+ df[col] = [f"model_{i}" for i in range(len(raw_df))]
382
+
383
+ print(f"Data fetched successfully. Shape: {df.shape}")
384
+
385
+ # Check if safetensors column exists before trying to process it
386
+ if 'safetensors' in df.columns:
387
+ # Add params column derived from safetensors.total (model size in GB)
388
+ df['params'] = df['safetensors'].apply(extract_model_size)
389
+
390
+ # Debug model sizes
391
+ size_ranges = {
392
+ "Small (<1GB)": 0,
393
+ "Medium (1-5GB)": 0,
394
+ "Large (5-20GB)": 0,
395
+ "X-Large (20-50GB)": 0,
396
+ "XX-Large (>50GB)": 0
397
+ }
398
+
399
+ # Count models in each size range
400
+ for idx, row in df.iterrows():
401
+ size_gb = row['params']
402
+ if 0 <= size_gb < 1:
403
+ size_ranges["Small (<1GB)"] += 1
404
+ elif 1 <= size_gb < 5:
405
+ size_ranges["Medium (1-5GB)"] += 1
406
+ elif 5 <= size_gb < 20:
407
+ size_ranges["Large (5-20GB)"] += 1
408
+ elif 20 <= size_gb < 50:
409
+ size_ranges["X-Large (20-50GB)"] += 1
410
+ elif size_gb >= 50:
411
+ size_ranges["XX-Large (>50GB)"] += 1
412
+
413
+ print("Model size distribution:")
414
+ for size_range, count in size_ranges.items():
415
+ print(f" {size_range}: {count} models")
416
+
417
+ # CACHE SIZE CATEGORY: Add a size_category column for faster filtering
418
+ def get_size_category(size_gb):
419
+ if 0 <= size_gb < 1:
420
+ return "Small (<1GB)"
421
+ elif 1 <= size_gb < 5:
422
+ return "Medium (1-5GB)"
423
+ elif 5 <= size_gb < 20:
424
+ return "Large (5-20GB)"
425
+ elif 20 <= size_gb < 50:
426
+ return "X-Large (20-50GB)"
427
+ elif size_gb >= 50:
428
+ return "XX-Large (>50GB)"
429
+ return None
430
+
431
+ # Add cached size category column
432
+ df['size_category'] = df['params'].apply(get_size_category)
433
+
434
+ # Remove the safetensors column as we don't need it anymore
435
+ df = df.drop(columns=['safetensors'])
436
+ else:
437
+ # If no safetensors column, add empty params column
438
+ df['params'] = 0
439
+ df['size_category'] = None
440
+
441
+ # Process tags to ensure it's in the right format - FIXED
442
+ def process_tags(tags_value):
443
+ try:
444
+ if pd.isna(tags_value) or tags_value is None:
445
+ return []
446
+
447
+ # If it's a numpy array, convert to a list of strings
448
+ if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'):
449
+ # Note: This is the fix for the error
450
+ return [str(tag) for tag in tags_value.tolist()]
451
+
452
+ # If already a list, ensure all elements are strings
453
+ if isinstance(tags_value, list):
454
+ return [str(tag) for tag in tags_value]
455
+
456
+ # If string, try to parse as JSON or split by comma
457
+ if isinstance(tags_value, str):
458
+ try:
459
+ tags_list = json.loads(tags_value)
460
+ if isinstance(tags_list, list):
461
+ return [str(tag) for tag in tags_list]
462
+ except:
463
+ # Split by comma if JSON parsing fails
464
+ return [tag.strip() for tag in tags_value.split(',') if tag.strip()]
465
+
466
+ # Last resort, convert to string and return as a single tag
467
+ return [str(tags_value)]
468
+
469
+ except Exception as e:
470
+ print(f"Error processing tags: {e}")
471
+ return []
472
+
473
+ # Check if tags column exists before trying to process it
474
+ if 'tags' in df.columns:
475
+ # Process tags column
476
+ df['tags'] = df['tags'].apply(process_tags)
477
+
478
+ # CACHE TAG CATEGORIES: Pre-calculate tag categories for faster filtering
479
+ print("Pre-calculating cached tag categories...")
480
+
481
+ # Helper functions to check for specific tags (simplified for caching)
482
+ def has_audio_tag(tags):
483
+ if tags and isinstance(tags, list):
484
+ return any("audio" in str(tag).lower() for tag in tags)
485
+ return False
486
+
487
+ def has_speech_tag(tags):
488
+ if tags and isinstance(tags, list):
489
+ return any("speech" in str(tag).lower() for tag in tags)
490
+ return False
491
+
492
+ def has_music_tag(tags):
493
+ if tags and isinstance(tags, list):
494
+ return any("music" in str(tag).lower() for tag in tags)
495
+ return False
496
+
497
+ def has_robot_tag(tags):
498
+ if tags and isinstance(tags, list):
499
+ return any("robot" in str(tag).lower() for tag in tags)
500
+ return False
501
+
502
+ def has_bio_tag(tags):
503
+ if tags and isinstance(tags, list):
504
+ return any("bio" in str(tag).lower() for tag in tags)
505
+ return False
506
+
507
+ def has_med_tag(tags):
508
+ if tags and isinstance(tags, list):
509
+ return any("medic" in str(tag).lower() for tag in tags)
510
+ return False
511
+
512
+ def has_series_tag(tags):
513
+ if tags and isinstance(tags, list):
514
+ return any("series" in str(tag).lower() for tag in tags)
515
+ return False
516
+
517
+ def has_science_tag(tags):
518
+ if tags and isinstance(tags, list):
519
+ return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags)
520
+ return False
521
+
522
+ def has_video_tag(tags):
523
+ if tags and isinstance(tags, list):
524
+ return any("video" in str(tag).lower() for tag in tags)
525
+ return False
526
+
527
+ def has_image_tag(tags):
528
+ if tags and isinstance(tags, list):
529
+ return any("image" in str(tag).lower() for tag in tags)
530
+ return False
531
+
532
+ def has_text_tag(tags):
533
+ if tags and isinstance(tags, list):
534
+ return any("text" in str(tag).lower() for tag in tags)
535
+ return False
536
+
537
+ # Add cached columns for tag categories
538
+ print("Creating cached tag columns...")
539
+ df['has_audio'] = df['tags'].apply(has_audio_tag)
540
+ df['has_speech'] = df['tags'].apply(has_speech_tag)
541
+ df['has_music'] = df['tags'].apply(has_music_tag)
542
+ df['has_robot'] = df['tags'].apply(has_robot_tag)
543
+ df['has_bio'] = df['tags'].apply(has_bio_tag)
544
+ df['has_med'] = df['tags'].apply(has_med_tag)
545
+ df['has_series'] = df['tags'].apply(has_series_tag)
546
+ df['has_science'] = df['tags'].apply(has_science_tag)
547
+ df['has_video'] = df['tags'].apply(has_video_tag)
548
+ df['has_image'] = df['tags'].apply(has_image_tag)
549
+ df['has_text'] = df['tags'].apply(has_text_tag)
550
+
551
+ # Create combined category flags for faster filtering
552
+ df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
553
+ df['pipeline_tag'].str.contains('audio', case=False, na=False) |
554
+ df['pipeline_tag'].str.contains('speech', case=False, na=False))
555
+ df['is_biomed'] = df['has_bio'] | df['has_med']
556
+
557
+ print("Cached tag columns created successfully!")
558
+ else:
559
+ # If no tags column, add empty tags and set all category flags to False
560
+ df['tags'] = [[] for _ in range(len(df))]
561
+ for col in ['has_audio', 'has_speech', 'has_music', 'has_robot',
562
+ 'has_bio', 'has_med', 'has_series', 'has_science',
563
+ 'has_video', 'has_image', 'has_text',
564
+ 'is_audio_speech', 'is_biomed']:
565
+ df[col] = False
566
+
567
+ # Fill NaN values
568
+ df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True)
569
+
570
+ # Ensure pipeline_tag is a string
571
+ if 'pipeline_tag' in df.columns:
572
+ df['pipeline_tag'] = df['pipeline_tag'].fillna('')
573
+ else:
574
+ df['pipeline_tag'] = ''
575
+
576
+ # Make sure all required columns exist
577
+ for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']:
578
+ if col not in df.columns:
579
+ if col in ['downloads', 'downloadsAllTime', 'likes', 'params']:
580
+ df[col] = 0
581
+ elif col == 'pipeline_tag':
582
+ df[col] = ''
583
+ elif col == 'tags':
584
+ df[col] = [[] for _ in range(len(df))]
585
+ elif col == 'id':
586
+ df[col] = [f"model_{i}" for i in range(len(df))]
587
+
588
+ print(f"Successfully processed {len(df)} models with cached tag and size information")
589
+ return df, True
590
+
591
+ except Exception as e:
592
+ print(f"Error loading data: {e}")
593
+ # Return an empty DataFrame and False to indicate loading failure
594
+ return pd.DataFrame(), False
595
 
596
  # Create Gradio interface
597
  with gr.Blocks() as demo:
 
606
  Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
607
 
608
  The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
609
+
610
  """)
611
 
612
  with gr.Row():
 
668
  )
669
 
670
  generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
671
+ refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary")
672
 
673
  with gr.Column(scale=3):
674
  plot_output = gr.Plot()
675
+ stats_output = gr.Markdown("*Loading data from Hugging Face...*")
676
+ data_info = gr.Markdown("")
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  # Button enablement after data load
679
  def enable_plot_button(loaded):
 
700
  outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
701
  )
702
 
703
+ # Function to handle data load and provide data info
704
+ def load_and_provide_info():
705
+ df, success = load_models_data()
706
+
707
+ if success:
708
+ # Generate information about the loaded data
709
+ info_text = f"""
710
+ ### Data Information
711
+ - **Total models loaded**: {len(df):,}
712
+ - **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
713
+ - **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet)
714
+ """
715
+
716
+ # Return the data, loading status, and info text
717
+ return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*"
718
+ else:
719
+ # Return empty data, failed loading status, and error message
720
+ return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*"
721
+
722
  # Main generate function
723
  def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
724
  if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
 
815
 
816
  # Load data at startup
817
  demo.load(
818
+ fn=load_and_provide_info,
819
  inputs=[],
820
+ outputs=[models_data, loading_complete, data_info, stats_output]
821
+ )
822
+
823
+ # Refresh data when button is clicked
824
+ refresh_data_button.click(
825
+ fn=load_and_provide_info,
826
+ inputs=[],
827
+ outputs=[models_data, loading_complete, data_info, stats_output]
828
  )
829
 
830
  generate_plot_button.click(
 
842
  outputs=[plot_output, stats_output]
843
  )
844
 
 
 
845
  if __name__ == "__main__":
846
  demo.launch()