2nzi commited on
Commit
be41de4
·
1 Parent(s): 17a5b18

update api.py

Browse files
Files changed (1) hide show
  1. api.py +66 -21
api.py CHANGED
@@ -13,6 +13,7 @@ import torchvision.transforms as T
13
  import torchvision.transforms.functional as f
14
  import yaml
15
  from tqdm import tqdm
 
16
 
17
  from get_camera_params import get_camera_parameters
18
 
@@ -40,6 +41,7 @@ app.add_middleware(
40
  # Paramètres par défaut pour l'inférence
41
  WEIGHTS_KP = "models/SV_FT_TSWC_kp"
42
  WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
 
43
  DEVICE = "cpu"
44
  KP_THRESHOLD = 0.15
45
  LINE_THRESHOLD = 0.15
@@ -49,33 +51,76 @@ FRAME_STEP = 5
49
  # Cache pour les modèles (éviter de les recharger à chaque requête)
50
  _models_cache = None
51
 
 
 
 
 
 
52
  def load_inference_models():
53
- """Charge les modèles d'inférence (avec cache)"""
54
  global _models_cache
55
 
56
  if _models_cache is not None:
57
  return _models_cache
58
 
59
- device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
60
-
61
- # Charger les configurations
62
- cfg = yaml.safe_load(open("config/hrnetv2_w48.yaml", 'r'))
63
- cfg_l = yaml.safe_load(open("config/hrnetv2_w48_l.yaml", 'r'))
64
-
65
- # Modèle keypoints
66
- model = get_cls_net(cfg)
67
- model.load_state_dict(torch.load(WEIGHTS_KP, map_location=device))
68
- model.to(device)
69
- model.eval()
70
-
71
- # Modèle lignes
72
- model_l = get_cls_net_l(cfg_l)
73
- model_l.load_state_dict(torch.load(WEIGHTS_LINE, map_location=device))
74
- model_l.to(device)
75
- model_l.eval()
76
-
77
- _models_cache = (model, model_l, device)
78
- return _models_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
81
  """Traite une frame et retourne les paramètres de caméra"""
 
13
  import torchvision.transforms.functional as f
14
  import yaml
15
  from tqdm import tqdm
16
+ from huggingface_hub import hf_hub_download
17
 
18
  from get_camera_params import get_camera_parameters
19
 
 
41
  # Paramètres par défaut pour l'inférence
42
  WEIGHTS_KP = "models/SV_FT_TSWC_kp"
43
  WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
44
+ # DEVICE = "cuda:0"
45
  DEVICE = "cpu"
46
  KP_THRESHOLD = 0.15
47
  LINE_THRESHOLD = 0.15
 
51
  # Cache pour les modèles (éviter de les recharger à chaque requête)
52
  _models_cache = None
53
 
54
+ # Paramètres pour HF Hub
55
+ HF_MODEL_REPO = "2nzi/SV_FT_TSWC_kp" # Remplacez par votre repo
56
+ WEIGHTS_KP_FILE = "SV_FT_TSWC_kp" # Nom du fichier dans le repo
57
+ WEIGHTS_LINE_FILE = "SV_FT_TSWC_lines" # Nom du fichier dans le repo
58
+
59
  def load_inference_models():
60
+ """Charge les modèles d'inférence depuis Hugging Face Hub"""
61
  global _models_cache
62
 
63
  if _models_cache is not None:
64
  return _models_cache
65
 
66
+ try:
67
+ # Device detection
68
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+ print(f"Using device: {device}")
70
+
71
+ # Télécharger les modèles depuis HF Hub
72
+ print("Téléchargement des modèles depuis Hugging Face Hub...")
73
+
74
+ weights_kp_path = hf_hub_download(
75
+ repo_id=HF_MODEL_REPO,
76
+ filename=WEIGHTS_KP_FILE,
77
+ cache_dir="./hf_cache"
78
+ )
79
+
80
+ weights_line_path = hf_hub_download(
81
+ repo_id=HF_MODEL_REPO,
82
+ filename=WEIGHTS_LINE_FILE,
83
+ cache_dir="./hf_cache"
84
+ )
85
+
86
+ print(f"Modèles téléchargés:")
87
+ print(f" - Keypoints: {weights_kp_path}")
88
+ print(f" - Lines: {weights_line_path}")
89
+
90
+ # Vérifier l'existence des fichiers de configuration
91
+ config_files = ["config/hrnetv2_w48.yaml", "config/hrnetv2_w48_l.yaml"]
92
+ for config_file in config_files:
93
+ if not os.path.exists(config_file):
94
+ raise FileNotFoundError(f"Fichier de configuration manquant: {config_file}")
95
+
96
+ # Charger les configurations
97
+ with open("config/hrnetv2_w48.yaml", 'r') as f:
98
+ cfg = yaml.safe_load(f)
99
+ with open("config/hrnetv2_w48_l.yaml", 'r') as f:
100
+ cfg_l = yaml.safe_load(f)
101
+
102
+ # Modèle keypoints
103
+ model = get_cls_net(cfg)
104
+ model.load_state_dict(torch.load(weights_kp_path, map_location=device))
105
+ model.to(device)
106
+ model.eval()
107
+
108
+ # Modèle lignes
109
+ model_l = get_cls_net_l(cfg_l)
110
+ model_l.load_state_dict(torch.load(weights_line_path, map_location=device))
111
+ model_l.to(device)
112
+ model_l.eval()
113
+
114
+ _models_cache = (model, model_l, device)
115
+ print("✅ Modèles chargés avec succès depuis HF Hub!")
116
+ return _models_cache
117
+
118
+ except Exception as e:
119
+ print(f"❌ Erreur lors du chargement des modèles: {e}")
120
+ raise HTTPException(
121
+ status_code=503,
122
+ detail=f"Modèles non disponibles: {str(e)}. Veuillez réessayer plus tard."
123
+ )
124
 
125
  def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
126
  """Traite une frame et retourne les paramètres de caméra"""