import numpy as np import joblib import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl from sklearn.metrics.pairwise import cosine_distances from sentence_transformers import SentenceTransformer class IntentClassifier(pl.LightningModule): def __init__(self, input_dim=384, hidden_dim=256, output_dim=150, lr=1e-3, weight_decay=1e-4): super().__init__() self.save_hyperparameters() self.fc1 = nn.Linear(input_dim, hidden_dim) self.bn1 = nn.BatchNorm1d(hidden_dim) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.3) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = self.fc1(x) x = self.bn1(x) x = self.relu(x) x = self.dropout(x) return self.fc2(x) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) preds = torch.argmax(logits, dim=1) mask = y != -1 if mask.sum() > 0: val_loss = F.cross_entropy(logits[mask], y[mask]) val_acc = (preds[mask] == y[mask]).float().mean() else: val_loss = torch.tensor(0.0, device=self.device) val_acc = torch.tensor(0.0, device=self.device) self.log("val_loss", val_loss, prog_bar=True) self.log("val_acc", val_acc, prog_bar=True) def configure_optimizers(self): return torch.optim.Adam( self.parameters(), lr=self.hparams.lr, weight_decay=1e-4 ) class IntentClassifierWithOOS: def __init__(self, embedder, classifier, oos_detector, label_encoder, centroids_dict, oos_threshold=0.5, device="cpu"): self.embedder = embedder # SentenceTransformer self.classifier = classifier.eval().to(device) # MLP self.oos_detector = oos_detector # pipeline sklearn self.label_encoder = label_encoder # fitted LabelEncoder self.centroids_dict = centroids_dict # {class_id: centroid} self.threshold = oos_threshold self.device = device def _compute_features(self, embedding, logits, predicted_class): probs = F.softmax(logits, dim=0).cpu().numpy() entropy = -np.sum(probs * np.log(probs + 1e-10)) msp = np.max(probs) energy = torch.logsumexp(logits, dim=0).item() # Logit gap sorted_logits = torch.sort(logits, descending=True).values logit_gap = (sorted_logits[0] - sorted_logits[1]).item() # Euclidean distance to class centroid centroid = self.centroids_dict.get(predicted_class) dist = np.linalg.norm(embedding - centroid) if centroid is not None else np.nan # Cosine distance cos_dist = cosine_distances([embedding], [centroid])[0][0] if centroid is not None else np.nan norm_emb = np.linalg.norm(embedding) return np.array([entropy, msp, dist]) def predict(self, sentence): # 1. Embedding embedding = self.embedder.encode(sentence) embedding = np.array(embedding) embedding_tensor = torch.tensor(embedding, dtype=torch.float32).unsqueeze(0).to(self.device) # 2. Intent prediction (MLP) with torch.no_grad(): logits = self.classifier(embedding_tensor) logits = logits.squeeze(0) probs = F.softmax(logits, dim=0) predicted_class = torch.argmax(probs).item() confidence = probs[predicted_class].item() # 3. Feature extraction features = self._compute_features(embedding, logits, predicted_class).reshape(1, -1) # 4. OOS detection oos_score = self.oos_detector.predict_proba(features)[0, 1] is_oos = oos_score >= self.threshold # 5. Output return { "intent": "oos" if is_oos else self.label_encoder.inverse_transform([predicted_class])[0], "is_oos": bool(is_oos), "confidence": None if is_oos else confidence, "oos_score": oos_score } # Load all saved components from the current directory best_model = IntentClassifier.load_from_checkpoint( "intent_classifier.ckpt", map_location=torch.device("cpu") ) oos_detector = joblib.load("oos_detector.pkl") label_encoder = joblib.load("label_encoder.pkl") class_centroids = joblib.load("class_centroids.pkl") best_threshold = joblib.load("oos_threshold.pkl") print("Model charging") # Recharger l'embedding model embedder = SentenceTransformer("intfloat/e5-small-v2") # Build the full inference model model = IntentClassifierWithOOS( embedder=embedder, classifier=best_model, oos_detector=oos_detector, label_encoder=label_encoder, centroids_dict=class_centroids, oos_threshold=best_threshold, device="cpu" )