File size: 3,888 Bytes
b36a746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# main.py

import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
from typing import Dict, Any

# --- Configuration ---
# Chemin vers votre modèle fine-tuné. Assurez-vous que ce dossier est correct.
MODEL_PATH = "./mon_modele_darija_final"

# --- Chargement du modèle (partie critique) ---
# Cette partie est exécutée une seule fois, au démarrage du serveur.
# C'est une bonne pratique pour éviter de recharger le modèle à chaque requête.
try:
    print("Chargement du tokenizer et du modèle MARBERT fine-tuné...")
    
    # On spécifie le device (GPU si disponible, sinon CPU)
    device = 0 if torch.cuda.is_available() else -1
    
    # Création du pipeline de classification de texte de Hugging Face.
    # C'est la manière la plus simple d'utiliser un modèle pour l'inférence.
    nlu_pipeline = pipeline(
        "text-classification",
        model=MODEL_PATH,
        tokenizer=MODEL_PATH,
        device=device  # Utilise le GPU si disponible
    )
    print("Modèle chargé avec succès !")

except Exception as e:
    # Si le modèle ne peut pas être chargé, on lève une erreur claire.
    print(f"Erreur critique lors du chargement du modèle: {e}")
    nlu_pipeline = None

# --- Définition de l'application FastAPI ---
app = FastAPI(
    title="API de NLU pour Darija Marocaine",
    description="Une API pour classifier l'intention d'un texte en Darija, basée sur MARBERT.",
    version="1.0.0"
)

# --- Définition des modèles de données (Pydantic) ---
# C'est pour la validation automatique des requêtes.

class TextInput(BaseModel):
    """Modèle pour le corps de la requête de prédiction."""
    text: str # Le champ doit s'appeler 'text'
    # Exemple de requête JSON attendue: {"text": "3afak bghit nchouf lfactura"}

class PredictionResponse(BaseModel):
    """Modèle pour la réponse de l'API."""
    intent: str
    confidence: float

# --- Définition des routes de l'API ---

@app.get("/", tags=["Général"])
def read_root() -> Dict[str, str]:
    """Route principale qui retourne un message de bienvenue."""
    return {"message": "Bienvenue sur l'API de NLU Darija. Utilisez le endpoint /predict pour faire une prédiction."}


@app.get("/health", tags=["Général"])
def health_check() -> Dict[str, str]:
    """Route de 'health check' pour vérifier si le service est en ligne et le modèle chargé."""
    if nlu_pipeline is None:
        raise HTTPException(status_code=500, detail="Erreur: Le modèle NLP n'a pas pu être chargé.")
    return {"status": "ok", "model_status": "loaded"}


@app.post("/predict", response_model=PredictionResponse, tags=["Prédiction"])
def predict_intent(request: TextInput) -> PredictionResponse:
    """
    Endpoint principal pour la prédiction d'intention.
    Prend un texte en entrée et retourne l'intention prédite et son score de confiance.
    """
    if nlu_pipeline is None:
        raise HTTPException(status_code=503, detail="Le service est indisponible car le modèle n'est pas chargé.")

    if not request.text or not request.text.strip():
        raise HTTPException(status_code=400, detail="Le champ 'text' ne peut pas être vide.")

    try:
        # Utilisation du pipeline pour faire la prédiction
        prediction = nlu_pipeline(request.text, top_k=1)[0]
        
        # Le pipeline retourne un dictionnaire avec 'label' et 'score'
        # On renomme pour correspondre à notre modèle de réponse
        intent = prediction['label']
        confidence = prediction['score']

        return PredictionResponse(intent=intent, confidence=confidence)

    except Exception as e:
        # Gestion d'erreurs inattendues pendant la prédiction
        raise HTTPException(status_code=500, detail=f"Une erreur interne est survenue: {str(e)}")