PnLCalib / api.py
2nzi's picture
ADD INFERENCE ENDPOINTS for images and video
fe10d2a
raw
history blame
19.1 kB
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, List, Any, Optional
import json
import tempfile
import os
from PIL import Image
import numpy as np
import cv2
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as f
import yaml
from tqdm import tqdm
from get_camera_params import get_camera_parameters
# Imports pour l'inférence automatique
from model.cls_hrnet import get_cls_net
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
from utils.utils_calib import FramebyFrameCalib
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, complete_keypoints, coords_to_dict
app = FastAPI(
title="Football Vision Calibration API",
description="API pour la calibration de caméras à partir de lignes de terrain de football",
version="1.0.0"
)
# Configuration CORS pour autoriser les requêtes depuis le frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # En production, spécifiez les domaines autorisés
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Paramètres par défaut pour l'inférence
WEIGHTS_KP = "models/SV_FT_TSWC_kp"
WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
DEVICE = "cuda:0"
KP_THRESHOLD = 0.15
LINE_THRESHOLD = 0.15
PNL_REFINE = True
FRAME_STEP = 5
# Cache pour les modèles (éviter de les recharger à chaque requête)
_models_cache = None
def load_inference_models():
"""Charge les modèles d'inférence (avec cache)"""
global _models_cache
if _models_cache is not None:
return _models_cache
device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
# Charger les configurations
cfg = yaml.safe_load(open("config/hrnetv2_w48.yaml", 'r'))
cfg_l = yaml.safe_load(open("config/hrnetv2_w48_l.yaml", 'r'))
# Modèle keypoints
model = get_cls_net(cfg)
model.load_state_dict(torch.load(WEIGHTS_KP, map_location=device))
model.to(device)
model.eval()
# Modèle lignes
model_l = get_cls_net_l(cfg_l)
model_l.load_state_dict(torch.load(WEIGHTS_LINE, map_location=device))
model_l.to(device)
model_l.eval()
_models_cache = (model, model_l, device)
return _models_cache
def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
"""Traite une frame et retourne les paramètres de caméra"""
transform = T.Resize((540, 960))
# Préparer la frame pour l'inférence
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0)
if frame_tensor.size()[-1] != 960:
frame_tensor = transform(frame_tensor)
frame_tensor = frame_tensor.to(device)
b, c, h, w = frame_tensor.size()
# Inférence
with torch.no_grad():
heatmaps = model(frame_tensor)
heatmaps_l = model_l(frame_tensor)
# Extraire les keypoints et lignes
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD)
lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD)
kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
# Calibration
cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
cam.update(kp_dict, lines_dict)
final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE)
return final_params_dict
# Modèles Pydantic pour la validation des données
class Point(BaseModel):
x: float
y: float
class LinePolygon(BaseModel):
points: List[Point]
class CalibrationRequest(BaseModel):
lines: Dict[str, List[Point]]
class CalibrationResponse(BaseModel):
status: str
camera_parameters: Dict[str, Any]
input_lines: Dict[str, List[Point]]
message: str
class InferenceImageResponse(BaseModel):
status: str
camera_parameters: Optional[Dict[str, Any]]
image_info: Dict[str, Any]
message: str
class InferenceVideoResponse(BaseModel):
status: str
camera_parameters: List[Dict[str, Any]]
video_info: Dict[str, Any]
frames_processed: int
message: str
@app.get("/")
async def root():
return {
"message": "Football Vision Calibration API",
"version": "1.0.0",
"endpoints": {
"/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
"/inference/image": "POST - Extraire les paramètres de caméra d'une image automatiquement",
"/inference/video": "POST - Extraire les paramètres de caméra d'une vidéo automatiquement",
"/health": "GET - Vérifier l'état de l'API"
}
}
@app.get("/health")
async def health_check():
return {"status": "healthy", "message": "API is running"}
@app.post("/calibrate", response_model=CalibrationResponse)
async def calibrate_camera(
image: UploadFile = File(..., description="Image du terrain de football"),
lines_data: str = Form(..., description="JSON des lignes du terrain")
):
"""
Calibrer une caméra à partir d'une image et des lignes du terrain.
Args:
image: Image du terrain de football (formats: jpg, jpeg, png)
lines_data: JSON contenant les lignes du terrain au format:
{"nom_ligne": [{"x": float, "y": float}, ...], ...}
Returns:
Paramètres de calibration de la caméra et lignes d'entrée
"""
try:
# Validation du format d'image - version robuste
content_type = getattr(image, 'content_type', None) or ""
filename = getattr(image, 'filename', "") or ""
# Vérifier le type MIME ou l'extension du fichier
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
is_image_content = content_type.startswith('image/') if content_type else False
is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
if not is_image_content and not is_image_extension:
raise HTTPException(
status_code=400,
detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
)
# Parse des données de lignes
try:
lines_dict = json.loads(lines_data)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
# Validation de la structure des lignes
validated_lines = {}
for line_name, points in lines_dict.items():
if not isinstance(points, list):
raise HTTPException(
status_code=400,
detail=f"Les points de la ligne '{line_name}' doivent être une liste"
)
validated_points = []
for i, point in enumerate(points):
if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
raise HTTPException(
status_code=400,
detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
)
try:
validated_points.append({
"x": float(point['x']),
"y": float(point['y'])
})
except (ValueError, TypeError):
raise HTTPException(
status_code=400,
detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
)
validated_lines[line_name] = validated_points
# Sauvegarde temporaire de l'image
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
content = await image.read()
temp_file.write(content)
temp_image_path = temp_file.name
try:
# Validation de l'image
pil_image = Image.open(temp_image_path)
pil_image.verify() # Vérification de l'intégrité de l'image
# Calibration de la caméra
camera_params = get_camera_parameters(temp_image_path, validated_lines)
# Formatage de la réponse
response = CalibrationResponse(
status="success",
camera_parameters=camera_params,
input_lines=validated_lines,
message="Calibration réussie"
)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Erreur lors de la calibration: {str(e)}"
)
finally:
# Nettoyage du fichier temporaire
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
@app.post("/inference/image", response_model=InferenceImageResponse)
async def inference_image(
image: UploadFile = File(..., description="Image du terrain de football"),
kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes")
):
"""
Extraire automatiquement les paramètres de caméra à partir d'une image.
Args:
image: Image du terrain de football (formats: jpg, jpeg, png)
kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
Returns:
Paramètres de calibration de la caméra extraits automatiquement
"""
try:
# Validation du format d'image - version robuste
content_type = getattr(image, 'content_type', None) or ""
filename = getattr(image, 'filename', "") or ""
# Vérifier le type MIME ou l'extension du fichier
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
is_image_content = content_type.startswith('image/') if content_type else False
is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
if not is_image_content and not is_image_extension:
raise HTTPException(
status_code=400,
detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
)
# Sauvegarde temporaire de l'image
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
content = await image.read()
temp_file.write(content)
temp_image_path = temp_file.name
try:
# Charger les modèles
model, model_l, device = load_inference_models()
# Lire l'image
frame = cv2.imread(temp_image_path)
if frame is None:
raise HTTPException(status_code=400, detail="Impossible de lire l'image")
frame_height, frame_width = frame.shape[:2]
# Mettre à jour les seuils globaux
global KP_THRESHOLD, LINE_THRESHOLD
KP_THRESHOLD = kp_threshold
LINE_THRESHOLD = line_threshold
# Traitement
params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
# Formatage de la réponse
response = InferenceImageResponse(
status="success" if params is not None else "failed",
camera_parameters=params,
image_info={
"filename": filename,
"width": frame_width,
"height": frame_height,
"kp_threshold": kp_threshold,
"line_threshold": line_threshold
},
message="Paramètres extraits avec succès" if params is not None else "Échec de l'extraction des paramètres"
)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Erreur lors de l'inférence: {str(e)} \n params:\n{params}"
)
finally:
# Nettoyage du fichier temporaire
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
@app.post("/inference/video", response_model=InferenceVideoResponse)
async def inference_video(
video: UploadFile = File(..., description="Vidéo du terrain de football"),
kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes"),
frame_step: int = Form(FRAME_STEP, description="Traiter 1 frame sur N")
):
"""
Extraire automatiquement les paramètres de caméra à partir d'une vidéo.
Args:
video: Vidéo du terrain de football (formats: mp4, avi, mov, etc.)
kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
frame_step: Traiter 1 frame sur N pour accélérer le traitement (défaut: 5)
Returns:
Liste des paramètres de calibration de la caméra pour chaque frame traitée
"""
try:
# Validation du format vidéo - version robuste
content_type = getattr(video, 'content_type', None) or ""
filename = getattr(video, 'filename', "") or ""
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
is_video_content = content_type.startswith('video/') if content_type else False
is_video_extension = any(filename.lower().endswith(ext) for ext in video_extensions)
if not is_video_content and not is_video_extension:
raise HTTPException(
status_code=400,
detail=f"Le fichier doit être une vidéo. Type détecté: {content_type}, Fichier: {filename}"
)
# Sauvegarde temporaire de la vidéo
file_extension = os.path.splitext(filename)[1] if filename else '.mp4'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
content = await video.read()
temp_file.write(content)
temp_video_path = temp_file.name
try:
# Charger les modèles
model, model_l, device = load_inference_models()
# Ouvrir la vidéo
cap = cv2.VideoCapture(temp_video_path)
if not cap.isOpened():
raise HTTPException(status_code=400, detail="Impossible d'ouvrir la vidéo")
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Mettre à jour les seuils globaux
global KP_THRESHOLD, LINE_THRESHOLD
KP_THRESHOLD = kp_threshold
LINE_THRESHOLD = line_threshold
all_params = []
frame_count = 0
processed_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Traiter seulement 1 frame sur frame_step
if frame_count % frame_step != 0:
frame_count += 1
continue
# Traitement
params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
if params is not None:
params['frame_number'] = frame_count
params['timestamp_seconds'] = frame_count / fps
all_params.append(params)
processed_count += 1
frame_count += 1
cap.release()
# Formatage de la réponse
response = InferenceVideoResponse(
status="success" if all_params else "failed",
camera_parameters=all_params,
video_info={
"filename": filename,
"width": frame_width,
"height": frame_height,
"total_frames": total_frames,
"fps": fps,
"duration_seconds": total_frames / fps,
"kp_threshold": kp_threshold,
"line_threshold": line_threshold,
"frame_step": frame_step
},
frames_processed=processed_count,
message=f"Paramètres extraits de {processed_count} frames" if all_params else "Aucun paramètre extrait"
)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Erreur lors de l'inférence vidéo: {str(e)}"
)
finally:
# Nettoyage du fichier temporaire
if os.path.exists(temp_video_path):
os.unlink(temp_video_path)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
# Point d'entrée pour Vercel
app_instance = app