Yiming-M commited on
Commit
277670a
Β·
1 Parent(s): 99459bc

2025-07-31 21:34 πŸ›

Browse files
Files changed (1) hide show
  1. app.py +65 -2
app.py CHANGED
@@ -22,6 +22,7 @@ alpha = 0.8
22
  EPS = 1e-8
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  loaded_model = None
 
25
 
26
  pretrained_models = [
27
  "ZIP-B @ ShanghaiTech A", "ZIP-B @ ShanghaiTech B", "ZIP-B @ UCF-QNRF", "ZIP-B @ NWPU-Crowd",
@@ -31,6 +32,39 @@ pretrained_models = [
31
  "ZIP-P @ ShanghaiTech A", "ZIP-P @ ShanghaiTech B", "ZIP-P @ UCF-QNRF"
32
  ]
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # -----------------------------
35
  # Define the model architecture
36
  # -----------------------------
@@ -242,7 +276,7 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
242
  Given an input image, preprocess it, run the model to obtain a density map,
243
  compute the total crowd count, and prepare the density map for display.
244
  """
245
- # global loaded_model
246
  variant, dataset = variant_dataset.split(" @ ")
247
 
248
  if dataset == "ShanghaiTech A":
@@ -254,8 +288,16 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
254
  elif dataset == "NWPU-Crowd":
255
  dataset_name = "nwpu"
256
 
257
- if loaded_model is None:
 
 
 
 
 
258
  loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
 
 
 
259
 
260
  if not hasattr(loaded_model, "input_size"):
261
  if dataset_name == "sha":
@@ -409,6 +451,12 @@ with gr.Blocks() as demo:
409
  label="Select Best Metric"
410
  )
411
 
 
 
 
 
 
 
412
  input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil")
413
  submit_btn = gr.Button("Predict")
414
 
@@ -422,6 +470,21 @@ with gr.Blocks() as demo:
422
  output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
423
  output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  submit_btn.click(
426
  fn=predict,
427
  inputs=[input_img, model_dropdown, metric_dropdown],
 
22
  EPS = 1e-8
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  loaded_model = None
25
+ current_model_config = {"variant": None, "dataset": None, "metric": None}
26
 
27
  pretrained_models = [
28
  "ZIP-B @ ShanghaiTech A", "ZIP-B @ ShanghaiTech B", "ZIP-B @ UCF-QNRF", "ZIP-B @ NWPU-Crowd",
 
32
  "ZIP-P @ ShanghaiTech A", "ZIP-P @ ShanghaiTech B", "ZIP-P @ UCF-QNRF"
33
  ]
34
 
35
+ # -----------------------------
36
+ # Model management functions
37
+ # -----------------------------
38
+ def update_model_if_needed(variant_dataset: str, metric: str):
39
+ """
40
+ Load a new model only if the configuration has changed.
41
+ """
42
+ global loaded_model, current_model_config
43
+ variant, dataset = variant_dataset.split(" @ ")
44
+
45
+ if dataset == "ShanghaiTech A":
46
+ dataset_name = "sha"
47
+ elif dataset == "ShanghaiTech B":
48
+ dataset_name = "shb"
49
+ elif dataset == "UCF-QNRF":
50
+ dataset_name = "qnrf"
51
+ elif dataset == "NWPU-Crowd":
52
+ dataset_name = "nwpu"
53
+
54
+ if (loaded_model is None or
55
+ current_model_config["variant"] != variant or
56
+ current_model_config["dataset"] != dataset_name or
57
+ current_model_config["metric"] != metric):
58
+
59
+ print(f"Loading new model: {variant} @ {dataset} with {metric} metric")
60
+ loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
61
+ current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
62
+ return f"Model loaded: {variant} @ {dataset} ({metric})"
63
+ else:
64
+ print(f"Using cached model: {variant} @ {dataset} with {metric} metric")
65
+ return f"Model already loaded: {variant} @ {dataset} ({metric})"
66
+
67
+
68
  # -----------------------------
69
  # Define the model architecture
70
  # -----------------------------
 
276
  Given an input image, preprocess it, run the model to obtain a density map,
277
  compute the total crowd count, and prepare the density map for display.
278
  """
279
+ global loaded_model, current_model_config
280
  variant, dataset = variant_dataset.split(" @ ")
281
 
282
  if dataset == "ShanghaiTech A":
 
288
  elif dataset == "NWPU-Crowd":
289
  dataset_name = "nwpu"
290
 
291
+ if (loaded_model is None or
292
+ current_model_config["variant"] != variant or
293
+ current_model_config["dataset"] != dataset_name or
294
+ current_model_config["metric"] != metric):
295
+
296
+ print(f"Loading new model: {variant} @ {dataset} with {metric} metric")
297
  loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
298
+ current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
299
+ else:
300
+ print(f"Using cached model: {variant} @ {dataset} with {metric} metric")
301
 
302
  if not hasattr(loaded_model, "input_size"):
303
  if dataset_name == "sha":
 
451
  label="Select Best Metric"
452
  )
453
 
454
+ model_status = gr.Textbox(
455
+ label="Model Status",
456
+ value="No model loaded",
457
+ interactive=False
458
+ )
459
+
460
  input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil")
461
  submit_btn = gr.Button("Predict")
462
 
 
470
  output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
471
  output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
472
 
473
+ def on_model_change(variant_dataset, metric):
474
+ return update_model_if_needed(variant_dataset, metric)
475
+
476
+ model_dropdown.change(
477
+ fn=on_model_change,
478
+ inputs=[model_dropdown, metric_dropdown],
479
+ outputs=[model_status]
480
+ )
481
+
482
+ metric_dropdown.change(
483
+ fn=on_model_change,
484
+ inputs=[model_dropdown, metric_dropdown],
485
+ outputs=[model_status]
486
+ )
487
+
488
  submit_btn.click(
489
  fn=predict,
490
  inputs=[input_img, model_dropdown, metric_dropdown],