2nzi commited on
Commit
9a5e1bd
·
1 Parent(s): 1892c37

update params api.py

Browse files
Files changed (1) hide show
  1. api.py +484 -483
api.py CHANGED
@@ -1,484 +1,485 @@
1
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from typing import Dict, List, Any, Optional
5
- import json
6
- import tempfile
7
- import os
8
- from PIL import Image
9
- import numpy as np
10
- import cv2
11
- import torch
12
- import torchvision.transforms as T
13
- import torchvision.transforms.functional as f
14
- import yaml
15
- from tqdm import tqdm
16
-
17
- from get_camera_params import get_camera_parameters
18
-
19
- # Imports pour l'inférence automatique
20
- from model.cls_hrnet import get_cls_net
21
- from model.cls_hrnet_l import get_cls_net as get_cls_net_l
22
- from utils.utils_calib import FramebyFrameCalib
23
- from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, complete_keypoints, coords_to_dict
24
-
25
- app = FastAPI(
26
- title="Football Vision Calibration API",
27
- description="API pour la calibration de caméras à partir de lignes de terrain de football",
28
- version="1.0.0"
29
- )
30
-
31
- # Configuration CORS pour autoriser les requêtes depuis le frontend
32
- app.add_middleware(
33
- CORSMiddleware,
34
- allow_origins=["*"], # En production, spécifiez les domaines autorisés
35
- allow_credentials=True,
36
- allow_methods=["*"],
37
- allow_headers=["*"],
38
- )
39
-
40
- # Paramètres par défaut pour l'inférence
41
- WEIGHTS_KP = "models/SV_FT_TSWC_kp"
42
- WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
43
- DEVICE = "cuda:0"
44
- KP_THRESHOLD = 0.15
45
- LINE_THRESHOLD = 0.15
46
- PNL_REFINE = True
47
- FRAME_STEP = 5
48
-
49
- # Cache pour les modèles (éviter de les recharger à chaque requête)
50
- _models_cache = None
51
-
52
- def load_inference_models():
53
- """Charge les modèles d'inférence (avec cache)"""
54
- global _models_cache
55
-
56
- if _models_cache is not None:
57
- return _models_cache
58
-
59
- device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
60
-
61
- # Charger les configurations
62
- cfg = yaml.safe_load(open("config/hrnetv2_w48.yaml", 'r'))
63
- cfg_l = yaml.safe_load(open("config/hrnetv2_w48_l.yaml", 'r'))
64
-
65
- # Modèle keypoints
66
- model = get_cls_net(cfg)
67
- model.load_state_dict(torch.load(WEIGHTS_KP, map_location=device))
68
- model.to(device)
69
- model.eval()
70
-
71
- # Modèle lignes
72
- model_l = get_cls_net_l(cfg_l)
73
- model_l.load_state_dict(torch.load(WEIGHTS_LINE, map_location=device))
74
- model_l.to(device)
75
- model_l.eval()
76
-
77
- _models_cache = (model, model_l, device)
78
- return _models_cache
79
-
80
- def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
81
- """Traite une frame et retourne les paramètres de caméra"""
82
- transform = T.Resize((540, 960))
83
-
84
- # Préparer la frame pour l'inférence
85
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
- frame_pil = Image.fromarray(frame_rgb)
87
- frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0)
88
-
89
- if frame_tensor.size()[-1] != 960:
90
- frame_tensor = transform(frame_tensor)
91
-
92
- frame_tensor = frame_tensor.to(device)
93
- b, c, h, w = frame_tensor.size()
94
-
95
- # Inférence
96
- with torch.no_grad():
97
- heatmaps = model(frame_tensor)
98
- heatmaps_l = model_l(frame_tensor)
99
-
100
- # Extraire les keypoints et lignes
101
- kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
102
- line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
103
- kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD)
104
- lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD)
105
- kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
106
-
107
- # Calibration
108
- cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
109
- cam.update(kp_dict, lines_dict)
110
- final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE)
111
-
112
- return final_params_dict
113
-
114
- # Modèles Pydantic pour la validation des données
115
- class Point(BaseModel):
116
- x: float
117
- y: float
118
-
119
- class LinePolygon(BaseModel):
120
- points: List[Point]
121
-
122
- class CalibrationRequest(BaseModel):
123
- lines: Dict[str, List[Point]]
124
-
125
- class CalibrationResponse(BaseModel):
126
- status: str
127
- camera_parameters: Dict[str, Any]
128
- input_lines: Dict[str, List[Point]]
129
- message: str
130
-
131
- class InferenceImageResponse(BaseModel):
132
- status: str
133
- camera_parameters: Optional[Dict[str, Any]]
134
- image_info: Dict[str, Any]
135
- message: str
136
-
137
- class InferenceVideoResponse(BaseModel):
138
- status: str
139
- camera_parameters: List[Dict[str, Any]]
140
- video_info: Dict[str, Any]
141
- frames_processed: int
142
- message: str
143
-
144
- @app.get("/")
145
- async def root():
146
- return {
147
- "message": "Football Vision Calibration API",
148
- "version": "1.0.0",
149
- "endpoints": {
150
- "/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
151
- "/inference/image": "POST - Extraire les paramètres de caméra d'une image automatiquement",
152
- "/inference/video": "POST - Extraire les paramètres de caméra d'une vidéo automatiquement",
153
- "/health": "GET - Vérifier l'état de l'API"
154
- }
155
- }
156
-
157
- @app.get("/health")
158
- async def health_check():
159
- return {"status": "healthy", "message": "API is running"}
160
-
161
- @app.post("/calibrate", response_model=CalibrationResponse)
162
- async def calibrate_camera(
163
- image: UploadFile = File(..., description="Image du terrain de football"),
164
- lines_data: str = Form(..., description="JSON des lignes du terrain")
165
- ):
166
- """
167
- Calibrer une caméra à partir d'une image et des lignes du terrain.
168
-
169
- Args:
170
- image: Image du terrain de football (formats: jpg, jpeg, png)
171
- lines_data: JSON contenant les lignes du terrain au format:
172
- {"nom_ligne": [{"x": float, "y": float}, ...], ...}
173
-
174
- Returns:
175
- Paramètres de calibration de la caméra et lignes d'entrée
176
- """
177
- try:
178
- # Validation du format d'image - version robuste
179
- content_type = getattr(image, 'content_type', None) or ""
180
- filename = getattr(image, 'filename', "") or ""
181
-
182
- # Vérifier le type MIME ou l'extension du fichier
183
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
184
- is_image_content = content_type.startswith('image/') if content_type else False
185
- is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
186
-
187
- if not is_image_content and not is_image_extension:
188
- raise HTTPException(
189
- status_code=400,
190
- detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
191
- )
192
-
193
- # Parse des données de lignes
194
- try:
195
- lines_dict = json.loads(lines_data)
196
- except json.JSONDecodeError:
197
- raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
198
-
199
- # Validation de la structure des lignes
200
- validated_lines = {}
201
- for line_name, points in lines_dict.items():
202
- if not isinstance(points, list):
203
- raise HTTPException(
204
- status_code=400,
205
- detail=f"Les points de la ligne '{line_name}' doivent être une liste"
206
- )
207
-
208
- validated_points = []
209
- for i, point in enumerate(points):
210
- if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
211
- raise HTTPException(
212
- status_code=400,
213
- detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
214
- )
215
- try:
216
- validated_points.append({
217
- "x": float(point['x']),
218
- "y": float(point['y'])
219
- })
220
- except (ValueError, TypeError):
221
- raise HTTPException(
222
- status_code=400,
223
- detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
224
- )
225
-
226
- validated_lines[line_name] = validated_points
227
-
228
- # Sauvegarde temporaire de l'image
229
- file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
230
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
231
- content = await image.read()
232
- temp_file.write(content)
233
- temp_image_path = temp_file.name
234
-
235
- try:
236
- # Validation de l'image
237
- pil_image = Image.open(temp_image_path)
238
- pil_image.verify() # Vérification de l'intégrité de l'image
239
-
240
- # Calibration de la caméra
241
- camera_params = get_camera_parameters(temp_image_path, validated_lines)
242
-
243
- # Formatage de la réponse
244
- response = CalibrationResponse(
245
- status="success",
246
- camera_parameters=camera_params,
247
- input_lines=validated_lines,
248
- message="Calibration réussie"
249
- )
250
-
251
- return response
252
-
253
- except Exception as e:
254
- raise HTTPException(
255
- status_code=500,
256
- detail=f"Erreur lors de la calibration: {str(e)}"
257
- )
258
-
259
- finally:
260
- # Nettoyage du fichier temporaire
261
- if os.path.exists(temp_image_path):
262
- os.unlink(temp_image_path)
263
-
264
- except HTTPException:
265
- raise
266
- except Exception as e:
267
- raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
268
-
269
- @app.post("/inference/image", response_model=InferenceImageResponse)
270
- async def inference_image(
271
- image: UploadFile = File(..., description="Image du terrain de football"),
272
- kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
273
- line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes")
274
- ):
275
- """
276
- Extraire automatiquement les paramètres de caméra à partir d'une image.
277
-
278
- Args:
279
- image: Image du terrain de football (formats: jpg, jpeg, png)
280
- kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
281
- line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
282
-
283
- Returns:
284
- Paramètres de calibration de la caméra extraits automatiquement
285
- """
286
- try:
287
- # Validation du format d'image - version robuste
288
- content_type = getattr(image, 'content_type', None) or ""
289
- filename = getattr(image, 'filename', "") or ""
290
-
291
- # Vérifier le type MIME ou l'extension du fichier
292
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
293
- is_image_content = content_type.startswith('image/') if content_type else False
294
- is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
295
-
296
- if not is_image_content and not is_image_extension:
297
- raise HTTPException(
298
- status_code=400,
299
- detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
300
- )
301
-
302
- # Sauvegarde temporaire de l'image
303
- file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
304
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
305
- content = await image.read()
306
- temp_file.write(content)
307
- temp_image_path = temp_file.name
308
-
309
- try:
310
- # Charger les modèles
311
- model, model_l, device = load_inference_models()
312
-
313
- # Lire l'image
314
- frame = cv2.imread(temp_image_path)
315
- if frame is None:
316
- raise HTTPException(status_code=400, detail="Impossible de lire l'image")
317
-
318
- frame_height, frame_width = frame.shape[:2]
319
-
320
- # Mettre à jour les seuils globaux
321
- global KP_THRESHOLD, LINE_THRESHOLD
322
- KP_THRESHOLD = kp_threshold
323
- LINE_THRESHOLD = line_threshold
324
-
325
- # Traitement
326
- params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
327
- # Formatage de la réponse
328
- response = InferenceImageResponse(
329
- status="success" if params is not None else "failed",
330
- camera_parameters=params,
331
- image_info={
332
- "filename": filename,
333
- "width": frame_width,
334
- "height": frame_height,
335
- "kp_threshold": kp_threshold,
336
- "line_threshold": line_threshold
337
- },
338
- message="Paramètres extraits avec succès" if params is not None else "Échec de l'extraction des paramètres"
339
- )
340
-
341
- return response
342
-
343
- except Exception as e:
344
- raise HTTPException(
345
- status_code=500,
346
- detail=f"Erreur lors de l'inférence: {str(e)} \n params:\n{params}"
347
- )
348
-
349
- finally:
350
- # Nettoyage du fichier temporaire
351
- if os.path.exists(temp_image_path):
352
- os.unlink(temp_image_path)
353
-
354
- except HTTPException:
355
- raise
356
- except Exception as e:
357
- raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
358
-
359
- @app.post("/inference/video", response_model=InferenceVideoResponse)
360
- async def inference_video(
361
- video: UploadFile = File(..., description="Vidéo du terrain de football"),
362
- kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
363
- line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes"),
364
- frame_step: int = Form(FRAME_STEP, description="Traiter 1 frame sur N")
365
- ):
366
- """
367
- Extraire automatiquement les paramètres de caméra à partir d'une vidéo.
368
-
369
- Args:
370
- video: Vidéo du terrain de football (formats: mp4, avi, mov, etc.)
371
- kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
372
- line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
373
- frame_step: Traiter 1 frame sur N pour accélérer le traitement (défaut: 5)
374
-
375
- Returns:
376
- Liste des paramètres de calibration de la caméra pour chaque frame traitée
377
- """
378
- try:
379
- # Validation du format vidéo - version robuste
380
- content_type = getattr(video, 'content_type', None) or ""
381
- filename = getattr(video, 'filename', "") or ""
382
-
383
- video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
384
- is_video_content = content_type.startswith('video/') if content_type else False
385
- is_video_extension = any(filename.lower().endswith(ext) for ext in video_extensions)
386
-
387
- if not is_video_content and not is_video_extension:
388
- raise HTTPException(
389
- status_code=400,
390
- detail=f"Le fichier doit être une vidéo. Type détecté: {content_type}, Fichier: {filename}"
391
- )
392
-
393
- # Sauvegarde temporaire de la vidéo
394
- file_extension = os.path.splitext(filename)[1] if filename else '.mp4'
395
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
396
- content = await video.read()
397
- temp_file.write(content)
398
- temp_video_path = temp_file.name
399
-
400
- try:
401
- # Charger les modèles
402
- model, model_l, device = load_inference_models()
403
-
404
- # Ouvrir la vidéo
405
- cap = cv2.VideoCapture(temp_video_path)
406
- if not cap.isOpened():
407
- raise HTTPException(status_code=400, detail="Impossible d'ouvrir la vidéo")
408
-
409
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
410
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
411
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
412
- fps = int(cap.get(cv2.CAP_PROP_FPS))
413
-
414
- # Mettre à jour les seuils globaux
415
- global KP_THRESHOLD, LINE_THRESHOLD
416
- KP_THRESHOLD = kp_threshold
417
- LINE_THRESHOLD = line_threshold
418
-
419
- all_params = []
420
- frame_count = 0
421
- processed_count = 0
422
-
423
- while cap.isOpened():
424
- ret, frame = cap.read()
425
- if not ret:
426
- break
427
-
428
- # Traiter seulement 1 frame sur frame_step
429
- if frame_count % frame_step != 0:
430
- frame_count += 1
431
- continue
432
-
433
- # Traitement
434
- params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
435
-
436
- if params is not None:
437
- params['frame_number'] = frame_count
438
- params['timestamp_seconds'] = frame_count / fps
439
- all_params.append(params)
440
- processed_count += 1
441
-
442
- frame_count += 1
443
-
444
- cap.release()
445
-
446
- # Formatage de la réponse
447
- response = InferenceVideoResponse(
448
- status="success" if all_params else "failed",
449
- camera_parameters=all_params,
450
- video_info={
451
- "filename": filename,
452
- "width": frame_width,
453
- "height": frame_height,
454
- "total_frames": total_frames,
455
- "fps": fps,
456
- "duration_seconds": total_frames / fps,
457
- "kp_threshold": kp_threshold,
458
- "line_threshold": line_threshold,
459
- "frame_step": frame_step
460
- },
461
- frames_processed=processed_count,
462
- message=f"Paramètres extraits de {processed_count} frames" if all_params else "Aucun paramètre extrait"
463
- )
464
-
465
- return response
466
-
467
- except Exception as e:
468
- raise HTTPException(
469
- status_code=500,
470
- detail=f"Erreur lors de l'inférence vidéo: {str(e)}"
471
- )
472
-
473
- finally:
474
- # Nettoyage du fichier temporaire
475
- if os.path.exists(temp_video_path):
476
- os.unlink(temp_video_path)
477
-
478
- except HTTPException:
479
- raise
480
- except Exception as e:
481
- raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
482
-
483
- # Point d'entrée pour Vercel
 
