Read live data from DuckDB
Browse files
app.py
CHANGED
@@ -2,7 +2,12 @@ import json
|
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import plotly.express as px
|
|
|
|
|
|
|
|
|
5 |
|
|
|
6 |
PIPELINE_TAGS = [
|
7 |
'text-generation',
|
8 |
'text-to-image',
|
@@ -44,45 +49,129 @@ PIPELINE_TAGS = [
|
|
44 |
'table-question-answering',
|
45 |
]
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
def
|
55 |
-
|
56 |
-
return
|
57 |
|
58 |
-
def
|
59 |
-
|
60 |
-
return
|
61 |
|
62 |
-
def
|
63 |
-
|
64 |
-
|
65 |
-
return res
|
66 |
|
67 |
-
def
|
68 |
-
|
69 |
-
return
|
70 |
|
71 |
-
def
|
72 |
-
|
73 |
-
return
|
74 |
|
75 |
-
def
|
76 |
-
|
77 |
-
return
|
78 |
|
79 |
-
def
|
80 |
-
|
81 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
TAG_FILTER_FUNCS = {
|
88 |
"Audio & Speech": is_audio_speech,
|
@@ -96,180 +185,662 @@ TAG_FILTER_FUNCS = {
|
|
96 |
"Sciences": is_science,
|
97 |
}
|
98 |
|
99 |
-
def
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
)
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
if pipeline_filter:
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
)
|
143 |
-
|
144 |
-
|
145 |
fig.update_layout(
|
146 |
-
|
147 |
-
margin = dict(t=50, l=25, r=25, b=25)
|
148 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
return fig
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
|
|
152 |
with gr.Blocks() as demo:
|
153 |
-
|
|
|
154 |
|
155 |
with gr.Row():
|
156 |
gr.Markdown("""
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
This app shows how different organizations are contributing to different aspects of the open AI ecosystem.
|
160 |
-
Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest.
|
161 |
""")
|
|
|
162 |
with gr.Row():
|
163 |
with gr.Column(scale=1):
|
164 |
-
repo_type_dropdown = gr.Dropdown(
|
165 |
-
label="Repository Type",
|
166 |
-
choices=["all", "models", "datasets"],
|
167 |
-
value="all"
|
168 |
-
)
|
169 |
count_by_dropdown = gr.Dropdown(
|
170 |
label="Metric",
|
171 |
-
choices=[
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
173 |
)
|
174 |
-
|
175 |
filter_choice_radio = gr.Radio(
|
176 |
-
label="Filter
|
177 |
choices=["None", "Tag Filter", "Pipeline Filter"],
|
178 |
-
value="None"
|
|
|
179 |
)
|
180 |
-
|
181 |
tag_filter_dropdown = gr.Dropdown(
|
182 |
label="Select Tag",
|
183 |
choices=list(TAG_FILTER_FUNCS.keys()),
|
184 |
value=None,
|
185 |
-
visible=False
|
|
|
186 |
)
|
|
|
187 |
pipeline_filter_dropdown = gr.Dropdown(
|
188 |
label="Select Pipeline Tag",
|
189 |
choices=PIPELINE_TAGS,
|
190 |
value=None,
|
191 |
-
visible=False
|
|
|
192 |
)
|
193 |
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
with gr.Column(scale=3):
|
197 |
plot_output = gr.Plot()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
print(f"Generating plot with the following inputs:")
|
202 |
-
print(f" Repository Type: {repo_type}")
|
203 |
-
print(f" Metric (Count By): {count_by}")
|
204 |
-
print(f" Filter Choice: {filter_choice}")
|
205 |
if filter_choice == "Tag Filter":
|
206 |
-
|
207 |
elif filter_choice == "Pipeline Filter":
|
208 |
-
|
|
|
|
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
selected_tag_filter = None
|
215 |
selected_pipeline_filter = None
|
|
|
216 |
|
217 |
if filter_choice == "Tag Filter":
|
218 |
selected_tag_filter = tag_filter
|
219 |
elif filter_choice == "Pipeline Filter":
|
220 |
selected_pipeline_filter = pipeline_filter
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
count_by=count_by,
|
224 |
-
|
225 |
-
org_stats=data,
|
226 |
tag_filter=selected_tag_filter,
|
227 |
-
pipeline_filter=selected_pipeline_filter
|
|
|
|
|
228 |
)
|
229 |
-
return fig
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
return gr.update(visible=False), gr.update(visible=False)
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
demo.load(
|
253 |
-
fn=
|
254 |
-
inputs=[],
|
255 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
)
|
257 |
|
258 |
-
# Button click event to generate plot
|
259 |
generate_plot_button.click(
|
260 |
fn=generate_plot_on_click,
|
261 |
inputs=[
|
262 |
-
repo_type_dropdown,
|
263 |
count_by_dropdown,
|
264 |
filter_choice_radio,
|
265 |
tag_filter_dropdown,
|
266 |
pipeline_filter_dropdown,
|
267 |
-
|
|
|
|
|
|
|
268 |
],
|
269 |
-
outputs=[plot_output]
|
270 |
)
|
271 |
|
272 |
-
|
273 |
if __name__ == "__main__":
|
274 |
-
# org_stats = json.load(open("org_to_artifacts_2l_stats.json")) # Data loading handled by demo.load
|
275 |
demo.launch()
|
|
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import plotly.express as px
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import io
|
8 |
+
import duckdb
|
9 |
|
10 |
+
# Define pipeline tags
|
11 |
PIPELINE_TAGS = [
|
12 |
'text-generation',
|
13 |
'text-to-image',
|
|
|
49 |
'table-question-answering',
|
50 |
]
|
51 |
|
52 |
+
# Model size categories in GB
|
53 |
+
MODEL_SIZE_RANGES = {
|
54 |
+
"Small (<1GB)": (0, 1),
|
55 |
+
"Medium (1-5GB)": (1, 5),
|
56 |
+
"Large (5-20GB)": (5, 20),
|
57 |
+
"X-Large (20-50GB)": (20, 50),
|
58 |
+
"XX-Large (>50GB)": (50, float('inf'))
|
59 |
+
}
|
60 |
+
|
61 |
+
# Filter functions for tags - UPDATED to use cached columns
|
62 |
+
def is_audio_speech(row):
|
63 |
+
# Use cached column instead of recalculating
|
64 |
+
return row['is_audio_speech']
|
65 |
+
|
66 |
+
def is_music(row):
|
67 |
+
# Use cached column instead of recalculating
|
68 |
+
return row['has_music']
|
69 |
|
70 |
+
def is_robotics(row):
|
71 |
+
# Use cached column instead of recalculating
|
72 |
+
return row['has_robot']
|
73 |
|
74 |
+
def is_biomed(row):
|
75 |
+
# Use cached column instead of recalculating
|
76 |
+
return row['is_biomed']
|
77 |
|
78 |
+
def is_timeseries(row):
|
79 |
+
# Use cached column instead of recalculating
|
80 |
+
return row['has_series']
|
|
|
81 |
|
82 |
+
def is_science(row):
|
83 |
+
# Use cached column instead of recalculating
|
84 |
+
return row['has_science']
|
85 |
|
86 |
+
def is_video(row):
|
87 |
+
# Use cached column instead of recalculating
|
88 |
+
return row['has_video']
|
89 |
|
90 |
+
def is_image(row):
|
91 |
+
# Use cached column instead of recalculating
|
92 |
+
return row['has_image']
|
93 |
|
94 |
+
def is_text(row):
|
95 |
+
# Use cached column instead of recalculating
|
96 |
+
return row['has_text']
|
97 |
+
|
98 |
+
def is_image(row):
|
99 |
+
tags = row.get("tags", [])
|
100 |
+
|
101 |
+
# Check if tags exists and is not empty
|
102 |
+
if tags is not None:
|
103 |
+
# For numpy arrays
|
104 |
+
if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
|
105 |
+
# Convert numpy array to list
|
106 |
+
tags_list = tags.tolist()
|
107 |
+
return any("image" in str(tag).lower() for tag in tags_list)
|
108 |
+
# For regular lists
|
109 |
+
elif isinstance(tags, list):
|
110 |
+
return any("image" in str(tag).lower() for tag in tags)
|
111 |
+
# For string tags
|
112 |
+
elif isinstance(tags, str):
|
113 |
+
return "image" in tags.lower()
|
114 |
+
return False
|
115 |
+
|
116 |
+
def is_text(row):
|
117 |
+
tags = row.get("tags", [])
|
118 |
+
|
119 |
+
# Check if tags exists and is not empty
|
120 |
+
if tags is not None:
|
121 |
+
# For numpy arrays
|
122 |
+
if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
|
123 |
+
# Convert numpy array to list
|
124 |
+
tags_list = tags.tolist()
|
125 |
+
return any("text" in str(tag).lower() for tag in tags_list)
|
126 |
+
# For regular lists
|
127 |
+
elif isinstance(tags, list):
|
128 |
+
return any("text" in str(tag).lower() for tag in tags)
|
129 |
+
# For string tags
|
130 |
+
elif isinstance(tags, str):
|
131 |
+
return "text" in tags.lower()
|
132 |
+
return False
|
133 |
+
|
134 |
+
def extract_model_size(safetensors_data):
|
135 |
+
"""Extract model size in GB from safetensors data"""
|
136 |
+
try:
|
137 |
+
if pd.isna(safetensors_data):
|
138 |
+
return 0
|
139 |
+
|
140 |
+
# If it's already a dictionary, use it directly
|
141 |
+
if isinstance(safetensors_data, dict):
|
142 |
+
if 'total' in safetensors_data:
|
143 |
+
try:
|
144 |
+
size_bytes = float(safetensors_data['total'])
|
145 |
+
return size_bytes / (1024 * 1024 * 1024) # Convert to GB
|
146 |
+
except (ValueError, TypeError):
|
147 |
+
pass
|
148 |
+
|
149 |
+
# If it's a string, try to parse it as JSON
|
150 |
+
elif isinstance(safetensors_data, str):
|
151 |
+
try:
|
152 |
+
data_dict = json.loads(safetensors_data)
|
153 |
+
if 'total' in data_dict:
|
154 |
+
try:
|
155 |
+
size_bytes = float(data_dict['total'])
|
156 |
+
return size_bytes / (1024 * 1024 * 1024) # Convert to GB
|
157 |
+
except (ValueError, TypeError):
|
158 |
+
pass
|
159 |
+
except:
|
160 |
+
pass
|
161 |
+
|
162 |
+
return 0
|
163 |
+
except Exception as e:
|
164 |
+
print(f"Error extracting model size: {e}")
|
165 |
+
return 0
|
166 |
|
167 |
+
# Add model size filter function - UPDATED to use cached size_category column
|
168 |
+
def is_in_size_range(row, size_range):
|
169 |
+
"""Check if a model is in the specified size range using pre-calculated size category"""
|
170 |
+
if size_range is None or size_range == "None":
|
171 |
+
return True
|
172 |
+
|
173 |
+
# Simply compare with cached size_category
|
174 |
+
return row['size_category'] == size_range
|
175 |
|
176 |
TAG_FILTER_FUNCS = {
|
177 |
"Audio & Speech": is_audio_speech,
|
|
|
185 |
"Sciences": is_science,
|
186 |
}
|
187 |
|
188 |
+
def extract_org_from_id(model_id):
|
189 |
+
"""Extract organization name from model ID"""
|
190 |
+
if "/" in model_id:
|
191 |
+
return model_id.split("/")[0]
|
192 |
+
return "unaffiliated"
|
193 |
+
|
194 |
+
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
195 |
+
"""Process DataFrame into treemap format with filters applied - OPTIMIZED with cached columns"""
|
196 |
+
# Create a copy to avoid modifying the original
|
197 |
+
filtered_df = df.copy()
|
198 |
+
|
199 |
+
# Apply filters
|
200 |
+
filter_stats = {"initial": len(filtered_df)}
|
201 |
+
start_time = pd.Timestamp.now()
|
202 |
+
|
203 |
+
# Apply tag filter - OPTIMIZED to use cached columns
|
204 |
+
if tag_filter and tag_filter in TAG_FILTER_FUNCS:
|
205 |
+
print(f"Applying tag filter: {tag_filter}")
|
206 |
+
|
207 |
+
# Use direct column filtering instead of applying a function to each row
|
208 |
+
if tag_filter == "Audio & Speech":
|
209 |
+
filtered_df = filtered_df[filtered_df['is_audio_speech']]
|
210 |
+
elif tag_filter == "Music":
|
211 |
+
filtered_df = filtered_df[filtered_df['has_music']]
|
212 |
+
elif tag_filter == "Robotics":
|
213 |
+
filtered_df = filtered_df[filtered_df['has_robot']]
|
214 |
+
elif tag_filter == "Biomedical":
|
215 |
+
filtered_df = filtered_df[filtered_df['is_biomed']]
|
216 |
+
elif tag_filter == "Time series":
|
217 |
+
filtered_df = filtered_df[filtered_df['has_series']]
|
218 |
+
elif tag_filter == "Sciences":
|
219 |
+
filtered_df = filtered_df[filtered_df['has_science']]
|
220 |
+
elif tag_filter == "Video":
|
221 |
+
filtered_df = filtered_df[filtered_df['has_video']]
|
222 |
+
elif tag_filter == "Images":
|
223 |
+
filtered_df = filtered_df[filtered_df['has_image']]
|
224 |
+
elif tag_filter == "Text":
|
225 |
+
filtered_df = filtered_df[filtered_df['has_text']]
|
226 |
+
|
227 |
+
filter_stats["after_tag_filter"] = len(filtered_df)
|
228 |
+
print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
229 |
+
start_time = pd.Timestamp.now()
|
230 |
+
|
231 |
+
# Apply pipeline filter
|
232 |
if pipeline_filter:
|
233 |
+
print(f"Applying pipeline filter: {pipeline_filter}")
|
234 |
+
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
|
235 |
+
filter_stats["after_pipeline_filter"] = len(filtered_df)
|
236 |
+
print(f"Pipeline filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
237 |
+
start_time = pd.Timestamp.now()
|
238 |
+
|
239 |
+
# Apply size filter - OPTIMIZED to use cached size_category column
|
240 |
+
if size_filter and size_filter in MODEL_SIZE_RANGES:
|
241 |
+
print(f"Applying size filter: {size_filter}")
|
242 |
+
|
243 |
+
# Use the cached size_category column directly
|
244 |
+
filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
|
245 |
+
|
246 |
+
# Debug info
|
247 |
+
print(f"Size filter '{size_filter}' applied.")
|
248 |
+
print(f"Models after size filter: {len(filtered_df)}")
|
249 |
+
|
250 |
+
filter_stats["after_size_filter"] = len(filtered_df)
|
251 |
+
print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
252 |
+
start_time = pd.Timestamp.now()
|
253 |
+
|
254 |
+
# Add organization column
|
255 |
+
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
|
256 |
+
|
257 |
+
# Skip organizations if specified
|
258 |
+
if skip_orgs and len(skip_orgs) > 0:
|
259 |
+
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
|
260 |
+
filter_stats["after_skip_orgs"] = len(filtered_df)
|
261 |
+
|
262 |
+
# Print filter stats
|
263 |
+
print("Filter statistics:")
|
264 |
+
for stage, count in filter_stats.items():
|
265 |
+
print(f" {stage}: {count} models")
|
266 |
+
|
267 |
+
# Check if we have any data left
|
268 |
+
if filtered_df.empty:
|
269 |
+
print("Warning: No data left after applying filters!")
|
270 |
+
return pd.DataFrame() # Return empty DataFrame
|
271 |
+
|
272 |
+
# Aggregate by organization
|
273 |
+
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
|
274 |
+
org_totals = org_totals.sort_values(by=count_by, ascending=False)
|
275 |
+
|
276 |
+
# Get top organizations
|
277 |
+
top_orgs = org_totals.head(top_k)["organization"].tolist()
|
278 |
+
|
279 |
+
# Filter to only include models from top organizations
|
280 |
+
filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
|
281 |
+
|
282 |
+
# Prepare data for treemap
|
283 |
+
treemap_data = filtered_df[["id", "organization", count_by]].copy()
|
284 |
+
|
285 |
+
# Add a root node
|
286 |
+
treemap_data["root"] = "models"
|
287 |
+
|
288 |
+
# Ensure numeric values
|
289 |
+
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
|
290 |
+
|
291 |
+
print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
292 |
+
return treemap_data
|
293 |
+
|
294 |
+
def create_treemap(treemap_data, count_by, title=None):
|
295 |
+
"""Create a Plotly treemap from the prepared data"""
|
296 |
+
if treemap_data.empty:
|
297 |
+
# Create an empty figure with a message
|
298 |
+
fig = px.treemap(
|
299 |
+
names=["No data matches the selected filters"],
|
300 |
+
values=[1]
|
301 |
+
)
|
302 |
+
fig.update_layout(
|
303 |
+
title="No data matches the selected filters",
|
304 |
+
margin=dict(t=50, l=25, r=25, b=25)
|
305 |
)
|
306 |
+
return fig
|
307 |
+
|
308 |
+
# Create the treemap
|
309 |
+
fig = px.treemap(
|
310 |
+
treemap_data,
|
311 |
+
path=["root", "organization", "id"],
|
312 |
+
values=count_by,
|
313 |
+
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
|
314 |
+
color_discrete_sequence=px.colors.qualitative.Plotly
|
315 |
)
|
316 |
+
|
317 |
+
# Update layout
|
318 |
fig.update_layout(
|
319 |
+
margin=dict(t=50, l=25, r=25, b=25)
|
|
|
320 |
)
|
321 |
+
|
322 |
+
# Update traces for better readability
|
323 |
+
fig.update_traces(
|
324 |
+
textinfo="label+value+percent root",
|
325 |
+
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
|
326 |
+
)
|
327 |
+
|
328 |
return fig
|
329 |
|
330 |
+
def load_models_data():
|
331 |
+
"""Load models data from Hugging Face using DuckDB with caching for improved performance"""
|
332 |
+
try:
|
333 |
+
# The URL to the parquet file
|
334 |
+
parquet_url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
|
335 |
+
|
336 |
+
print("Fetching data from Hugging Face models.parquet...")
|
337 |
+
|
338 |
+
# Based on the column names provided, we can directly select the columns we need
|
339 |
+
# Note: We need to select safetensors to get the model size information
|
340 |
+
try:
|
341 |
+
query = """
|
342 |
+
SELECT
|
343 |
+
id,
|
344 |
+
downloads,
|
345 |
+
downloadsAllTime,
|
346 |
+
likes,
|
347 |
+
pipeline_tag,
|
348 |
+
tags,
|
349 |
+
safetensors
|
350 |
+
FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')
|
351 |
+
"""
|
352 |
+
df = duckdb.sql(query).df()
|
353 |
+
except Exception as sql_error:
|
354 |
+
print(f"Error with specific column selection: {sql_error}")
|
355 |
+
# Fallback to just selecting everything and then filtering
|
356 |
+
print("Falling back to select * query...")
|
357 |
+
query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')"
|
358 |
+
raw_df = duckdb.sql(query).df()
|
359 |
+
|
360 |
+
# Now extract only the columns we need
|
361 |
+
needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors']
|
362 |
+
available_columns = set(raw_df.columns)
|
363 |
+
df = pd.DataFrame()
|
364 |
+
|
365 |
+
# Copy over columns that exist
|
366 |
+
for col in needed_columns:
|
367 |
+
if col in available_columns:
|
368 |
+
df[col] = raw_df[col]
|
369 |
+
else:
|
370 |
+
# Create empty columns for missing data
|
371 |
+
if col in ['downloads', 'downloadsAllTime', 'likes']:
|
372 |
+
df[col] = 0
|
373 |
+
elif col == 'pipeline_tag':
|
374 |
+
df[col] = ''
|
375 |
+
elif col == 'tags':
|
376 |
+
df[col] = [[] for _ in range(len(raw_df))]
|
377 |
+
elif col == 'safetensors':
|
378 |
+
df[col] = None
|
379 |
+
elif col == 'id':
|
380 |
+
# Create IDs based on index if missing
|
381 |
+
df[col] = [f"model_{i}" for i in range(len(raw_df))]
|
382 |
+
|
383 |
+
print(f"Data fetched successfully. Shape: {df.shape}")
|
384 |
+
|
385 |
+
# Check if safetensors column exists before trying to process it
|
386 |
+
if 'safetensors' in df.columns:
|
387 |
+
# Add params column derived from safetensors.total (model size in GB)
|
388 |
+
df['params'] = df['safetensors'].apply(extract_model_size)
|
389 |
+
|
390 |
+
# Debug model sizes
|
391 |
+
size_ranges = {
|
392 |
+
"Small (<1GB)": 0,
|
393 |
+
"Medium (1-5GB)": 0,
|
394 |
+
"Large (5-20GB)": 0,
|
395 |
+
"X-Large (20-50GB)": 0,
|
396 |
+
"XX-Large (>50GB)": 0
|
397 |
+
}
|
398 |
+
|
399 |
+
# Count models in each size range
|
400 |
+
for idx, row in df.iterrows():
|
401 |
+
size_gb = row['params']
|
402 |
+
if 0 <= size_gb < 1:
|
403 |
+
size_ranges["Small (<1GB)"] += 1
|
404 |
+
elif 1 <= size_gb < 5:
|
405 |
+
size_ranges["Medium (1-5GB)"] += 1
|
406 |
+
elif 5 <= size_gb < 20:
|
407 |
+
size_ranges["Large (5-20GB)"] += 1
|
408 |
+
elif 20 <= size_gb < 50:
|
409 |
+
size_ranges["X-Large (20-50GB)"] += 1
|
410 |
+
elif size_gb >= 50:
|
411 |
+
size_ranges["XX-Large (>50GB)"] += 1
|
412 |
+
|
413 |
+
print("Model size distribution:")
|
414 |
+
for size_range, count in size_ranges.items():
|
415 |
+
print(f" {size_range}: {count} models")
|
416 |
+
|
417 |
+
# CACHE SIZE CATEGORY: Add a size_category column for faster filtering
|
418 |
+
def get_size_category(size_gb):
|
419 |
+
if 0 <= size_gb < 1:
|
420 |
+
return "Small (<1GB)"
|
421 |
+
elif 1 <= size_gb < 5:
|
422 |
+
return "Medium (1-5GB)"
|
423 |
+
elif 5 <= size_gb < 20:
|
424 |
+
return "Large (5-20GB)"
|
425 |
+
elif 20 <= size_gb < 50:
|
426 |
+
return "X-Large (20-50GB)"
|
427 |
+
elif size_gb >= 50:
|
428 |
+
return "XX-Large (>50GB)"
|
429 |
+
return None
|
430 |
+
|
431 |
+
# Add cached size category column
|
432 |
+
df['size_category'] = df['params'].apply(get_size_category)
|
433 |
+
|
434 |
+
# Remove the safetensors column as we don't need it anymore
|
435 |
+
df = df.drop(columns=['safetensors'])
|
436 |
+
else:
|
437 |
+
# If no safetensors column, add empty params column
|
438 |
+
df['params'] = 0
|
439 |
+
df['size_category'] = None
|
440 |
+
|
441 |
+
# Process tags to ensure it's in the right format - FIXED
|
442 |
+
def process_tags(tags_value):
|
443 |
+
try:
|
444 |
+
if pd.isna(tags_value) or tags_value is None:
|
445 |
+
return []
|
446 |
+
|
447 |
+
# If it's a numpy array, convert to a list of strings
|
448 |
+
if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'):
|
449 |
+
# Note: This is the fix for the error
|
450 |
+
return [str(tag) for tag in tags_value.tolist()]
|
451 |
+
|
452 |
+
# If already a list, ensure all elements are strings
|
453 |
+
if isinstance(tags_value, list):
|
454 |
+
return [str(tag) for tag in tags_value]
|
455 |
+
|
456 |
+
# If string, try to parse as JSON or split by comma
|
457 |
+
if isinstance(tags_value, str):
|
458 |
+
try:
|
459 |
+
tags_list = json.loads(tags_value)
|
460 |
+
if isinstance(tags_list, list):
|
461 |
+
return [str(tag) for tag in tags_list]
|
462 |
+
except:
|
463 |
+
# Split by comma if JSON parsing fails
|
464 |
+
return [tag.strip() for tag in tags_value.split(',') if tag.strip()]
|
465 |
+
|
466 |
+
# Last resort, convert to string and return as a single tag
|
467 |
+
return [str(tags_value)]
|
468 |
+
|
469 |
+
except Exception as e:
|
470 |
+
print(f"Error processing tags: {e}")
|
471 |
+
return []
|
472 |
+
|
473 |
+
# Check if tags column exists before trying to process it
|
474 |
+
if 'tags' in df.columns:
|
475 |
+
# Process tags column
|
476 |
+
df['tags'] = df['tags'].apply(process_tags)
|
477 |
+
|
478 |
+
# CACHE TAG CATEGORIES: Pre-calculate tag categories for faster filtering
|
479 |
+
print("Pre-calculating cached tag categories...")
|
480 |
+
|
481 |
+
# Helper functions to check for specific tags (simplified for caching)
|
482 |
+
def has_audio_tag(tags):
|
483 |
+
if tags and isinstance(tags, list):
|
484 |
+
return any("audio" in str(tag).lower() for tag in tags)
|
485 |
+
return False
|
486 |
+
|
487 |
+
def has_speech_tag(tags):
|
488 |
+
if tags and isinstance(tags, list):
|
489 |
+
return any("speech" in str(tag).lower() for tag in tags)
|
490 |
+
return False
|
491 |
+
|
492 |
+
def has_music_tag(tags):
|
493 |
+
if tags and isinstance(tags, list):
|
494 |
+
return any("music" in str(tag).lower() for tag in tags)
|
495 |
+
return False
|
496 |
+
|
497 |
+
def has_robot_tag(tags):
|
498 |
+
if tags and isinstance(tags, list):
|
499 |
+
return any("robot" in str(tag).lower() for tag in tags)
|
500 |
+
return False
|
501 |
+
|
502 |
+
def has_bio_tag(tags):
|
503 |
+
if tags and isinstance(tags, list):
|
504 |
+
return any("bio" in str(tag).lower() for tag in tags)
|
505 |
+
return False
|
506 |
+
|
507 |
+
def has_med_tag(tags):
|
508 |
+
if tags and isinstance(tags, list):
|
509 |
+
return any("medic" in str(tag).lower() for tag in tags)
|
510 |
+
return False
|
511 |
+
|
512 |
+
def has_series_tag(tags):
|
513 |
+
if tags and isinstance(tags, list):
|
514 |
+
return any("series" in str(tag).lower() for tag in tags)
|
515 |
+
return False
|
516 |
+
|
517 |
+
def has_science_tag(tags):
|
518 |
+
if tags and isinstance(tags, list):
|
519 |
+
return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags)
|
520 |
+
return False
|
521 |
+
|
522 |
+
def has_video_tag(tags):
|
523 |
+
if tags and isinstance(tags, list):
|
524 |
+
return any("video" in str(tag).lower() for tag in tags)
|
525 |
+
return False
|
526 |
+
|
527 |
+
def has_image_tag(tags):
|
528 |
+
if tags and isinstance(tags, list):
|
529 |
+
return any("image" in str(tag).lower() for tag in tags)
|
530 |
+
return False
|
531 |
+
|
532 |
+
def has_text_tag(tags):
|
533 |
+
if tags and isinstance(tags, list):
|
534 |
+
return any("text" in str(tag).lower() for tag in tags)
|
535 |
+
return False
|
536 |
+
|
537 |
+
# Add cached columns for tag categories
|
538 |
+
print("Creating cached tag columns...")
|
539 |
+
df['has_audio'] = df['tags'].apply(has_audio_tag)
|
540 |
+
df['has_speech'] = df['tags'].apply(has_speech_tag)
|
541 |
+
df['has_music'] = df['tags'].apply(has_music_tag)
|
542 |
+
df['has_robot'] = df['tags'].apply(has_robot_tag)
|
543 |
+
df['has_bio'] = df['tags'].apply(has_bio_tag)
|
544 |
+
df['has_med'] = df['tags'].apply(has_med_tag)
|
545 |
+
df['has_series'] = df['tags'].apply(has_series_tag)
|
546 |
+
df['has_science'] = df['tags'].apply(has_science_tag)
|
547 |
+
df['has_video'] = df['tags'].apply(has_video_tag)
|
548 |
+
df['has_image'] = df['tags'].apply(has_image_tag)
|
549 |
+
df['has_text'] = df['tags'].apply(has_text_tag)
|
550 |
+
|
551 |
+
# Create combined category flags for faster filtering
|
552 |
+
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
553 |
+
df['pipeline_tag'].str.contains('audio', case=False, na=False) |
|
554 |
+
df['pipeline_tag'].str.contains('speech', case=False, na=False))
|
555 |
+
df['is_biomed'] = df['has_bio'] | df['has_med']
|
556 |
+
|
557 |
+
print("Cached tag columns created successfully!")
|
558 |
+
else:
|
559 |
+
# If no tags column, add empty tags and set all category flags to False
|
560 |
+
df['tags'] = [[] for _ in range(len(df))]
|
561 |
+
for col in ['has_audio', 'has_speech', 'has_music', 'has_robot',
|
562 |
+
'has_bio', 'has_med', 'has_series', 'has_science',
|
563 |
+
'has_video', 'has_image', 'has_text',
|
564 |
+
'is_audio_speech', 'is_biomed']:
|
565 |
+
df[col] = False
|
566 |
+
|
567 |
+
# Fill NaN values
|
568 |
+
df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True)
|
569 |
+
|
570 |
+
# Ensure pipeline_tag is a string
|
571 |
+
if 'pipeline_tag' in df.columns:
|
572 |
+
df['pipeline_tag'] = df['pipeline_tag'].fillna('')
|
573 |
+
else:
|
574 |
+
df['pipeline_tag'] = ''
|
575 |
+
|
576 |
+
# Make sure all required columns exist
|
577 |
+
for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']:
|
578 |
+
if col not in df.columns:
|
579 |
+
if col in ['downloads', 'downloadsAllTime', 'likes', 'params']:
|
580 |
+
df[col] = 0
|
581 |
+
elif col == 'pipeline_tag':
|
582 |
+
df[col] = ''
|
583 |
+
elif col == 'tags':
|
584 |
+
df[col] = [[] for _ in range(len(df))]
|
585 |
+
elif col == 'id':
|
586 |
+
df[col] = [f"model_{i}" for i in range(len(df))]
|
587 |
+
|
588 |
+
print(f"Successfully processed {len(df)} models with cached tag and size information")
|
589 |
+
return df, True
|
590 |
+
|
591 |
+
except Exception as e:
|
592 |
+
print(f"Error loading data: {e}")
|
593 |
+
# Return an empty DataFrame and False to indicate loading failure
|
594 |
+
return pd.DataFrame(), False
|
595 |
|
596 |
+
# Create Gradio interface
|
597 |
with gr.Blocks() as demo:
|
598 |
+
models_data = gr.State()
|
599 |
+
loading_complete = gr.State(False) # Flag to indicate data load completion
|
600 |
|
601 |
with gr.Row():
|
602 |
gr.Markdown("""
|
603 |
+
# HuggingFace Models TreeMap Visualization
|
604 |
+
|
605 |
+
This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
|
606 |
+
Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
|
607 |
+
|
608 |
+
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
|
609 |
|
|
|
|
|
610 |
""")
|
611 |
+
|
612 |
with gr.Row():
|
613 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
614 |
count_by_dropdown = gr.Dropdown(
|
615 |
label="Metric",
|
616 |
+
choices=[
|
617 |
+
("Downloads (last 30 days)", "downloads"),
|
618 |
+
("Downloads (All Time)", "downloadsAllTime"),
|
619 |
+
("Likes", "likes")
|
620 |
+
],
|
621 |
+
value="downloads",
|
622 |
+
info="Select the metric to determine box sizes"
|
623 |
)
|
624 |
+
|
625 |
filter_choice_radio = gr.Radio(
|
626 |
+
label="Filter Type",
|
627 |
choices=["None", "Tag Filter", "Pipeline Filter"],
|
628 |
+
value="None",
|
629 |
+
info="Choose how to filter the models"
|
630 |
)
|
631 |
+
|
632 |
tag_filter_dropdown = gr.Dropdown(
|
633 |
label="Select Tag",
|
634 |
choices=list(TAG_FILTER_FUNCS.keys()),
|
635 |
value=None,
|
636 |
+
visible=False,
|
637 |
+
info="Filter models by domain/category"
|
638 |
)
|
639 |
+
|
640 |
pipeline_filter_dropdown = gr.Dropdown(
|
641 |
label="Select Pipeline Tag",
|
642 |
choices=PIPELINE_TAGS,
|
643 |
value=None,
|
644 |
+
visible=False,
|
645 |
+
info="Filter models by specific pipeline"
|
646 |
)
|
647 |
|
648 |
+
size_filter_dropdown = gr.Dropdown(
|
649 |
+
label="Model Size Filter",
|
650 |
+
choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
|
651 |
+
value="None",
|
652 |
+
info="Filter models by their size (using params column)"
|
653 |
+
)
|
654 |
+
|
655 |
+
top_k_slider = gr.Slider(
|
656 |
+
label="Number of Top Organizations",
|
657 |
+
minimum=5,
|
658 |
+
maximum=50,
|
659 |
+
value=25,
|
660 |
+
step=5,
|
661 |
+
info="Number of top organizations to include"
|
662 |
+
)
|
663 |
+
|
664 |
+
skip_orgs_textbox = gr.Textbox(
|
665 |
+
label="Organizations to Skip (comma-separated)",
|
666 |
+
placeholder="e.g., OpenAI, Google",
|
667 |
+
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
|
668 |
+
)
|
669 |
+
|
670 |
+
generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
|
671 |
+
refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary")
|
672 |
|
673 |
with gr.Column(scale=3):
|
674 |
plot_output = gr.Plot()
|
675 |
+
stats_output = gr.Markdown("*Loading data from Hugging Face...*")
|
676 |
+
data_info = gr.Markdown("")
|
677 |
+
|
678 |
+
# Button enablement after data load
|
679 |
+
def enable_plot_button(loaded):
|
680 |
+
return gr.update(interactive=loaded)
|
681 |
+
|
682 |
+
loading_complete.change(
|
683 |
+
fn=enable_plot_button,
|
684 |
+
inputs=[loading_complete],
|
685 |
+
outputs=[generate_plot_button]
|
686 |
+
)
|
687 |
|
688 |
+
# Show/hide tag/pipeline dropdown
|
689 |
+
def update_filter_visibility(filter_choice):
|
|
|
|
|
|
|
|
|
690 |
if filter_choice == "Tag Filter":
|
691 |
+
return gr.update(visible=True), gr.update(visible=False)
|
692 |
elif filter_choice == "Pipeline Filter":
|
693 |
+
return gr.update(visible=False), gr.update(visible=True)
|
694 |
+
else:
|
695 |
+
return gr.update(visible=False), gr.update(visible=False)
|
696 |
|
697 |
+
filter_choice_radio.change(
|
698 |
+
fn=update_filter_visibility,
|
699 |
+
inputs=[filter_choice_radio],
|
700 |
+
outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
|
701 |
+
)
|
702 |
+
|
703 |
+
# Function to handle data load and provide data info
|
704 |
+
def load_and_provide_info():
|
705 |
+
df, success = load_models_data()
|
706 |
|
707 |
+
if success:
|
708 |
+
# Generate information about the loaded data
|
709 |
+
info_text = f"""
|
710 |
+
### Data Information
|
711 |
+
- **Total models loaded**: {len(df):,}
|
712 |
+
- **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
|
713 |
+
- **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet)
|
714 |
+
"""
|
715 |
+
|
716 |
+
# Return the data, loading status, and info text
|
717 |
+
return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*"
|
718 |
+
else:
|
719 |
+
# Return empty data, failed loading status, and error message
|
720 |
+
return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*"
|
721 |
+
|
722 |
+
# Main generate function
|
723 |
+
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
|
724 |
+
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
|
725 |
+
return None, "Error: Data is still loading. Please wait a moment and try again."
|
726 |
+
|
727 |
selected_tag_filter = None
|
728 |
selected_pipeline_filter = None
|
729 |
+
selected_size_filter = None
|
730 |
|
731 |
if filter_choice == "Tag Filter":
|
732 |
selected_tag_filter = tag_filter
|
733 |
elif filter_choice == "Pipeline Filter":
|
734 |
selected_pipeline_filter = pipeline_filter
|
735 |
+
|
736 |
+
if size_filter != "None":
|
737 |
+
selected_size_filter = size_filter
|
738 |
+
|
739 |
+
skip_orgs = []
|
740 |
+
if skip_orgs_text and skip_orgs_text.strip():
|
741 |
+
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
|
742 |
+
|
743 |
+
treemap_data = make_treemap_data(
|
744 |
+
df=data_df,
|
745 |
count_by=count_by,
|
746 |
+
top_k=top_k,
|
|
|
747 |
tag_filter=selected_tag_filter,
|
748 |
+
pipeline_filter=selected_pipeline_filter,
|
749 |
+
size_filter=selected_size_filter,
|
750 |
+
skip_orgs=skip_orgs
|
751 |
)
|
|
|
752 |
|
753 |
+
title_labels = {
|
754 |
+
"downloads": "Downloads (last 30 days)",
|
755 |
+
"downloadsAllTime": "Downloads (All Time)",
|
756 |
+
"likes": "Likes"
|
757 |
+
}
|
758 |
+
title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
|
|
|
759 |
|
760 |
+
fig = create_treemap(
|
761 |
+
treemap_data=treemap_data,
|
762 |
+
count_by=count_by,
|
763 |
+
title=title_text
|
764 |
+
)
|
765 |
+
|
766 |
+
if treemap_data.empty:
|
767 |
+
stats_md = "No data matches the selected filters."
|
768 |
+
else:
|
769 |
+
total_models = len(treemap_data)
|
770 |
+
total_value = treemap_data[count_by].sum()
|
771 |
+
|
772 |
+
# Get top 5 organizations
|
773 |
+
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
|
774 |
+
|
775 |
+
# Get top 5 individual models
|
776 |
+
top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
|
777 |
|
778 |
+
# Create statistics section
|
779 |
+
stats_md = f"""
|
780 |
+
## Statistics
|
781 |
+
- **Total models shown**: {total_models:,}
|
782 |
+
- **Total {count_by}**: {int(total_value):,}
|
783 |
+
|
784 |
+
## Top Organizations by {count_by.capitalize()}
|
785 |
+
|
786 |
+
| Organization | {count_by.capitalize()} | % of Total |
|
787 |
+
|--------------|-------------:|----------:|
|
788 |
+
"""
|
789 |
+
|
790 |
+
# Add top organizations to the table
|
791 |
+
for org, value in top_5_orgs.items():
|
792 |
+
percentage = (value / total_value) * 100
|
793 |
+
stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
|
794 |
+
|
795 |
+
# Add the top models table
|
796 |
+
stats_md += f"""
|
797 |
+
## Top Models by {count_by.capitalize()}
|
798 |
+
|
799 |
+
| Model | {count_by.capitalize()} | % of Total |
|
800 |
+
|-------|-------------:|----------:|
|
801 |
+
"""
|
802 |
+
|
803 |
+
# Add top models to the table
|
804 |
+
for _, row in top_5_models.iterrows():
|
805 |
+
model_id = row["id"]
|
806 |
+
value = row[count_by]
|
807 |
+
percentage = (value / total_value) * 100
|
808 |
+
stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
|
809 |
+
|
810 |
+
# Add note about skipped organizations if any
|
811 |
+
if skip_orgs:
|
812 |
+
stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
|
813 |
+
|
814 |
+
return fig, stats_md
|
815 |
+
|
816 |
+
# Load data at startup
|
817 |
demo.load(
|
818 |
+
fn=load_and_provide_info,
|
819 |
+
inputs=[],
|
820 |
+
outputs=[models_data, loading_complete, data_info, stats_output]
|
821 |
+
)
|
822 |
+
|
823 |
+
# Refresh data when button is clicked
|
824 |
+
refresh_data_button.click(
|
825 |
+
fn=load_and_provide_info,
|
826 |
+
inputs=[],
|
827 |
+
outputs=[models_data, loading_complete, data_info, stats_output]
|
828 |
)
|
829 |
|
|
|
830 |
generate_plot_button.click(
|
831 |
fn=generate_plot_on_click,
|
832 |
inputs=[
|
|
|
833 |
count_by_dropdown,
|
834 |
filter_choice_radio,
|
835 |
tag_filter_dropdown,
|
836 |
pipeline_filter_dropdown,
|
837 |
+
size_filter_dropdown,
|
838 |
+
top_k_slider,
|
839 |
+
skip_orgs_textbox,
|
840 |
+
models_data
|
841 |
],
|
842 |
+
outputs=[plot_output, stats_output]
|
843 |
)
|
844 |
|
|
|
845 |
if __name__ == "__main__":
|
|
|
846 |
demo.launch()
|