evijit HF Staff commited on
Commit
f38cb18
·
verified ·
1 Parent(s): f153269

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -11
app.py CHANGED
@@ -23,9 +23,12 @@ def load_models_data():
23
  dataset_dict = load_dataset(HF_DATASET_ID)
24
  df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas()
25
  if 'params' in df.columns:
26
- df['params'] = pd.to_numeric(df['params'], errors='coerce').fillna(0)
 
 
27
  else:
28
- df['params'] = 0
 
29
  msg = f"Successfully loaded dataset in {time.time() - overall_start_time:.2f}s."
30
  print(msg)
31
  return df, True, msg
@@ -40,9 +43,15 @@ def get_param_range_values(param_range_labels):
40
  max_val = float('inf') if '>' in max_label else float(max_label.replace('B', ''))
41
  return min_val, max_val
42
 
43
- def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, param_range=None, skip_orgs=None):
44
  if df is None or df.empty: return pd.DataFrame()
45
  filtered_df = df.copy()
 
 
 
 
 
 
46
  col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot", "Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science", "Video": "has_video", "Images": "has_image", "Text": "has_text" }
47
  if tag_filter and tag_filter in col_map and col_map[tag_filter] in filtered_df.columns:
48
  filtered_df = filtered_df[filtered_df[col_map[tag_filter]]]
@@ -51,9 +60,12 @@ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=N
51
  if param_range:
52
  min_params, max_params = get_param_range_values(param_range)
53
  is_default_range = (param_range[0] == PARAM_CHOICES[0] and param_range[1] == PARAM_CHOICES[-1])
 
 
54
  if not is_default_range and 'params' in filtered_df.columns:
55
  if min_params is not None: filtered_df = filtered_df[filtered_df['params'] >= min_params]
56
  if max_params is not None and max_params != float('inf'): filtered_df = filtered_df[filtered_df['params'] < max_params]
 
57
  if skip_orgs and len(skip_orgs) > 0 and "organization" in filtered_df.columns:
58
  filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
59
  if filtered_df.empty: return pd.DataFrame()
