evijit HF Staff commited on
Commit
7262ace
·
verified ·
1 Parent(s): 4d0a8a3

Read live data from DuckDB

Browse files
Files changed (1) hide show
  1. app.py +710 -139
app.py CHANGED
@@ -2,7 +2,12 @@ import json
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
 
 
 
 
5
 
 
6
  PIPELINE_TAGS = [
7
  'text-generation',
8
  'text-to-image',
@@ -44,45 +49,129 @@ PIPELINE_TAGS = [
44
  'table-question-answering',
45
  ]
46
 
47
- def is_audio_speech(repo_dct):
48
- res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \
49
- (repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \
50
- (repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
51
- (repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", [])))
52
- return res
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- def is_music(repo_dct):
55
- res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", [])))
56
- return res
57
 
58
- def is_robotics(repo_dct):
59
- res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", [])))
60
- return res
61
 
62
- def is_biomed(repo_dct):
63
- res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
64
- (repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", [])))
65
- return res
66
 
67
- def is_timeseries(repo_dct):
68
- res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", [])))
69
- return res
70
 
71
- def is_science(repo_dct):
72
- res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", [])))
73
- return res
74
 
75
- def is_video(repo_dct):
76
- res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", [])))
77
- return res
78
 
79
- def is_image(repo_dct):
80
- res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", [])))
81
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def is_text(repo_dct):
84
- res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", [])))
85
- return res
 
 
 
 
 
86
 
87
  TAG_FILTER_FUNCS = {
88
  "Audio & Speech": is_audio_speech,
@@ -96,180 +185,662 @@ TAG_FILTER_FUNCS = {
96
  "Sciences": is_science,
97
  }
98
 
99
- def make_org_stats(repo_type, count_by, org_stats, top_k=20, filter_func=None):
100
- assert count_by in ["likes", "downloads", "downloads_all"]
101
- assert repo_type in ["all", "datasets", "models"]
102
- repos = ["datasets", "models"] if repo_type == "all" else [repo_type]
103
- if filter_func is None:
104
- filter_func = lambda x: True
105
- sorted_stats = sorted(
106
- [(
107
- author,
108
- sum(dct[count_by] for dct in author_dct[repo] if filter_func(dct))
109
- ) for repo in repos for author, author_dct in org_stats.items()],
110
- key=lambda x:x[1],
111
- reverse=True,
112
- )
113
- res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))]
114
- total_st = sum(st for o, st in res)
115
- res_plot_df = []
116
- for org, st in res:
117
- if org == "Others...":
118
- res_plot_df += [("Others...", "other", st * 100 / total_st)]
119
- else:
120
- for repo in repos:
121
- for dct in org_stats[org][repo]:
122
- if filter_func(dct):
123
- res_plot_df += [(org, dct["id"], dct[count_by] * 100 / total_st)]
124
- return ([(o, 100 * st / total_st) for o, st in res if st > 0], res_plot_df)
125
-
126
- def make_figure(count_by, repo_type, org_stats, tag_filter=None, pipeline_filter=None):
127
- assert count_by in ["downloads", "likes", "downloads_all"]
128
- assert repo_type in ["all", "models", "datasets"]
129
- assert tag_filter is None or pipeline_filter is None
130
- filter_func = None
131
- if tag_filter:
132
- filter_func = TAG_FILTER_FUNCS[tag_filter]
 
 
 
 
 
 
 
 
 
 
133
  if pipeline_filter:
134
- filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter
135
- _, res_plot_df = make_org_stats(repo_type, count_by, org_stats, top_k=25, filter_func=filter_func)
136
- df = pd.DataFrame(
137
- dict(
138
- organizations=[o for o, _, _ in res_plot_df],
139
- repo=[r for _, r, _ in res_plot_df],
140
- stats=[s for _, _, s in res_plot_df],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  )
 
 
 
 
 
 
 
 
 
142
  )
143
- df[repo_type] = repo_type # in order to have a single root node
144
- fig = px.treemap(df, path=[repo_type, 'organizations', 'repo'], values='stats')
145
  fig.update_layout(
146
- treemapcolorway = ["pink" for _ in range(len(res_plot_df))],
147
- margin = dict(t=50, l=25, r=25, b=25)
148
  )
 
 
 
 
 
 
 
149
  return fig
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
 
152
  with gr.Blocks() as demo:
153
- org_stats_data = gr.State(value=None) # To store loaded data
 
154
 
155
  with gr.Row():
156
  gr.Markdown("""
157
- ## Hugging Face Organization Stats
 
 
 
 
 
158
 
159
- This app shows how different organizations are contributing to different aspects of the open AI ecosystem.
160
- Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest.
161
  """)
 
162
  with gr.Row():
163
  with gr.Column(scale=1):
164
- repo_type_dropdown = gr.Dropdown(
165
- label="Repository Type",
166
- choices=["all", "models", "datasets"],
167
- value="all"
168
- )
169
  count_by_dropdown = gr.Dropdown(
170
  label="Metric",
171
- choices=["downloads", "likes", "downloads_all"],
172
- value="downloads"
 
 
 
 
 
173
  )
174
-
175
  filter_choice_radio = gr.Radio(
176
- label="Filter by",
177
  choices=["None", "Tag Filter", "Pipeline Filter"],
178
- value="None"
 
179
  )
180
-
181
  tag_filter_dropdown = gr.Dropdown(
182
  label="Select Tag",
183
  choices=list(TAG_FILTER_FUNCS.keys()),
184
  value=None,
185
- visible=False
 
186
  )
 
187
  pipeline_filter_dropdown = gr.Dropdown(
188
  label="Select Pipeline Tag",
189
  choices=PIPELINE_TAGS,
190
  value=None,
191
- visible=False
 
192
  )
193
 
194
- generate_plot_button = gr.Button("Generate Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  with gr.Column(scale=3):
197
  plot_output = gr.Plot()
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- def generate_plot_on_click(repo_type, count_by, filter_choice, tag_filter, pipeline_filter, data):
200
- # Print the current state of the input variables
201
- print(f"Generating plot with the following inputs:")
202
- print(f" Repository Type: {repo_type}")
203
- print(f" Metric (Count By): {count_by}")
204
- print(f" Filter Choice: {filter_choice}")
205
  if filter_choice == "Tag Filter":
206
- print(f" Tag Filter: {tag_filter}")
207
  elif filter_choice == "Pipeline Filter":
208
- print(f" Pipeline Filter: {pipeline_filter}")
 
 
209
 
210
- if data is None:
211
- print("Error: Data not loaded yet.")
212
- return None
 
 
 
 
 
 
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  selected_tag_filter = None
215
  selected_pipeline_filter = None
 
216
 
217
  if filter_choice == "Tag Filter":
218
  selected_tag_filter = tag_filter
219
  elif filter_choice == "Pipeline Filter":
220
  selected_pipeline_filter = pipeline_filter
221
-
222
- fig = make_figure(
 
 
 
 
 
 
 
 
223
  count_by=count_by,
224
- repo_type=repo_type,
225
- org_stats=data,
226
  tag_filter=selected_tag_filter,
227
- pipeline_filter=selected_pipeline_filter
 
 
228
  )
229
- return fig
230
 
231
- def update_filter_visibility(filter_choice):
232
- if filter_choice == "Tag Filter":
233
- return gr.update(visible=True), gr.update(visible=False)
234
- elif filter_choice == "Pipeline Filter":
235
- return gr.update(visible=False), gr.update(visible=True)
236
- else: # "None"
237
- return gr.update(visible=False), gr.update(visible=False)
238
 
239
- filter_choice_radio.change(
240
- fn=update_filter_visibility,
241
- inputs=[filter_choice_radio],
242
- outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
243
- )
244
-
245
- # Load data once at startup
246
- def load_org_data():
247
- print("Loading organization statistics data...")
248
- loaded_org_stats = json.load(open("org_to_artifacts_2l_stats.json"))
249
- print("Data loaded successfully.")
250
- return loaded_org_stats
 
 
 
 
 
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  demo.load(
253
- fn=load_org_data,
254
- inputs=[], # No inputs needed to just load data
255
- outputs=[org_stats_data] # Only output to the state
 
 
 
 
 
 
 
256
  )
257
 
258
- # Button click event to generate plot
259
  generate_plot_button.click(
260
  fn=generate_plot_on_click,
261
  inputs=[
262
- repo_type_dropdown,
263
  count_by_dropdown,
264
  filter_choice_radio,
265
  tag_filter_dropdown,
266
  pipeline_filter_dropdown,
267
- org_stats_data
 
 
 
268
  ],
269
- outputs=[plot_output]
270
  )
271
 
272
-
273
  if __name__ == "__main__":
274
- # org_stats = json.load(open("org_to_artifacts_2l_stats.json")) # Data loading handled by demo.load
275
  demo.launch()
 
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
5
+ import os
6
+ import numpy as np
7
+ import io
8
+ import duckdb
9
 
10
+ # Define pipeline tags
11
  PIPELINE_TAGS = [
12
  'text-generation',
13
  'text-to-image',
 
49
  'table-question-answering',
50
  ]
51
 
52
+ # Model size categories in GB
53
+ MODEL_SIZE_RANGES = {
54
+ "Small (<1GB)": (0, 1),
55
+ "Medium (1-5GB)": (1, 5),
56
+ "Large (5-20GB)": (5, 20),
57
+ "X-Large (20-50GB)": (20, 50),
58
+ "XX-Large (>50GB)": (50, float('inf'))
59
+ }
60
+
61
+ # Filter functions for tags - UPDATED to use cached columns
62
+ def is_audio_speech(row):
63
+ # Use cached column instead of recalculating
64
+ return row['is_audio_speech']
65
+
66
+ def is_music(row):
67
+ # Use cached column instead of recalculating
68
+ return row['has_music']
69
 
70
+ def is_robotics(row):
71
+ # Use cached column instead of recalculating
72
+ return row['has_robot']
73
 
74
+ def is_biomed(row):
75
+ # Use cached column instead of recalculating
76
+ return row['is_biomed']
77
 
78
+ def is_timeseries(row):
79
+ # Use cached column instead of recalculating
80
+ return row['has_series']
 
81
 
82
+ def is_science(row):
83
+ # Use cached column instead of recalculating
84
+ return row['has_science']
85
 
86
+ def is_video(row):
87
+ # Use cached column instead of recalculating
88
+ return row['has_video']
89
 
90
+ def is_image(row):
91
+ # Use cached column instead of recalculating
92
+ return row['has_image']
93
 
94
+ def is_text(row):
95
+ # Use cached column instead of recalculating
96
+ return row['has_text']
97
+
98
+ def is_image(row):
99
+ tags = row.get("tags", [])
100
+
101
+ # Check if tags exists and is not empty
102
+ if tags is not None:
103
+ # For numpy arrays
104
+ if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
105
+ # Convert numpy array to list
106
+ tags_list = tags.tolist()
107
+ return any("image" in str(tag).lower() for tag in tags_list)
108
+ # For regular lists
109
+ elif isinstance(tags, list):
110
+ return any("image" in str(tag).lower() for tag in tags)
111
+ # For string tags
112
+ elif isinstance(tags, str):
113
+ return "image" in tags.lower()
114
+ return False
115
+
116
+ def is_text(row):
117
+ tags = row.get("tags", [])
118
+
119
+ # Check if tags exists and is not empty
120
+ if tags is not None:
121
+ # For numpy arrays
122
+ if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
123
+ # Convert numpy array to list
124
+ tags_list = tags.tolist()
125
+ return any("text" in str(tag).lower() for tag in tags_list)
126
+ # For regular lists
127
+ elif isinstance(tags, list):
128
+ return any("text" in str(tag).lower() for tag in tags)
129
+ # For string tags
130
+ elif isinstance(tags, str):
131
+ return "text" in tags.lower()
132
+ return False
133
+
134
+ def extract_model_size(safetensors_data):
135
+ """Extract model size in GB from safetensors data"""
136
+ try:
137
+ if pd.isna(safetensors_data):
138
+ return 0
139
+
140
+ # If it's already a dictionary, use it directly
141
+ if isinstance(safetensors_data, dict):
142
+ if 'total' in safetensors_data:
143
+ try:
144
+ size_bytes = float(safetensors_data['total'])
145
+ return size_bytes / (1024 * 1024 * 1024) # Convert to GB
146
+ except (ValueError, TypeError):
147
+ pass
148
+
149
+ # If it's a string, try to parse it as JSON
150
+ elif isinstance(safetensors_data, str):
151
+ try:
152
+ data_dict = json.loads(safetensors_data)
153
+ if 'total' in data_dict:
154
+ try:
155
+ size_bytes = float(data_dict['total'])
156
+ return size_bytes / (1024 * 1024 * 1024) # Convert to GB
157
+ except (ValueError, TypeError):
158
+ pass
159
+ except:
160
+ pass
161
+
162
+ return 0
163
+ except Exception as e:
164
+ print(f"Error extracting model size: {e}")
165
+ return 0
166
 
167
+ # Add model size filter function - UPDATED to use cached size_category column
168
+ def is_in_size_range(row, size_range):
169
+ """Check if a model is in the specified size range using pre-calculated size category"""
170
+ if size_range is None or size_range == "None":
171
+ return True
172
+
173
+ # Simply compare with cached size_category
174
+ return row['size_category'] == size_range
175
 
176
  TAG_FILTER_FUNCS = {
177
  "Audio & Speech": is_audio_speech,
 
185
  "Sciences": is_science,
186
  }
187
 
188
+ def extract_org_from_id(model_id):
189
+ """Extract organization name from model ID"""
190
+ if "/" in model_id:
191
+ return model_id.split("/")[0]
192
+ return "unaffiliated"
193
+
194
+ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
195
+ """Process DataFrame into treemap format with filters applied - OPTIMIZED with cached columns"""
196
+ # Create a copy to avoid modifying the original
197
+ filtered_df = df.copy()
198
+
199
+ # Apply filters
200
+ filter_stats = {"initial": len(filtered_df)}
201
+ start_time = pd.Timestamp.now()
202
+
203
+ # Apply tag filter - OPTIMIZED to use cached columns
204
+ if tag_filter and tag_filter in TAG_FILTER_FUNCS:
205
+ print(f"Applying tag filter: {tag_filter}")
206
+
207
+ # Use direct column filtering instead of applying a function to each row
208
+ if tag_filter == "Audio & Speech":
209
+ filtered_df = filtered_df[filtered_df['is_audio_speech']]
210
+ elif tag_filter == "Music":
211
+ filtered_df = filtered_df[filtered_df['has_music']]
212
+ elif tag_filter == "Robotics":
213
+ filtered_df = filtered_df[filtered_df['has_robot']]
214
+ elif tag_filter == "Biomedical":
215
+ filtered_df = filtered_df[filtered_df['is_biomed']]
216
+ elif tag_filter == "Time series":
217
+ filtered_df = filtered_df[filtered_df['has_series']]
218
+ elif tag_filter == "Sciences":
219
+ filtered_df = filtered_df[filtered_df['has_science']]
220
+ elif tag_filter == "Video":
221
+ filtered_df = filtered_df[filtered_df['has_video']]
222
+ elif tag_filter == "Images":
223
+ filtered_df = filtered_df[filtered_df['has_image']]
224
+ elif tag_filter == "Text":
225
+ filtered_df = filtered_df[filtered_df['has_text']]
226
+
227
+ filter_stats["after_tag_filter"] = len(filtered_df)
228
+ print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
229
+ start_time = pd.Timestamp.now()
230
+
231
+ # Apply pipeline filter
232
  if pipeline_filter:
233
+ print(f"Applying pipeline filter: {pipeline_filter}")
234
+ filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
235
+ filter_stats["after_pipeline_filter"] = len(filtered_df)
236
+ print(f"Pipeline filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
237
+ start_time = pd.Timestamp.now()
238
+
239
+ # Apply size filter - OPTIMIZED to use cached size_category column
240
+ if size_filter and size_filter in MODEL_SIZE_RANGES:
241
+ print(f"Applying size filter: {size_filter}")
242
+
243
+ # Use the cached size_category column directly
244
+ filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
245
+
246
+ # Debug info
247
+ print(f"Size filter '{size_filter}' applied.")
248
+ print(f"Models after size filter: {len(filtered_df)}")
249
+
250
+ filter_stats["after_size_filter"] = len(filtered_df)
251
+ print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
252
+ start_time = pd.Timestamp.now()
253
+
254
+ # Add organization column
255
+ filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
256
+
257
+ # Skip organizations if specified
258
+ if skip_orgs and len(skip_orgs) > 0:
259
+ filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
260
+ filter_stats["after_skip_orgs"] = len(filtered_df)
261
+
262
+ # Print filter stats
263
+ print("Filter statistics:")
264
+ for stage, count in filter_stats.items():
265
+ print(f" {stage}: {count} models")
266
+
267
+ # Check if we have any data left
268
+ if filtered_df.empty:
269
+ print("Warning: No data left after applying filters!")
270
+ return pd.DataFrame() # Return empty DataFrame
271
+
272
+ # Aggregate by organization
273
+ org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
274
+ org_totals = org_totals.sort_values(by=count_by, ascending=False)
275
+
276
+ # Get top organizations
277
+ top_orgs = org_totals.head(top_k)["organization"].tolist()
278
+
279
+ # Filter to only include models from top organizations
280
+ filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
281
+
282
+ # Prepare data for treemap
283
+ treemap_data = filtered_df[["id", "organization", count_by]].copy()
284
+
285
+ # Add a root node
286
+ treemap_data["root"] = "models"
287
+
288
+ # Ensure numeric values
289
+ treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
290
+
291
+ print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
292
+ return treemap_data
293
+
294
+ def create_treemap(treemap_data, count_by, title=None):
295
+ """Create a Plotly treemap from the prepared data"""
296
+ if treemap_data.empty:
297
+ # Create an empty figure with a message
298
+ fig = px.treemap(
299
+ names=["No data matches the selected filters"],
300
+ values=[1]
301
+ )
302
+ fig.update_layout(
303
+ title="No data matches the selected filters",
304
+ margin=dict(t=50, l=25, r=25, b=25)
305
  )
306
+ return fig
307
+
308
+ # Create the treemap
309
+ fig = px.treemap(
310
+ treemap_data,
311
+ path=["root", "organization", "id"],
312
+ values=count_by,
313
+ title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
314
+ color_discrete_sequence=px.colors.qualitative.Plotly
315
  )
316
+
317
+ # Update layout
318
  fig.update_layout(
319
+ margin=dict(t=50, l=25, r=25, b=25)
 
320
  )
321
+
322
+ # Update traces for better readability
323
+ fig.update_traces(
324
+ textinfo="label+value+percent root",
325
+ hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
326
+ )
327
+
328
  return fig
329
 
330
+ def load_models_data():
331
+ """Load models data from Hugging Face using DuckDB with caching for improved performance"""
332
+ try:
333
+ # The URL to the parquet file
334
+ parquet_url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
335
+
336
+ print("Fetching data from Hugging Face models.parquet...")
337
+
338
+ # Based on the column names provided, we can directly select the columns we need
339
+ # Note: We need to select safetensors to get the model size information
340
+ try:
341
+ query = """
342
+ SELECT
343
+ id,
344
+ downloads,
345
+ downloadsAllTime,
346
+ likes,
347
+ pipeline_tag,
348
+ tags,
349
+ safetensors
350
+ FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')
351
+ """
352
+ df = duckdb.sql(query).df()
353
+ except Exception as sql_error:
354
+ print(f"Error with specific column selection: {sql_error}")
355
+ # Fallback to just selecting everything and then filtering
356
+ print("Falling back to select * query...")
357
+ query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')"
358
+ raw_df = duckdb.sql(query).df()
359
+
360
+ # Now extract only the columns we need
361
+ needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors']
362
+ available_columns = set(raw_df.columns)
363
+ df = pd.DataFrame()
364
+
365
+ # Copy over columns that exist
366
+ for col in needed_columns:
367
+ if col in available_columns:
368
+ df[col] = raw_df[col]
369
+ else:
370
+ # Create empty columns for missing data
371
+ if col in ['downloads', 'downloadsAllTime', 'likes']:
372
+ df[col] = 0
373
+ elif col == 'pipeline_tag':
374
+ df[col] = ''
375
+ elif col == 'tags':
376
+ df[col] = [[] for _ in range(len(raw_df))]
377
+ elif col == 'safetensors':
378
+ df[col] = None
379
+ elif col == 'id':
380
+ # Create IDs based on index if missing
381
+ df[col] = [f"model_{i}" for i in range(len(raw_df))]
382
+
383
+ print(f"Data fetched successfully. Shape: {df.shape}")
384
+
385
+ # Check if safetensors column exists before trying to process it
386
+ if 'safetensors' in df.columns:
387
+ # Add params column derived from safetensors.total (model size in GB)
388
+ df['params'] = df['safetensors'].apply(extract_model_size)
389
+
390
+ # Debug model sizes
391
+ size_ranges = {
392
+ "Small (<1GB)": 0,
393
+ "Medium (1-5GB)": 0,
394
+ "Large (5-20GB)": 0,
395
+ "X-Large (20-50GB)": 0,
396
+ "XX-Large (>50GB)": 0
397
+ }
398
+
399
+ # Count models in each size range
400
+ for idx, row in df.iterrows():
401
+ size_gb = row['params']
402
+ if 0 <= size_gb < 1:
403
+ size_ranges["Small (<1GB)"] += 1
404
+ elif 1 <= size_gb < 5:
405
+ size_ranges["Medium (1-5GB)"] += 1
406
+ elif 5 <= size_gb < 20:
407
+ size_ranges["Large (5-20GB)"] += 1
408
+ elif 20 <= size_gb < 50:
409
+ size_ranges["X-Large (20-50GB)"] += 1
410
+ elif size_gb >= 50:
411
+ size_ranges["XX-Large (>50GB)"] += 1
412
+
413
+ print("Model size distribution:")
414
+ for size_range, count in size_ranges.items():
415
+ print(f" {size_range}: {count} models")
416
+
417
+ # CACHE SIZE CATEGORY: Add a size_category column for faster filtering
418
+ def get_size_category(size_gb):
419
+ if 0 <= size_gb < 1:
420
+ return "Small (<1GB)"
421
+ elif 1 <= size_gb < 5:
422
+ return "Medium (1-5GB)"
423
+ elif 5 <= size_gb < 20:
424
+ return "Large (5-20GB)"
425
+ elif 20 <= size_gb < 50:
426
+ return "X-Large (20-50GB)"
427
+ elif size_gb >= 50:
428
+ return "XX-Large (>50GB)"
429
+ return None
430
+
431
+ # Add cached size category column
432
+ df['size_category'] = df['params'].apply(get_size_category)
433
+
434
+ # Remove the safetensors column as we don't need it anymore
435
+ df = df.drop(columns=['safetensors'])
436
+ else:
437
+ # If no safetensors column, add empty params column
438
+ df['params'] = 0
439
+ df['size_category'] = None
440
+
441
+ # Process tags to ensure it's in the right format - FIXED
442
+ def process_tags(tags_value):
443
+ try:
444
+ if pd.isna(tags_value) or tags_value is None:
445
+ return []
446
+
447
+ # If it's a numpy array, convert to a list of strings
448
+ if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'):
449
+ # Note: This is the fix for the error
450
+ return [str(tag) for tag in tags_value.tolist()]
451
+
452
+ # If already a list, ensure all elements are strings
453
+ if isinstance(tags_value, list):
454
+ return [str(tag) for tag in tags_value]
455
+
456
+ # If string, try to parse as JSON or split by comma
457
+ if isinstance(tags_value, str):
458
+ try:
459
+ tags_list = json.loads(tags_value)
460
+ if isinstance(tags_list, list):
461
+ return [str(tag) for tag in tags_list]
462
+ except:
463
+ # Split by comma if JSON parsing fails
464
+ return [tag.strip() for tag in tags_value.split(',') if tag.strip()]
465
+
466
+ # Last resort, convert to string and return as a single tag
467
+ return [str(tags_value)]
468
+
469
+ except Exception as e:
470
+ print(f"Error processing tags: {e}")
471
+ return []
472
+
473
+ # Check if tags column exists before trying to process it
474
+ if 'tags' in df.columns:
475
+ # Process tags column
476
+ df['tags'] = df['tags'].apply(process_tags)
477
+
478
+ # CACHE TAG CATEGORIES: Pre-calculate tag categories for faster filtering
479
+ print("Pre-calculating cached tag categories...")
480
+
481
+ # Helper functions to check for specific tags (simplified for caching)
482
+ def has_audio_tag(tags):
483
+ if tags and isinstance(tags, list):
484
+ return any("audio" in str(tag).lower() for tag in tags)
485
+ return False
486
+
487
+ def has_speech_tag(tags):
488
+ if tags and isinstance(tags, list):
489
+ return any("speech" in str(tag).lower() for tag in tags)
490
+ return False
491
+
492
+ def has_music_tag(tags):
493
+ if tags and isinstance(tags, list):
494
+ return any("music" in str(tag).lower() for tag in tags)
495
+ return False
496
+
497
+ def has_robot_tag(tags):
498
+ if tags and isinstance(tags, list):
499
+ return any("robot" in str(tag).lower() for tag in tags)
500
+ return False
501
+
502
+ def has_bio_tag(tags):
503
+ if tags and isinstance(tags, list):
504
+ return any("bio" in str(tag).lower() for tag in tags)
505
+ return False
506
+
507
+ def has_med_tag(tags):
508
+ if tags and isinstance(tags, list):
509
+ return any("medic" in str(tag).lower() for tag in tags)
510
+ return False
511
+
512
+ def has_series_tag(tags):
513
+ if tags and isinstance(tags, list):
514
+ return any("series" in str(tag).lower() for tag in tags)
515
+ return False
516
+
517
+ def has_science_tag(tags):
518
+ if tags and isinstance(tags, list):
519
+ return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags)
520
+ return False
521
+
522
+ def has_video_tag(tags):
523
+ if tags and isinstance(tags, list):
524
+ return any("video" in str(tag).lower() for tag in tags)
525
+ return False
526
+
527
+ def has_image_tag(tags):
528
+ if tags and isinstance(tags, list):
529
+ return any("image" in str(tag).lower() for tag in tags)
530
+ return False
531
+
532
+ def has_text_tag(tags):
533
+ if tags and isinstance(tags, list):
534
+ return any("text" in str(tag).lower() for tag in tags)
535
+ return False
536
+
537
+ # Add cached columns for tag categories
538
+ print("Creating cached tag columns...")
539
+ df['has_audio'] = df['tags'].apply(has_audio_tag)
540
+ df['has_speech'] = df['tags'].apply(has_speech_tag)
541
+ df['has_music'] = df['tags'].apply(has_music_tag)
542
+ df['has_robot'] = df['tags'].apply(has_robot_tag)
543
+ df['has_bio'] = df['tags'].apply(has_bio_tag)
544
+ df['has_med'] = df['tags'].apply(has_med_tag)
545
+ df['has_series'] = df['tags'].apply(has_series_tag)
546
+ df['has_science'] = df['tags'].apply(has_science_tag)
547
+ df['has_video'] = df['tags'].apply(has_video_tag)
548
+ df['has_image'] = df['tags'].apply(has_image_tag)
549
+ df['has_text'] = df['tags'].apply(has_text_tag)
550
+
551
+ # Create combined category flags for faster filtering
552
+ df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
553
+ df['pipeline_tag'].str.contains('audio', case=False, na=False) |
554
+ df['pipeline_tag'].str.contains('speech', case=False, na=False))
555
+ df['is_biomed'] = df['has_bio'] | df['has_med']
556
+
557
+ print("Cached tag columns created successfully!")
558
+ else:
559
+ # If no tags column, add empty tags and set all category flags to False
560
+ df['tags'] = [[] for _ in range(len(df))]
561
+ for col in ['has_audio', 'has_speech', 'has_music', 'has_robot',
562
+ 'has_bio', 'has_med', 'has_series', 'has_science',
563
+ 'has_video', 'has_image', 'has_text',
564
+ 'is_audio_speech', 'is_biomed']:
565
+ df[col] = False
566
+
567
+ # Fill NaN values
568
+ df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True)
569
+
570
+ # Ensure pipeline_tag is a string
571
+ if 'pipeline_tag' in df.columns:
572
+ df['pipeline_tag'] = df['pipeline_tag'].fillna('')
573
+ else:
574
+ df['pipeline_tag'] = ''
575
+
576
+ # Make sure all required columns exist
577
+ for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']:
578
+ if col not in df.columns:
579
+ if col in ['downloads', 'downloadsAllTime', 'likes', 'params']:
580
+ df[col] = 0
581
+ elif col == 'pipeline_tag':
582
+ df[col] = ''
583
+ elif col == 'tags':
584
+ df[col] = [[] for _ in range(len(df))]
585
+ elif col == 'id':
586
+ df[col] = [f"model_{i}" for i in range(len(df))]
587
+
588
+ print(f"Successfully processed {len(df)} models with cached tag and size information")
589
+ return df, True
590
+
591
+ except Exception as e:
592
+ print(f"Error loading data: {e}")
593
+ # Return an empty DataFrame and False to indicate loading failure
594
+ return pd.DataFrame(), False
595
 
596
+ # Create Gradio interface
597
  with gr.Blocks() as demo:
598
+ models_data = gr.State()
599
+ loading_complete = gr.State(False) # Flag to indicate data load completion
600
 
601
  with gr.Row():
602
  gr.Markdown("""
603
+ # HuggingFace Models TreeMap Visualization
604
+
605
+ This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
606
+ Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
607
+
608
+ The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
609
 
 
 
610
  """)
611
+
612
  with gr.Row():
613
  with gr.Column(scale=1):
 
 
 
 
 
614
  count_by_dropdown = gr.Dropdown(
615
  label="Metric",
616
+ choices=[
617
+ ("Downloads (last 30 days)", "downloads"),
618
+ ("Downloads (All Time)", "downloadsAllTime"),
619
+ ("Likes", "likes")
620
+ ],
621
+ value="downloads",
622
+ info="Select the metric to determine box sizes"
623
  )
624
+
625
  filter_choice_radio = gr.Radio(
626
+ label="Filter Type",
627
  choices=["None", "Tag Filter", "Pipeline Filter"],
628
+ value="None",
629
+ info="Choose how to filter the models"
630
  )
631
+
632
  tag_filter_dropdown = gr.Dropdown(
633
  label="Select Tag",
634
  choices=list(TAG_FILTER_FUNCS.keys()),
635
  value=None,
636
+ visible=False,
637
+ info="Filter models by domain/category"
638
  )
639
+
640
  pipeline_filter_dropdown = gr.Dropdown(
641
  label="Select Pipeline Tag",
642
  choices=PIPELINE_TAGS,
643
  value=None,
644
+ visible=False,
645
+ info="Filter models by specific pipeline"
646
  )
647
 
648
+ size_filter_dropdown = gr.Dropdown(
649
+ label="Model Size Filter",
650
+ choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
651
+ value="None",
652
+ info="Filter models by their size (using params column)"
653
+ )
654
+
655
+ top_k_slider = gr.Slider(
656
+ label="Number of Top Organizations",
657
+ minimum=5,
658
+ maximum=50,
659
+ value=25,
660
+ step=5,
661
+ info="Number of top organizations to include"
662
+ )
663
+
664
+ skip_orgs_textbox = gr.Textbox(
665
+ label="Organizations to Skip (comma-separated)",
666
+ placeholder="e.g., OpenAI, Google",
667
+ value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
668
+ )
669
+
670
+ generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
671
+ refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary")
672
 
673
  with gr.Column(scale=3):
674
  plot_output = gr.Plot()
675
+ stats_output = gr.Markdown("*Loading data from Hugging Face...*")
676
+ data_info = gr.Markdown("")
677
+
678
+ # Button enablement after data load
679
+ def enable_plot_button(loaded):
680
+ return gr.update(interactive=loaded)
681
+
682
+ loading_complete.change(
683
+ fn=enable_plot_button,
684
+ inputs=[loading_complete],
685
+ outputs=[generate_plot_button]
686
+ )
687
 
688
+ # Show/hide tag/pipeline dropdown
689
+ def update_filter_visibility(filter_choice):
 
 
 
 
690
  if filter_choice == "Tag Filter":
691
+ return gr.update(visible=True), gr.update(visible=False)
692
  elif filter_choice == "Pipeline Filter":
693
+ return gr.update(visible=False), gr.update(visible=True)
694
+ else:
695
+ return gr.update(visible=False), gr.update(visible=False)
696
 
697
+ filter_choice_radio.change(
698
+ fn=update_filter_visibility,
699
+ inputs=[filter_choice_radio],
700
+ outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
701
+ )
702
+
703
+ # Function to handle data load and provide data info
704
+ def load_and_provide_info():
705
+ df, success = load_models_data()
706
 
707
+ if success:
708
+ # Generate information about the loaded data
709
+ info_text = f"""
710
+ ### Data Information
711
+ - **Total models loaded**: {len(df):,}
712
+ - **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
713
+ - **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet)
714
+ """
715
+
716
+ # Return the data, loading status, and info text
717
+ return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*"
718
+ else:
719
+ # Return empty data, failed loading status, and error message
720
+ return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*"
721
+
722
+ # Main generate function
723
+ def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
724
+ if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
725
+ return None, "Error: Data is still loading. Please wait a moment and try again."
726
+
727
  selected_tag_filter = None
728
  selected_pipeline_filter = None
729
+ selected_size_filter = None
730
 
731
  if filter_choice == "Tag Filter":
732
  selected_tag_filter = tag_filter
733
  elif filter_choice == "Pipeline Filter":
734
  selected_pipeline_filter = pipeline_filter
735
+
736
+ if size_filter != "None":
737
+ selected_size_filter = size_filter
738
+
739
+ skip_orgs = []
740
+ if skip_orgs_text and skip_orgs_text.strip():
741
+ skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
742
+
743
+ treemap_data = make_treemap_data(
744
+ df=data_df,
745
  count_by=count_by,
746
+ top_k=top_k,
 
747
  tag_filter=selected_tag_filter,
748
+ pipeline_filter=selected_pipeline_filter,
749
+ size_filter=selected_size_filter,
750
+ skip_orgs=skip_orgs
751
  )
 
752
 
753
+ title_labels = {
754
+ "downloads": "Downloads (last 30 days)",
755
+ "downloadsAllTime": "Downloads (All Time)",
756
+ "likes": "Likes"
757
+ }
758
+ title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
 
759
 
760
+ fig = create_treemap(
761
+ treemap_data=treemap_data,
762
+ count_by=count_by,
763
+ title=title_text
764
+ )
765
+
766
+ if treemap_data.empty:
767
+ stats_md = "No data matches the selected filters."
768
+ else:
769
+ total_models = len(treemap_data)
770
+ total_value = treemap_data[count_by].sum()
771
+
772
+ # Get top 5 organizations
773
+ top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
774
+
775
+ # Get top 5 individual models
776
+ top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
777
 
778
+ # Create statistics section
779
+ stats_md = f"""
780
+ ## Statistics
781
+ - **Total models shown**: {total_models:,}
782
+ - **Total {count_by}**: {int(total_value):,}
783
+
784
+ ## Top Organizations by {count_by.capitalize()}
785
+
786
+ | Organization | {count_by.capitalize()} | % of Total |
787
+ |--------------|-------------:|----------:|
788
+ """
789
+
790
+ # Add top organizations to the table
791
+ for org, value in top_5_orgs.items():
792
+ percentage = (value / total_value) * 100
793
+ stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
794
+
795
+ # Add the top models table
796
+ stats_md += f"""
797
+ ## Top Models by {count_by.capitalize()}
798
+
799
+ | Model | {count_by.capitalize()} | % of Total |
800
+ |-------|-------------:|----------:|
801
+ """
802
+
803
+ # Add top models to the table
804
+ for _, row in top_5_models.iterrows():
805
+ model_id = row["id"]
806
+ value = row[count_by]
807
+ percentage = (value / total_value) * 100
808
+ stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
809
+
810
+ # Add note about skipped organizations if any
811
+ if skip_orgs:
812
+ stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
813
+
814
+ return fig, stats_md
815
+
816
+ # Load data at startup
817
  demo.load(
818
+ fn=load_and_provide_info,
819
+ inputs=[],
820
+ outputs=[models_data, loading_complete, data_info, stats_output]
821
+ )
822
+
823
+ # Refresh data when button is clicked
824
+ refresh_data_button.click(
825
+ fn=load_and_provide_info,
826
+ inputs=[],
827
+ outputs=[models_data, loading_complete, data_info, stats_output]
828
  )
829
 
 
830
  generate_plot_button.click(
831
  fn=generate_plot_on_click,
832
  inputs=[
 
833
  count_by_dropdown,
834
  filter_choice_radio,
835
  tag_filter_dropdown,
836
  pipeline_filter_dropdown,
837
+ size_filter_dropdown,
838
+ top_k_slider,
839
+ skip_orgs_textbox,
840
+ models_data
841
  ],
842
+ outputs=[plot_output, stats_output]
843
  )
844
 
 
845
  if __name__ == "__main__":
 
846
  demo.launch()