Spaces:
Running
on
Zero
Running
on
Zero
2025-08-01 08:50 🐛
Browse files
app.py
CHANGED
@@ -20,8 +20,6 @@ mean = (0.485, 0.456, 0.406)
|
|
20 |
std = (0.229, 0.224, 0.225)
|
21 |
alpha = 0.8
|
22 |
EPS = 1e-8
|
23 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
-
device = torch.device("cuda")
|
25 |
loaded_model = None
|
26 |
current_model_config = {"variant": None, "dataset": None, "metric": None}
|
27 |
|
@@ -78,18 +76,18 @@ def update_model_if_needed(variant_dataset_metric: str):
|
|
78 |
else:
|
79 |
return f"Unknown dataset: {dataset}"
|
80 |
|
81 |
-
|
82 |
-
|
83 |
current_model_config["dataset"] != dataset_name or
|
84 |
current_model_config["metric"] != metric):
|
85 |
|
86 |
-
print(f"
|
87 |
-
loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
|
88 |
current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
|
89 |
-
|
|
|
90 |
else:
|
91 |
-
print(f"
|
92 |
-
return f"Model
|
93 |
|
94 |
|
95 |
# -----------------------------
|
@@ -305,13 +303,13 @@ def predict(image: Image.Image, variant_dataset_metric: str):
|
|
305 |
"""
|
306 |
global loaded_model, current_model_config
|
307 |
|
|
|
|
|
|
|
308 |
# 如果选择的是分割线,返回错误信息
|
309 |
if "━━━━━━" in variant_dataset_metric:
|
310 |
return image, None, None, "⚠️ Please select a valid model configuration", None, None, None
|
311 |
|
312 |
-
# 确保模型正确加载
|
313 |
-
update_model_if_needed(variant_dataset_metric)
|
314 |
-
|
315 |
parts = variant_dataset_metric.split(" @ ")
|
316 |
if len(parts) != 3:
|
317 |
return image, None, None, "❌ Invalid model configuration format", None, None, None
|
@@ -329,6 +327,16 @@ def predict(image: Image.Image, variant_dataset_metric: str):
|
|
329 |
else:
|
330 |
return image, None, None, f"❌ Unknown dataset: {dataset}", None, None, None
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
if not hasattr(loaded_model, "input_size"):
|
333 |
if dataset_name == "sha":
|
334 |
loaded_model.input_size = 224
|
@@ -758,9 +766,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="ZIP Crowd Counting") as d
|
|
758 |
outputs=[model_status]
|
759 |
)
|
760 |
|
761 |
-
#
|
762 |
demo.load(
|
763 |
-
fn=lambda: f"
|
764 |
outputs=[model_status]
|
765 |
)
|
766 |
|
|
|
20 |
std = (0.229, 0.224, 0.225)
|
21 |
alpha = 0.8
|
22 |
EPS = 1e-8
|
|
|
|
|
23 |
loaded_model = None
|
24 |
current_model_config = {"variant": None, "dataset": None, "metric": None}
|
25 |
|
|
|
76 |
else:
|
77 |
return f"Unknown dataset: {dataset}"
|
78 |
|
79 |
+
# 只更新配置,不在主进程中加载模型
|
80 |
+
if (current_model_config["variant"] != variant or
|
81 |
current_model_config["dataset"] != dataset_name or
|
82 |
current_model_config["metric"] != metric):
|
83 |
|
84 |
+
print(f"Model configuration updated: {variant} @ {dataset} with {metric} metric")
|
|
|
85 |
current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
|
86 |
+
loaded_model = None # 重置模型,将在GPU进程中重新加载
|
87 |
+
return f"Model configuration set: {variant} @ {dataset} ({metric})"
|
88 |
else:
|
89 |
+
print(f"Model configuration unchanged: {variant} @ {dataset} with {metric} metric")
|
90 |
+
return f"Model configuration: {variant} @ {dataset} ({metric})"
|
91 |
|
92 |
|
93 |
# -----------------------------
|
|
|
303 |
"""
|
304 |
global loaded_model, current_model_config
|
305 |
|
306 |
+
# 在GPU进程中定义device
|
307 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
308 |
+
|
309 |
# 如果选择的是分割线,返回错误信息
|
310 |
if "━━━━━━" in variant_dataset_metric:
|
311 |
return image, None, None, "⚠️ Please select a valid model configuration", None, None, None
|
312 |
|
|
|
|
|
|
|
313 |
parts = variant_dataset_metric.split(" @ ")
|
314 |
if len(parts) != 3:
|
315 |
return image, None, None, "❌ Invalid model configuration format", None, None, None
|
|
|
327 |
else:
|
328 |
return image, None, None, f"❌ Unknown dataset: {dataset}", None, None, None
|
329 |
|
330 |
+
# 在GPU进程中加载模型(如果需要)
|
331 |
+
if (loaded_model is None or
|
332 |
+
current_model_config["variant"] != variant or
|
333 |
+
current_model_config["dataset"] != dataset_name or
|
334 |
+
current_model_config["metric"] != metric):
|
335 |
+
|
336 |
+
print(f"Loading model in GPU process: {variant} @ {dataset} with {metric} metric")
|
337 |
+
loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
|
338 |
+
current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
|
339 |
+
|
340 |
if not hasattr(loaded_model, "input_size"):
|
341 |
if dataset_name == "sha":
|
342 |
loaded_model.input_size = 224
|
|
|
766 |
outputs=[model_status]
|
767 |
)
|
768 |
|
769 |
+
# 页面加载时设置默认模型配置(不在主进程中加载模型)
|
770 |
demo.load(
|
771 |
+
fn=lambda: f"✅ {update_model_if_needed('ZIP-B @ NWPU-Crowd @ MAE')}",
|
772 |
outputs=[model_status]
|
773 |
)
|
774 |
|