@@ -82,7 +94,6 @@ custom_css = """
82
  #param-slider-wrapper div[data-testid="range-slider"] > span {
83
  display: none !important;
84
  }
85
-
86
  /*
87
  THIS IS THE KEY FIX:
88
  We target all the individual component containers (divs with class .block)
@@ -129,6 +140,8 @@ with gr.Blocks(title="🤗 ModelVerse Explorer", fill_width=True, css=custom_css
129
  elem_id="param-slider-wrapper"
130
  )
131
  param_range_display = gr.Markdown(f"Range: `{PARAM_CHOICES[0]}` to `{PARAM_CHOICES[-1]}`")
 
 
132
 
133
  # This section remains un-grouped
134
  top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25)
@@ -166,8 +179,11 @@ with gr.Blocks(title="🤗 ModelVerse Explorer", fill_width=True, css=custom_css
166
  if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]):
167
  ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True)
168
  date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z')
169
- param_count = (current_df['params'] > 0).sum() if 'params' in current_df.columns else 0
170
- data_info_text = f"### Data Information\n- Source: `{HF_DATASET_ID}`\n- Status: {status_msg_from_load}\n- Total models loaded: {len(current_df):,}\n- Models with parameter counts: {param_count:,}\n- Data as of: {date_display}\n"
 
 
 
171
  else:
172
  data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
173
  except Exception as e:
@@ -178,7 +194,6 @@ with gr.Blocks(title="🤗 ModelVerse Explorer", fill_width=True, css=custom_css
178
  print(f"Critical error in load_and_generate_initial_plot: {e}")
179
 
180
  # --- Part 2: Generate Initial Plot ---
181
- # We call the existing plot generation function with the default values from the UI
182
  progress(0.6, desc="Generating initial plot...")
183
  # Get default values directly from the UI component definitions
184
  default_metric = "downloads"
@@ -188,18 +203,20 @@ with gr.Blocks(title="🤗 ModelVerse Explorer", fill_width=True, css=custom_css
188
  default_param_indices = PARAM_CHOICES_DEFAULT_INDICES
189
  default_k = 25
190
  default_skip_orgs = "TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski"
 
 
191
 
192
  # Reuse the existing controller function for plotting
193
  initial_plot, initial_status = ui_generate_plot_controller(
194
  default_metric, default_filter_type, default_tag, default_pipeline,
195
- default_param_indices, default_k, default_skip_orgs, current_df, progress
196
  )
197
 
198
  # Return all the necessary updates for the UI
199
  return current_df, load_success_flag, data_info_text, initial_status, initial_plot
200
 
201
  def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice,
202
- param_range_indices, k_orgs, skip_orgs_input, df_current_models, progress=gr.Progress()):
203
  if df_current_models is None or df_current_models.empty:
204
  return create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded"), "Model data is not loaded. Cannot generate plot."
205
 
@@ -212,7 +229,16 @@ with gr.Blocks(title="🤗 ModelVerse Explorer", fill_width=True, css=custom_css
212
  max_label = PARAM_CHOICES[int(param_range_indices[1])]
213
  param_labels_for_filtering = [min_label, max_label]
214
 
215
- treemap_df = make_treemap_data(df_current_models, metric_choice, k_orgs, tag_to_use, pipeline_to_use, param_labels_for_filtering, orgs_to_skip)
 
 
 
 
 
 
 
 
 
216
 
217
  progress(0.7, desc="Generating plot...")
218
  title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
@@ -237,7 +263,7 @@ with gr.Blocks(title="🤗 ModelVerse Explorer", fill_width=True, css=custom_css
237
  generate_plot_button.click(
238
  fn=ui_generate_plot_controller,
239
  inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
240
- param_range_slider, top_k_dropdown, skip_orgs_textbox, models_data_state],
241
  outputs=[plot_output, status_message_md]
242
  )
243
 
 
23
  dataset_dict = load_dataset(HF_DATASET_ID)
24
  df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas()
25
  if 'params' in df.columns:
26
+ # IMPORTANT CHANGE: Fill NaN/coerce errors with -1 to signify unknown size
27
+ # This aligns with the utility function's return of -1.0 for unknown sizes.
28
+ df['params'] = pd.to_numeric(df['params'], errors='coerce').fillna(-1)
29
  else:
30
+ # If 'params' column doesn't exist, assume all are unknown
31
+ df['params'] = -1
32
  msg = f"Successfully loaded dataset in {time.time() - overall_start_time:.2f}s."
33
  print(msg)
34
  return df, True, msg
 
43
  max_val = float('inf') if '>' in max_label else float(max_label.replace('B', ''))
44
  return min_val, max_val
45
 
46
+ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, param_range=None, skip_orgs=None, include_unknown_param_size=True):
47
  if df is None or df.empty: return pd.DataFrame()
48
  filtered_df = df.copy()
49
+
50
+ # New: Filter based on unknown parameter size
51
+ # If include_unknown_param_size is False, exclude models where params is -1 (unknown)
52
+ if not include_unknown_param_size and 'params' in filtered_df.columns:
53
+ filtered_df = filtered_df[filtered_df['params'] != -1]
54
+
55
  col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot", "Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science", "Video": "has_video", "Images": "has_image", "Text": "has_text" }
56
  if tag_filter and tag_filter in col_map and col_map[tag_filter] in filtered_df.columns:
57
  filtered_df = filtered_df[filtered_df[col_map[tag_filter]]]
 
60
  if param_range:
61
  min_params, max_params = get_param_range_values(param_range)
62
  is_default_range = (param_range[0] == PARAM_CHOICES[0] and param_range[1] == PARAM_CHOICES[-1])
63
+ # Apply parameter range filter only if it's not the default (all range) AND params column exists
64
+ # This filter will naturally exclude -1 if the min_params is >= 0, as it should.
65
  if not is_default_range and 'params' in filtered_df.columns:
66
  if min_params is not None: filtered_df = filtered_df[filtered_df['params'] >= min_params]
67
  if max_params is not None and max_params != float('inf'): filtered_df = filtered_df[filtered_df['params'] < max_params]
68
+
69
  if skip_orgs and len(skip_orgs) > 0 and "organization" in filtered_df.columns:
70
  filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
71
  if filtered_df.empty: return pd.DataFrame()
 
94
  #param-slider-wrapper div[data-testid="range-slider"] > span {
95
  display: none !important;
96
  }
 
97
  /*
98
  THIS IS THE KEY FIX:
99
  We target all the individual component containers (divs with class .block)
 
140
  elem_id="param-slider-wrapper"
141
  )
142
  param_range_display = gr.Markdown(f"Range: `{PARAM_CHOICES[0]}` to `{PARAM_CHOICES[-1]}`")
143
+ # New: Checkbox for including unknown parameter sizes
144
+ include_unknown_params_checkbox = gr.Checkbox(label="Include models with unknown parameter size", value=True)
145
 
146
  # This section remains un-grouped
147
  top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25)
 
179
  if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]):
180
  ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True)
181
  date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z')
182
+ # Count models where params is not -1 (known size)
183
+ param_count = (current_df['params'] != -1).sum() if 'params' in current_df.columns else 0
184
+ unknown_param_count = (current_df['params'] == -1).sum() if 'params' in current_df.columns else 0
185
+
186
+ data_info_text = f"### Data Information\n- Source: `{HF_DATASET_ID}`\n- Status: {status_msg_from_load}\n- Total models loaded: {len(current_df):,}\n- Models with known parameter counts: {param_count:,}\n- Models with unknown parameter counts: {unknown_param_count:,}\n- Data as of: {date_display}\n"
187
  else:
188
  data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
189
  except Exception as e:
 
194
  print(f"Critical error in load_and_generate_initial_plot: {e}")
195
 
196
  # --- Part 2: Generate Initial Plot ---
 
197
  progress(0.6, desc="Generating initial plot...")
198
  # Get default values directly from the UI component definitions
199
  default_metric = "downloads"
 
203
  default_param_indices = PARAM_CHOICES_DEFAULT_INDICES
204
  default_k = 25
205
  default_skip_orgs = "TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski"
206
+ # New default: include unknown params initially (matches checkbox default)
207
+ default_include_unknown_params = True
208
 
209
  # Reuse the existing controller function for plotting
210
  initial_plot, initial_status = ui_generate_plot_controller(
211
  default_metric, default_filter_type, default_tag, default_pipeline,
212
+ default_param_indices, default_k, default_skip_orgs, default_include_unknown_params, current_df, progress
213
  )
214
 
215
  # Return all the necessary updates for the UI
216
  return current_df, load_success_flag, data_info_text, initial_status, initial_plot
217
 
218
  def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice,
219
+ param_range_indices, k_orgs, skip_orgs_input, include_unknown_param_size_flag, df_current_models, progress=gr.Progress()):
220
  if df_current_models is None or df_current_models.empty:
221
  return create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded"), "Model data is not loaded. Cannot generate plot."
222
 
 
229
  max_label = PARAM_CHOICES[int(param_range_indices[1])]
230
  param_labels_for_filtering = [min_label, max_label]
231
 
232
+ treemap_df = make_treemap_data(
233
+ df_current_models,
234
+ metric_choice,
235
+ k_orgs,
236
+ tag_to_use,
237
+ pipeline_to_use,
238
+ param_labels_for_filtering,
239
+ orgs_to_skip,
240
+ include_unknown_param_size_flag # Pass the new flag
241
+ )
242
 
243
  progress(0.7, desc="Generating plot...")
244
  title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
 
263
  generate_plot_button.click(
264
  fn=ui_generate_plot_controller,
265
  inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
266
+ param_range_slider, top_k_dropdown, skip_orgs_textbox, include_unknown_params_checkbox, models_data_state], # Add checkbox to inputs
267
  outputs=[plot_output, status_message_md]
268
  )
269