File size: 19,109 Bytes
3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 fe10d2a 3d1f2c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, List, Any, Optional
import json
import tempfile
import os
from PIL import Image
import numpy as np
import cv2
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as f
import yaml
from tqdm import tqdm
from get_camera_params import get_camera_parameters
# Imports pour l'inférence automatique
from model.cls_hrnet import get_cls_net
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
from utils.utils_calib import FramebyFrameCalib
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, complete_keypoints, coords_to_dict
app = FastAPI(
title="Football Vision Calibration API",
description="API pour la calibration de caméras à partir de lignes de terrain de football",
version="1.0.0"
)
# Configuration CORS pour autoriser les requêtes depuis le frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # En production, spécifiez les domaines autorisés
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Paramètres par défaut pour l'inférence
WEIGHTS_KP = "models/SV_FT_TSWC_kp"
WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
DEVICE = "cuda:0"
KP_THRESHOLD = 0.15
LINE_THRESHOLD = 0.15
PNL_REFINE = True
FRAME_STEP = 5
# Cache pour les modèles (éviter de les recharger à chaque requête)
_models_cache = None
def load_inference_models():
"""Charge les modèles d'inférence (avec cache)"""
global _models_cache
if _models_cache is not None:
return _models_cache
device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
# Charger les configurations
cfg = yaml.safe_load(open("config/hrnetv2_w48.yaml", 'r'))
cfg_l = yaml.safe_load(open("config/hrnetv2_w48_l.yaml", 'r'))
# Modèle keypoints
model = get_cls_net(cfg)
model.load_state_dict(torch.load(WEIGHTS_KP, map_location=device))
model.to(device)
model.eval()
# Modèle lignes
model_l = get_cls_net_l(cfg_l)
model_l.load_state_dict(torch.load(WEIGHTS_LINE, map_location=device))
model_l.to(device)
model_l.eval()
_models_cache = (model, model_l, device)
return _models_cache
def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
"""Traite une frame et retourne les paramètres de caméra"""
transform = T.Resize((540, 960))
# Préparer la frame pour l'inférence
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0)
if frame_tensor.size()[-1] != 960:
frame_tensor = transform(frame_tensor)
frame_tensor = frame_tensor.to(device)
b, c, h, w = frame_tensor.size()
# Inférence
with torch.no_grad():
heatmaps = model(frame_tensor)
heatmaps_l = model_l(frame_tensor)
# Extraire les keypoints et lignes
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD)
lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD)
kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
# Calibration
cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
cam.update(kp_dict, lines_dict)
final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE)
return final_params_dict
# Modèles Pydantic pour la validation des données
class Point(BaseModel):
x: float
y: float
class LinePolygon(BaseModel):
points: List[Point]
class CalibrationRequest(BaseModel):
lines: Dict[str, List[Point]]
class CalibrationResponse(BaseModel):
status: str
camera_parameters: Dict[str, Any]
input_lines: Dict[str, List[Point]]
message: str
class InferenceImageResponse(BaseModel):
status: str
camera_parameters: Optional[Dict[str, Any]]
image_info: Dict[str, Any]
message: str
class InferenceVideoResponse(BaseModel):
status: str
camera_parameters: List[Dict[str, Any]]
video_info: Dict[str, Any]
frames_processed: int
message: str
@app.get("/")
async def root():
return {
"message": "Football Vision Calibration API",
"version": "1.0.0",
"endpoints": {
"/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
"/inference/image": "POST - Extraire les paramètres de caméra d'une image automatiquement",
"/inference/video": "POST - Extraire les paramètres de caméra d'une vidéo automatiquement",
"/health": "GET - Vérifier l'état de l'API"
}
}
@app.get("/health")
async def health_check():
return {"status": "healthy", "message": "API is running"}
@app.post("/calibrate", response_model=CalibrationResponse)
async def calibrate_camera(
image: UploadFile = File(..., description="Image du terrain de football"),
lines_data: str = Form(..., description="JSON des lignes du terrain")
):
"""
Calibrer une caméra à partir d'une image et des lignes du terrain.
Args:
image: Image du terrain de football (formats: jpg, jpeg, png)
lines_data: JSON contenant les lignes du terrain au format:
{"nom_ligne": [{"x": float, "y": float}, ...], ...}
Returns:
Paramètres de calibration de la caméra et lignes d'entrée
"""
try:
# Validation du format d'image - version robuste
content_type = getattr(image, 'content_type', None) or ""
filename = getattr(image, 'filename', "") or ""
# Vérifier le type MIME ou l'extension du fichier
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
is_image_content = content_type.startswith('image/') if content_type else False
is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
if not is_image_content and not is_image_extension:
raise HTTPException(
status_code=400,
detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
)
# Parse des données de lignes
try:
lines_dict = json.loads(lines_data)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
# Validation de la structure des lignes
validated_lines = {}
for line_name, points in lines_dict.items():
if not isinstance(points, list):
raise HTTPException(
status_code=400,
detail=f"Les points de la ligne '{line_name}' doivent être une liste"
)
validated_points = []
for i, point in enumerate(points):
if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
raise HTTPException(
status_code=400,
detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
)
try:
validated_points.append({
"x": float(point['x']),
"y": float(point['y'])
})
except (ValueError, TypeError):
raise HTTPException(
status_code=400,
detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
)
validated_lines[line_name] = validated_points
# Sauvegarde temporaire de l'image
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
content = await image.read()
temp_file.write(content)
temp_image_path = temp_file.name
try:
# Validation de l'image
pil_image = Image.open(temp_image_path)
pil_image.verify() # Vérification de l'intégrité de l'image
# Calibration de la caméra
camera_params = get_camera_parameters(temp_image_path, validated_lines)
# Formatage de la réponse
response = CalibrationResponse(
status="success",
camera_parameters=camera_params,
input_lines=validated_lines,
message="Calibration réussie"
)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Erreur lors de la calibration: {str(e)}"
)
finally:
# Nettoyage du fichier temporaire
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
@app.post("/inference/image", response_model=InferenceImageResponse)
async def inference_image(
image: UploadFile = File(..., description="Image du terrain de football"),
kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes")
):
"""
Extraire automatiquement les paramètres de caméra à partir d'une image.
Args:
image: Image du terrain de football (formats: jpg, jpeg, png)
kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
Returns:
Paramètres de calibration de la caméra extraits automatiquement
"""
try:
# Validation du format d'image - version robuste
content_type = getattr(image, 'content_type', None) or ""
filename = getattr(image, 'filename', "") or ""
# Vérifier le type MIME ou l'extension du fichier
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
is_image_content = content_type.startswith('image/') if content_type else False
is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
if not is_image_content and not is_image_extension:
raise HTTPException(
status_code=400,
detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
)
# Sauvegarde temporaire de l'image
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
content = await image.read()
temp_file.write(content)
temp_image_path = temp_file.name
try:
# Charger les modèles
model, model_l, device = load_inference_models()
# Lire l'image
frame = cv2.imread(temp_image_path)
if frame is None:
raise HTTPException(status_code=400, detail="Impossible de lire l'image")
frame_height, frame_width = frame.shape[:2]
# Mettre à jour les seuils globaux
global KP_THRESHOLD, LINE_THRESHOLD
KP_THRESHOLD = kp_threshold
LINE_THRESHOLD = line_threshold
# Traitement
params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
# Formatage de la réponse
response = InferenceImageResponse(
status="success" if params is not None else "failed",
camera_parameters=params,
image_info={
"filename": filename,
"width": frame_width,
"height": frame_height,
"kp_threshold": kp_threshold,
"line_threshold": line_threshold
},
message="Paramètres extraits avec succès" if params is not None else "Échec de l'extraction des paramètres"
)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Erreur lors de l'inférence: {str(e)} \n params:\n{params}"
)
finally:
# Nettoyage du fichier temporaire
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
@app.post("/inference/video", response_model=InferenceVideoResponse)
async def inference_video(
video: UploadFile = File(..., description="Vidéo du terrain de football"),
kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes"),
frame_step: int = Form(FRAME_STEP, description="Traiter 1 frame sur N")
):
"""
Extraire automatiquement les paramètres de caméra à partir d'une vidéo.
Args:
video: Vidéo du terrain de football (formats: mp4, avi, mov, etc.)
kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
frame_step: Traiter 1 frame sur N pour accélérer le traitement (défaut: 5)
Returns:
Liste des paramètres de calibration de la caméra pour chaque frame traitée
"""
try:
# Validation du format vidéo - version robuste
content_type = getattr(video, 'content_type', None) or ""
filename = getattr(video, 'filename', "") or ""
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
is_video_content = content_type.startswith('video/') if content_type else False
is_video_extension = any(filename.lower().endswith(ext) for ext in video_extensions)
if not is_video_content and not is_video_extension:
raise HTTPException(
status_code=400,
detail=f"Le fichier doit être une vidéo. Type détecté: {content_type}, Fichier: {filename}"
)
# Sauvegarde temporaire de la vidéo
file_extension = os.path.splitext(filename)[1] if filename else '.mp4'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
content = await video.read()
temp_file.write(content)
temp_video_path = temp_file.name
try:
# Charger les modèles
model, model_l, device = load_inference_models()
# Ouvrir la vidéo
cap = cv2.VideoCapture(temp_video_path)
if not cap.isOpened():
raise HTTPException(status_code=400, detail="Impossible d'ouvrir la vidéo")
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Mettre à jour les seuils globaux
global KP_THRESHOLD, LINE_THRESHOLD
KP_THRESHOLD = kp_threshold
LINE_THRESHOLD = line_threshold
all_params = []
frame_count = 0
processed_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Traiter seulement 1 frame sur frame_step
if frame_count % frame_step != 0:
frame_count += 1
continue
# Traitement
params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
if params is not None:
params['frame_number'] = frame_count
params['timestamp_seconds'] = frame_count / fps
all_params.append(params)
processed_count += 1
frame_count += 1
cap.release()
# Formatage de la réponse
response = InferenceVideoResponse(
status="success" if all_params else "failed",
camera_parameters=all_params,
video_info={
"filename": filename,
"width": frame_width,
"height": frame_height,
"total_frames": total_frames,
"fps": fps,
"duration_seconds": total_frames / fps,
"kp_threshold": kp_threshold,
"line_threshold": line_threshold,
"frame_step": frame_step
},
frames_processed=processed_count,
message=f"Paramètres extraits de {processed_count} frames" if all_params else "Aucun paramètre extrait"
)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Erreur lors de l'inférence vidéo: {str(e)}"
)
finally:
# Nettoyage du fichier temporaire
if os.path.exists(temp_video_path):
os.unlink(temp_video_path)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
# Point d'entrée pour Vercel
app_instance = app |