Yiming-M commited on
Commit
ef5adb7
·
1 Parent(s): 277670a

2025-07-31 21:38 🐛

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -277,6 +277,10 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
277
  compute the total crowd count, and prepare the density map for display.
278
  """
279
  global loaded_model, current_model_config
 
 
 
 
280
  variant, dataset = variant_dataset.split(" @ ")
281
 
282
  if dataset == "ShanghaiTech A":
@@ -288,17 +292,6 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
288
  elif dataset == "NWPU-Crowd":
289
  dataset_name = "nwpu"
290
 
291
- if (loaded_model is None or
292
- current_model_config["variant"] != variant or
293
- current_model_config["dataset"] != dataset_name or
294
- current_model_config["metric"] != metric):
295
-
296
- print(f"Loading new model: {variant} @ {dataset} with {metric} metric")
297
- loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
298
- current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
299
- else:
300
- print(f"Using cached model: {variant} @ {dataset} with {metric} metric")
301
-
302
  if not hasattr(loaded_model, "input_size"):
303
  if dataset_name == "sha":
304
  loaded_model.input_size = 224
@@ -470,6 +463,7 @@ with gr.Blocks() as demo:
470
  output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
471
  output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
472
 
 
473
  def on_model_change(variant_dataset, metric):
474
  return update_model_if_needed(variant_dataset, metric)
475
 
@@ -485,6 +479,12 @@ with gr.Blocks() as demo:
485
  outputs=[model_status]
486
  )
487
 
 
 
 
 
 
 
488
  submit_btn.click(
489
  fn=predict,
490
  inputs=[input_img, model_dropdown, metric_dropdown],
 
277
  compute the total crowd count, and prepare the density map for display.
278
  """
279
  global loaded_model, current_model_config
280
+
281
+ # 确保模型正确加载
282
+ update_model_if_needed(variant_dataset, metric)
283
+
284
  variant, dataset = variant_dataset.split(" @ ")
285
 
286
  if dataset == "ShanghaiTech A":
 
292
  elif dataset == "NWPU-Crowd":
293
  dataset_name = "nwpu"
294
 
 
 
 
 
 
 
 
 
 
 
 
295
  if not hasattr(loaded_model, "input_size"):
296
  if dataset_name == "sha":
297
  loaded_model.input_size = 224
 
463
  output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
464
  output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
465
 
466
+ # 当模型或度量参数变化时,自动更新模型
467
  def on_model_change(variant_dataset, metric):
468
  return update_model_if_needed(variant_dataset, metric)
469
 
 
479
  outputs=[model_status]
480
  )
481
 
482
+ # 页面加载时自动加载默认模型
483
+ demo.load(
484
+ fn=lambda: update_model_if_needed("ZIP-B @ NWPU-Crowd", "mae"),
485
+ outputs=[model_status]
486
+ )
487
+
488
  submit_btn.click(
489
  fn=predict,
490
  inputs=[input_img, model_dropdown, metric_dropdown],