Spaces:
Running
on
Zero
Running
on
Zero
2025-07-31 21:23 π
Browse filesFixed a bug in app.py
app.py
CHANGED
@@ -36,7 +36,7 @@ pretrained_models = [
|
|
36 |
# -----------------------------
|
37 |
def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
|
38 |
""" Load the model weights from the Hugging Face Hub."""
|
39 |
-
global loaded_model
|
40 |
# Build model
|
41 |
|
42 |
model_info_path = hf_hub_download(
|
@@ -241,31 +241,30 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
|
|
241 |
Given an input image, preprocess it, run the model to obtain a density map,
|
242 |
compute the total crowd count, and prepare the density map for display.
|
243 |
"""
|
244 |
-
global loaded_model
|
245 |
variant, dataset = variant_dataset.split(" @ ")
|
246 |
|
247 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
-
|
250 |
-
dataset_name = "sha"
|
251 |
-
elif dataset == "ShanghaiTech B":
|
252 |
-
dataset_name = "shb"
|
253 |
-
elif dataset == "UCF-QNRF":
|
254 |
-
dataset_name = "qnrf"
|
255 |
-
elif dataset == "NWPU-Crowd":
|
256 |
-
dataset_name = "nwpu"
|
257 |
-
|
258 |
load_model(variant=variant, dataset=dataset_name, metric=metric)
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
|
270 |
loaded_model.to(device)
|
271 |
|
|
|
36 |
# -----------------------------
|
37 |
def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
|
38 |
""" Load the model weights from the Hugging Face Hub."""
|
39 |
+
# global loaded_model
|
40 |
# Build model
|
41 |
|
42 |
model_info_path = hf_hub_download(
|
|
|
241 |
Given an input image, preprocess it, run the model to obtain a density map,
|
242 |
compute the total crowd count, and prepare the density map for display.
|
243 |
"""
|
244 |
+
# global loaded_model
|
245 |
variant, dataset = variant_dataset.split(" @ ")
|
246 |
|
247 |
+
if dataset == "ShanghaiTech A":
|
248 |
+
dataset_name = "sha"
|
249 |
+
elif dataset == "ShanghaiTech B":
|
250 |
+
dataset_name = "shb"
|
251 |
+
elif dataset == "UCF-QNRF":
|
252 |
+
dataset_name = "qnrf"
|
253 |
+
elif dataset == "NWPU-Crowd":
|
254 |
+
dataset_name = "nwpu"
|
255 |
|
256 |
+
if loaded_model is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
load_model(variant=variant, dataset=dataset_name, metric=metric)
|
258 |
|
259 |
+
if not hasattr(loaded_model, "input_size"):
|
260 |
+
if dataset_name == "sha":
|
261 |
+
loaded_model.input_size = 224
|
262 |
+
elif dataset_name == "shb":
|
263 |
+
loaded_model.input_size = 448
|
264 |
+
elif dataset_name == "qnrf":
|
265 |
+
loaded_model.input_size = 672
|
266 |
+
elif dataset_name == "nwpu":
|
267 |
+
loaded_model.input_size = 672
|
268 |
|
269 |
loaded_model.to(device)
|
270 |
|