Yiming-M commited on
Commit
ccc542d
Β·
1 Parent(s): 449e3d1

2025-07-31 21:23 πŸ›

Browse files

Fixed a bug in app.py

Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -36,7 +36,7 @@ pretrained_models = [
36
  # -----------------------------
37
  def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
38
  """ Load the model weights from the Hugging Face Hub."""
39
- global loaded_model
40
  # Build model
41
 
42
  model_info_path = hf_hub_download(
@@ -241,31 +241,30 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
241
  Given an input image, preprocess it, run the model to obtain a density map,
242
  compute the total crowd count, and prepare the density map for display.
243
  """
244
- global loaded_model
245
  variant, dataset = variant_dataset.split(" @ ")
246
 
247
- if loaded_model is None:
 
 
 
 
 
 
 
248
 
249
- if dataset == "ShanghaiTech A":
250
- dataset_name = "sha"
251
- elif dataset == "ShanghaiTech B":
252
- dataset_name = "shb"
253
- elif dataset == "UCF-QNRF":
254
- dataset_name = "qnrf"
255
- elif dataset == "NWPU-Crowd":
256
- dataset_name = "nwpu"
257
-
258
  load_model(variant=variant, dataset=dataset_name, metric=metric)
259
 
260
- if not hasattr(loaded_model, "input_size"):
261
- if dataset_name == "sha":
262
- loaded_model.input_size = 224
263
- elif dataset_name == "shb":
264
- loaded_model.input_size = 448
265
- elif dataset_name == "qnrf":
266
- loaded_model.input_size = 672
267
- elif dataset_name == "nwpu":
268
- loaded_model.input_size = 672
269
 
270
  loaded_model.to(device)
271
 
 
36
  # -----------------------------
37
  def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
38
  """ Load the model weights from the Hugging Face Hub."""
39
+ # global loaded_model
40
  # Build model
41
 
42
  model_info_path = hf_hub_download(
 
241
  Given an input image, preprocess it, run the model to obtain a density map,
242
  compute the total crowd count, and prepare the density map for display.
243
  """
244
+ # global loaded_model
245
  variant, dataset = variant_dataset.split(" @ ")
246
 
247
+ if dataset == "ShanghaiTech A":
248
+ dataset_name = "sha"
249
+ elif dataset == "ShanghaiTech B":
250
+ dataset_name = "shb"
251
+ elif dataset == "UCF-QNRF":
252
+ dataset_name = "qnrf"
253
+ elif dataset == "NWPU-Crowd":
254
+ dataset_name = "nwpu"
255
 
256
+ if loaded_model is None:
 
 
 
 
 
 
 
 
257
  load_model(variant=variant, dataset=dataset_name, metric=metric)
258
 
259
+ if not hasattr(loaded_model, "input_size"):
260
+ if dataset_name == "sha":
261
+ loaded_model.input_size = 224
262
+ elif dataset_name == "shb":
263
+ loaded_model.input_size = 448
264
+ elif dataset_name == "qnrf":
265
+ loaded_model.input_size = 672
266
+ elif dataset_name == "nwpu":
267
+ loaded_model.input_size = 672
268
 
269
  loaded_model.to(device)
270