from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi.middleware.cors import CORSMiddleware import torch import torchvision.transforms as transforms from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from transformers import AutoFeatureExtractor import timm import numpy as np import json import base64 from io import BytesIO import uvicorn app = FastAPI(title="VerifAI GradCAM API", description="API pour la détection d'images IA avec GradCAM") # Configuration CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AIDetectionGradCAM: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.models = {} self.feature_extractors = {} self.target_layers = {} # Initialiser les modèles self._load_models() def _load_models(self): """Charge les modèles pour la détection""" try: # Modèle Swin Transformer model_name = "microsoft/swin-base-patch4-window7-224-in22k" self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2) self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name) # Définir les couches cibles pour GradCAM self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1] # Mettre en mode évaluation for model in self.models.values(): model.eval() model.to(self.device) except Exception as e: print(f"Erreur lors du chargement des modèles: {e}") def _preprocess_image(self, image, model_type='swin'): """Prétraite l'image pour le modèle""" if isinstance(image, str): # Si c'est un chemin ou base64 if image.startswith('data:image'): # Décoder base64 header, data = image.split(',', 1) image_data = base64.b64decode(data) image = Image.open(BytesIO(image_data)) else: image = Image.open(image) # Convertir en RGB si nécessaire if image.mode != 'RGB': image = image.convert('RGB') # Redimensionner image = image.resize((224, 224)) # Normalisation standard transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) tensor = transform(image).unsqueeze(0).to(self.device) return tensor, np.array(image) / 255.0 def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'): """Génère la carte de saillance GradCAM""" try: model = self.models[model_type] target_layers = self.target_layers[model_type] # Créer l'objet GradCAM cam = GradCAM(model=model, target_layers=target_layers) # Générer la carte de saillance grayscale_cam = cam(input_tensor=image_tensor, targets=None) grayscale_cam = grayscale_cam[0, :] # Superposer sur l'image originale cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) return cam_image except Exception as e: print(f"Erreur GradCAM: {e}") return rgb_img * 255 def predict_and_explain(self, image): """Prédiction avec explication GradCAM""" try: # Prétraitement image_tensor, rgb_img = self._preprocess_image(image) # Prédiction with torch.no_grad(): outputs = self.models['swin'](image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) confidence = probabilities.max().item() prediction = probabilities.argmax().item() # Génération GradCAM cam_image = self._generate_gradcam(image_tensor, rgb_img) # Convertir l'image GradCAM en base64 pil_image = Image.fromarray(cam_image.astype(np.uint8)) buffer = BytesIO() pil_image.save(buffer, format='PNG') cam_base64 = base64.b64encode(buffer.getvalue()).decode() # Résultats result = { 'prediction': prediction, 'confidence': confidence, 'class_probabilities': { 'Real': probabilities[0][0].item(), 'AI-Generated': probabilities[0][1].item() }, 'cam_image': f"data:image/png;base64,{cam_base64}", 'status': 'success' } return result except Exception as e: return {'status': 'error', 'message': str(e)} # Initialiser le détecteur detector = AIDetectionGradCAM() @app.get("/") async def root(): return {"message": "VerifAI GradCAM API", "status": "running"} @app.post("/predict") async def predict_image(file: UploadFile = File(...)): """Endpoint pour analyser une image""" try: # Lire l'image image_data = await file.read() image = Image.open(BytesIO(image_data)) # Analyser result = detector.predict_and_explain(image) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict-base64") async def predict_base64(data: dict): """Endpoint pour analyser une image en base64""" try: if 'image' not in data: raise HTTPException(status_code=400, detail="Champ 'image' requis") image_b64 = data['image'] # Analyser result = detector.predict_and_explain(image_b64) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)