484
  app_instance = app
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Dict, List, Any, Optional
5
+ import json
6
+ import tempfile
7
+ import os
8
+ from PIL import Image
9
+ import numpy as np
10
+ import cv2
11
+ import torch
12
+ import torchvision.transforms as T
13
+ import torchvision.transforms.functional as f
14
+ import yaml
15
+ from tqdm import tqdm
16
+
17
+ from get_camera_params import get_camera_parameters
18
+
19
+ # Imports pour l'inférence automatique
20
+ from model.cls_hrnet import get_cls_net
21
+ from model.cls_hrnet_l import get_cls_net as get_cls_net_l
22
+ from utils.utils_calib import FramebyFrameCalib
23
+ from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, complete_keypoints, coords_to_dict
24
+
25
+ app = FastAPI(
26
+ title="Football Vision Calibration API",
27
+ description="API pour la calibration de caméras à partir de lignes de terrain de football",
28
+ version="1.0.0"
29
+ )
30
+
31
+ # Configuration CORS pour autoriser les requêtes depuis le frontend
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"], # En production, spécifiez les domaines autorisés
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+
40
+ # Paramètres par défaut pour l'inférence
41
+ WEIGHTS_KP = "models/SV_FT_TSWC_kp"
42
+ WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
43
+ DEVICE = "cuda:0"
44
+ KP_THRESHOLD = 0.15
45
+ LINE_THRESHOLD = 0.15
46
+ PNL_REFINE = True
47
+ FRAME_STEP = 5
48
+
49
+ # Cache pour les modèles (éviter de les recharger à chaque requête)
50
+ _models_cache = None
51
+
52
+ def load_inference_models():
53
+ """Charge les modèles d'inférence (avec cache)"""
54
+ global _models_cache
55
+
56
+ if _models_cache is not None:
57
+ return _models_cache
58
+
59
+ device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
60
+
61
+ # Charger les configurations
62
+ cfg = yaml.safe_load(open("config/hrnetv2_w48.yaml", 'r'))
63
+ cfg_l = yaml.safe_load(open("config/hrnetv2_w48_l.yaml", 'r'))
64
+
65
+ # Modèle keypoints
66
+ model = get_cls_net(cfg)
67
+ model.load_state_dict(torch.load(WEIGHTS_KP, map_location=device))
68
+ model.to(device)
69
+ model.eval()
70
+
71
+ # Modèle lignes
72
+ model_l = get_cls_net_l(cfg_l)
73
+ model_l.load_state_dict(torch.load(WEIGHTS_LINE, map_location=device))
74
+ model_l.to(device)
75
+ model_l.eval()
76
+
77
+ _models_cache = (model, model_l, device)
78
+ return _models_cache
79
+
80
+ def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
81
+ """Traite une frame et retourne les paramètres de caméra"""
82
+ transform = T.Resize((540, 960))
83
+
84
+ # Préparer la frame pour l'inférence
85
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+ frame_pil = Image.fromarray(frame_rgb)
87
+ frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0)
88
+
89
+ if frame_tensor.size()[-1] != 960:
90
+ frame_tensor = transform(frame_tensor)
91
+
92
+ frame_tensor = frame_tensor.to(device)
93
+ b, c, h, w = frame_tensor.size()
94
+
95
+ # Inférence
96
+ with torch.no_grad():
97
+ heatmaps = model(frame_tensor)
98
+ heatmaps_l = model_l(frame_tensor)
99
+
100
+ # Extraire les keypoints et lignes
101
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
102
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
103
+ kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD)
104
+ lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD)
105
+ kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
106
+
107
+ # Calibration
108
+ cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
109
+ cam.update(kp_dict, lines_dict)
110
+ final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE)
111
+
112
+ return final_params_dict
113
+
114
+ # Modèles Pydantic pour la validation des données
115
+ class Point(BaseModel):
116
+ x: float
117
+ y: float
118
+
119
+ class LinePolygon(BaseModel):
120
+ points: List[Point]
121
+
122
+ class CalibrationRequest(BaseModel):
123
+ lines: Dict[str, List[Point]]
124
+
125
+ class CalibrationResponse(BaseModel):
126
+ status: str
127
+ camera_parameters: Dict[str, Any]
128
+ input_lines: Dict[str, List[Point]]
129
+ message: str
130
+
131
+ class InferenceImageResponse(BaseModel):
132
+ status: str
133
+ camera_parameters: Optional[Dict[str, Any]]
134
+ image_info: Dict[str, Any]
135
+ message: str
136
+
137
+ class InferenceVideoResponse(BaseModel):
138
+ status: str
139
+ camera_parameters: List[Dict[str, Any]]
140
+ video_info: Dict[str, Any]
141
+ frames_processed: int
142
+ message: str
143
+
144
+ @app.get("/")
145
+ async def root():
146
+ return {
147
+ "message": "Football Vision Calibration API",
148
+ "version": "1.0.0",
149
+ "endpoints": {
150
+ "/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
151
+ "/inference/image": "POST - Extraire les paramètres de caméra d'une image automatiquement",
152
+ "/inference/video": "POST - Extraire les paramètres de caméra d'une vidéo automatiquement",
153
+ "/health": "GET - Vérifier l'état de l'API"
154
+ }
155
+ }
156
+
157
+ @app.get("/health")
158
+ async def health_check():
159
+ return {"status": "healthy", "message": "API is running"}
160
+
161
+ @app.post("/calibrate", response_model=CalibrationResponse)
162
+ async def calibrate_camera(
163
+ image: UploadFile = File(..., description="Image du terrain de football"),
164
+ lines_data: str = Form(..., description="JSON des lignes du terrain")
165
+ ):
166
+ """
167
+ Calibrer une caméra à partir d'une image et des lignes du terrain.
168
+
169
+ Args:
170
+ image: Image du terrain de football (formats: jpg, jpeg, png)
171
+ lines_data: JSON contenant les lignes du terrain au format:
172
+ {"nom_ligne": [{"x": float, "y": float}, ...], ...}
173
+
174
+ Returns:
175
+ Paramètres de calibration de la caméra et lignes d'entrée
176
+ """
177
+ try:
178
+ # Validation du format d'image - version robuste
179
+ content_type = getattr(image, 'content_type', None) or ""
180
+ filename = getattr(image, 'filename', "") or ""
181
+
182
+ # Vérifier le type MIME ou l'extension du fichier
183
+ image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
184
+ is_image_content = content_type.startswith('image/') if content_type else False
185
+ is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
186
+
187
+ if not is_image_content and not is_image_extension:
188
+ raise HTTPException(
189
+ status_code=400,
190
+ detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
191
+ )
192
+
193
+ # Parse des données de lignes
194
+ try:
195
+ lines_dict = json.loads(lines_data)
196
+ except json.JSONDecodeError:
197
+ raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
198
+
199
+ # Validation de la structure des lignes
200
+ validated_lines = {}
201
+ for line_name, points in lines_dict.items():
202
+ if not isinstance(points, list):
203
+ raise HTTPException(
204
+ status_code=400,
205
+ detail=f"Les points de la ligne '{line_name}' doivent être une liste"
206
+ )
207
+
208
+ validated_points = []
209
+ for i, point in enumerate(points):
210
+ if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
211
+ raise HTTPException(
212
+ status_code=400,
213
+ detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
214
+ )
215
+ try:
216
+ validated_points.append({
217
+ "x": float(point['x']),
218
+ "y": float(point['y'])
219
+ })
220
+ except (ValueError, TypeError):
221
+ raise HTTPException(
222
+ status_code=400,
223
+ detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
224
+ )
225
+
226
+ validated_lines[line_name] = validated_points
227
+
228
+ # Sauvegarde temporaire de l'image
229
+ file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
230
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
231
+ content = await image.read()
232
+ temp_file.write(content)
233
+ temp_image_path = temp_file.name
234
+
235
+ try:
236
+ # Validation de l'image
237
+ pil_image = Image.open(temp_image_path)
238
+ pil_image.verify() # Vérification de l'intégrité de l'image
239
+
240
+ # Calibration de la caméra
241
+ camera_params = get_camera_parameters(temp_image_path, validated_lines)
242
+
243
+ # Formatage de la réponse
244
+ response = CalibrationResponse(
245
+ status="success",
246
+ camera_parameters=camera_params,
247
+ input_lines=validated_lines,
248
+ message="Calibration réussie"
249
+ )
250
+
251
+ return response
252
+
253
+ except Exception as e:
254
+ raise HTTPException(
255
+ status_code=500,
256
+ detail=f"Erreur lors de la calibration: {str(e)}"
257
+ )
258
+
259
+ finally:
260
+ # Nettoyage du fichier temporaire
261
+ if os.path.exists(temp_image_path):
262
+ os.unlink(temp_image_path)
263
+
264
+ except HTTPException:
265
+ raise
266
+ except Exception as e:
267
+ raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
268
+
269
+ @app.post("/inference/image", response_model=InferenceImageResponse)
270
+ async def inference_image(
271
+ image: UploadFile = File(..., description="Image du terrain de football"),
272
+ kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
273
+ line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes")
274
+ ):
275
+ """
276
+ Extraire automatiquement les paramètres de caméra à partir d'une image.
277
+
278
+ Args:
279
+ image: Image du terrain de football (formats: jpg, jpeg, png)
280
+ kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
281
+ line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
282
+
283
+ Returns:
284
+ Paramètres de calibration de la caméra extraits automatiquement
285
+ """
286
+ params = None # Initialiser params
287
+ try:
288
+ # Validation du format d'image - version robuste
289
+ content_type = getattr(image, 'content_type', None) or ""
290
+ filename = getattr(image, 'filename', "") or ""
291
+
292
+ # Vérifier le type MIME ou l'extension du fichier
293
+ image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
294
+ is_image_content = content_type.startswith('image/') if content_type else False
295
+ is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
296
+
297
+ if not is_image_content and not is_image_extension:
298
+ raise HTTPException(
299
+ status_code=400,
300
+ detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
301
+ )
302
+
303
+ # Sauvegarde temporaire de l'image
304
+ file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
305
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
306
+ content = await image.read()
307
+ temp_file.write(content)
308
+ temp_image_path = temp_file.name
309
+
310
+ try:
311
+ # Charger les modèles
312
+ model, model_l, device = load_inference_models()
313
+
314
+ # Lire l'image
315
+ frame = cv2.imread(temp_image_path)
316
+ if frame is None:
317
+ raise HTTPException(status_code=400, detail="Impossible de lire l'image")
318
+
319
+ frame_height, frame_width = frame.shape[:2]
320
+
321
+ # Mettre à jour les seuils globaux
322
+ global KP_THRESHOLD, LINE_THRESHOLD
323
+ KP_THRESHOLD = kp_threshold
324
+ LINE_THRESHOLD = line_threshold
325
+
326
+ # Traitement
327
+ params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
328
+
329
+ # Formatage de la réponse
330
+ response = InferenceImageResponse(
331
+ status="success" if params is not None else "failed",
332
+ camera_parameters=params,
333
+ image_info={
334
+ "filename": filename,
335
+ "width": frame_width,
336
+ "height": frame_height,
337
+ "kp_threshold": kp_threshold,
338
+ "line_threshold": line_threshold
339
+ },
340
+ message="Paramètres extraits avec succès" if params is not None else "Échec de l'extraction des paramètres"
341
+ )
342
+
343
+ return response
344
+
345
+ except Exception as e:
346
+ raise HTTPException(
347
+ status_code=500,
348
+ detail=f"Erreur lors de l'inférence: {str(e)}"
349
+ )
350
+
351
+ finally:
352
+ # Nettoyage du fichier temporaire
353
+ if os.path.exists(temp_image_path):
354
+ os.unlink(temp_image_path)
355
+
356
+ except HTTPException:
357
+ raise
358
+ except Exception as e:
359
+ raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
360
+
361
+ @app.post("/inference/video", response_model=InferenceVideoResponse)
362
+ async def inference_video(
363
+ video: UploadFile = File(..., description="Vidéo du terrain de football"),
364
+ kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
365
+ line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes"),
366
+ frame_step: int = Form(FRAME_STEP, description="Traiter 1 frame sur N")
367
+ ):
368
+ """
369
+ Extraire automatiquement les paramètres de caméra à partir d'une vidéo.
370
+
371
+ Args:
372
+ video: Vidéo du terrain de football (formats: mp4, avi, mov, etc.)
373
+ kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
374
+ line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
375
+ frame_step: Traiter 1 frame sur N pour accélérer le traitement (défaut: 5)
376
+
377
+ Returns:
378
+ Liste des paramètres de calibration de la caméra pour chaque frame traitée
379
+ """
380
+ try:
381
+ # Validation du format vidéo - version robuste
382
+ content_type = getattr(video, 'content_type', None) or ""
383
+ filename = getattr(video, 'filename', "") or ""
384
+
385
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
386
+ is_video_content = content_type.startswith('video/') if content_type else False
387
+ is_video_extension = any(filename.lower().endswith(ext) for ext in video_extensions)
388
+
389
+ if not is_video_content and not is_video_extension:
390
+ raise HTTPException(
391
+ status_code=400,
392
+ detail=f"Le fichier doit être une vidéo. Type détecté: {content_type}, Fichier: {filename}"
393
+ )
394
+
395
+ # Sauvegarde temporaire de la vidéo
396
+ file_extension = os.path.splitext(filename)[1] if filename else '.mp4'
397
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
398
+ content = await video.read()
399
+ temp_file.write(content)
400
+ temp_video_path = temp_file.name
401
+
402
+ try:
403
+ # Charger les modèles
404
+ model, model_l, device = load_inference_models()
405
+
406
+ # Ouvrir la vidéo
407
+ cap = cv2.VideoCapture(temp_video_path)
408
+ if not cap.isOpened():
409
+ raise HTTPException(status_code=400, detail="Impossible d'ouvrir la vidéo")
410
+
411
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
412
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
413
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
414
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
415
+
416
+ # Mettre à jour les seuils globaux
417
+ global KP_THRESHOLD, LINE_THRESHOLD
418
+ KP_THRESHOLD = kp_threshold
419
+ LINE_THRESHOLD = line_threshold
420
+
421
+ all_params = []
422
+ frame_count = 0
423
+ processed_count = 0
424
+
425
+ while cap.isOpened():
426
+ ret, frame = cap.read()
427
+ if not ret:
428
+ break
429
+
430
+ # Traiter seulement 1 frame sur frame_step
431
+ if frame_count % frame_step != 0:
432
+ frame_count += 1
433
+ continue
434
+
435
+ # Traitement
436
+ params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
437
+
438
+ if params is not None:
439
+ params['frame_number'] = frame_count
440
+ params['timestamp_seconds'] = frame_count / fps
441
+ all_params.append(params)
442
+ processed_count += 1
443
+
444
+ frame_count += 1
445
+
446
+ cap.release()
447
+
448
+ # Formatage de la réponse
449
+ response = InferenceVideoResponse(
450
+ status="success" if all_params else "failed",
451
+ camera_parameters=all_params,
452
+ video_info={
453
+ "filename": filename,
454
+ "width": frame_width,
455
+ "height": frame_height,
456
+ "total_frames": total_frames,
457
+ "fps": fps,
458
+ "duration_seconds": total_frames / fps,
459
+ "kp_threshold": kp_threshold,
460
+ "line_threshold": line_threshold,
461
+ "frame_step": frame_step
462
+ },
463
+ frames_processed=processed_count,
464
+ message=f"Paramètres extraits de {processed_count} frames" if all_params else "Aucun paramètre extrait"
465
+ )
466
+
467
+ return response
468
+
469
+ except Exception as e:
470
+ raise HTTPException(
471
+ status_code=500,
472
+ detail=f"Erreur lors de l'inférence vidéo: {str(e)}"
473
+ )
474
+
475
+ finally:
476
+ # Nettoyage du fichier temporaire
477
+ if os.path.exists(temp_video_path):
478
+ os.unlink(temp_video_path)
479
+
480
+ except HTTPException:
481
+ raise
482
+ except Exception as e:
483
+ raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
484
+
485
  app_instance = app