Yiming-M commited on
Commit
61ba3dd
·
1 Parent(s): 9a167d2

2025-08-01 08:50 🐛

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -20,8 +20,6 @@ mean = (0.485, 0.456, 0.406)
20
  std = (0.229, 0.224, 0.225)
21
  alpha = 0.8
22
  EPS = 1e-8
23
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- device = torch.device("cuda")
25
  loaded_model = None
26
  current_model_config = {"variant": None, "dataset": None, "metric": None}
27
 
@@ -78,18 +76,18 @@ def update_model_if_needed(variant_dataset_metric: str):
78
  else:
79
  return f"Unknown dataset: {dataset}"
80
 
81
- if (loaded_model is None or
82
- current_model_config["variant"] != variant or
83
  current_model_config["dataset"] != dataset_name or
84
  current_model_config["metric"] != metric):
85
 
86
- print(f"Loading new model: {variant} @ {dataset} with {metric} metric")
87
- loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
88
  current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
89
- return f"Model loaded: {variant} @ {dataset} ({metric})"
 
90
  else:
91
- print(f"Using cached model: {variant} @ {dataset} with {metric} metric")
92
- return f"Model already loaded: {variant} @ {dataset} ({metric})"
93
 
94
 
95
  # -----------------------------
@@ -305,13 +303,13 @@ def predict(image: Image.Image, variant_dataset_metric: str):
305
  """
306
  global loaded_model, current_model_config
307
 
 
 
 
308
  # 如果选择的是分割线,返回错误信息
309
  if "━━━━━━" in variant_dataset_metric:
310
  return image, None, None, "⚠️ Please select a valid model configuration", None, None, None
311
 
312
- # 确保模型正确加载
313
- update_model_if_needed(variant_dataset_metric)
314
-
315
  parts = variant_dataset_metric.split(" @ ")
316
  if len(parts) != 3:
317
  return image, None, None, "❌ Invalid model configuration format", None, None, None
@@ -329,6 +327,16 @@ def predict(image: Image.Image, variant_dataset_metric: str):
329
  else:
330
  return image, None, None, f"❌ Unknown dataset: {dataset}", None, None, None
331
 
 
 
 
 
 
 
 
 
 
 
332
  if not hasattr(loaded_model, "input_size"):
333
  if dataset_name == "sha":
334
  loaded_model.input_size = 224
@@ -758,9 +766,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="ZIP Crowd Counting") as d
758
  outputs=[model_status]
759
  )
760
 
761
- # 页面加载时自动加载默认模型
762
  demo.load(
763
- fn=lambda: f"🔄 {update_model_if_needed('ZIP-B @ NWPU-Crowd @ MAE')}",
764
  outputs=[model_status]
765
  )
766
 
 
20
  std = (0.229, 0.224, 0.225)
21
  alpha = 0.8
22
  EPS = 1e-8
 
 
23
  loaded_model = None
24
  current_model_config = {"variant": None, "dataset": None, "metric": None}
25
 
 
76
  else:
77
  return f"Unknown dataset: {dataset}"
78
 
79
+ # 只更新配置,不在主进程中加载模型
80
+ if (current_model_config["variant"] != variant or
81
  current_model_config["dataset"] != dataset_name or
82
  current_model_config["metric"] != metric):
83
 
84
+ print(f"Model configuration updated: {variant} @ {dataset} with {metric} metric")
 
85
  current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
86
+ loaded_model = None # 重置模型,将在GPU进程中重新加载
87
+ return f"Model configuration set: {variant} @ {dataset} ({metric})"
88
  else:
89
+ print(f"Model configuration unchanged: {variant} @ {dataset} with {metric} metric")
90
+ return f"Model configuration: {variant} @ {dataset} ({metric})"
91
 
92
 
93
  # -----------------------------
 
303
  """
304
  global loaded_model, current_model_config
305
 
306
+ # 在GPU进程中定义device
307
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
308
+
309
  # 如果选择的是分割线,返回错误信息
310
  if "━━━━━━" in variant_dataset_metric:
311
  return image, None, None, "⚠️ Please select a valid model configuration", None, None, None
312
 
 
 
 
313
  parts = variant_dataset_metric.split(" @ ")
314
  if len(parts) != 3:
315
  return image, None, None, "❌ Invalid model configuration format", None, None, None
 
327
  else:
328
  return image, None, None, f"❌ Unknown dataset: {dataset}", None, None, None
329
 
330
+ # 在GPU进程中加载模型(如果需要)
331
+ if (loaded_model is None or
332
+ current_model_config["variant"] != variant or
333
+ current_model_config["dataset"] != dataset_name or
334
+ current_model_config["metric"] != metric):
335
+
336
+ print(f"Loading model in GPU process: {variant} @ {dataset} with {metric} metric")
337
+ loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
338
+ current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
339
+
340
  if not hasattr(loaded_model, "input_size"):
341
  if dataset_name == "sha":
342
  loaded_model.input_size = 224
 
766
  outputs=[model_status]
767
  )
768
 
769
+ # 页面加载时设置默认模型配置(不在主进程中加载模型)
770
  demo.load(
771
+ fn=lambda: f" {update_model_if_needed('ZIP-B @ NWPU-Crowd @ MAE')}",
772
  outputs=[model_status]
773
  )
774