intent-classifier / inference.py
admchiem
Ajout du print pour vérifier que le modèle ne se recharge qu'une fois
fd64fb4
import numpy as np
import joblib
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from sklearn.metrics.pairwise import cosine_distances
from sentence_transformers import SentenceTransformer
class IntentClassifier(pl.LightningModule):
def __init__(self, input_dim=384, hidden_dim=256, output_dim=150, lr=1e-3, weight_decay=1e-4):
super().__init__()
self.save_hyperparameters()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.bn1 = nn.BatchNorm1d(hidden_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.3)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.dropout(x)
return self.fc2(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
preds = torch.argmax(logits, dim=1)
mask = y != -1
if mask.sum() > 0:
val_loss = F.cross_entropy(logits[mask], y[mask])
val_acc = (preds[mask] == y[mask]).float().mean()
else:
val_loss = torch.tensor(0.0, device=self.device)
val_acc = torch.tensor(0.0, device=self.device)
self.log("val_loss", val_loss, prog_bar=True)
self.log("val_acc", val_acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(),
lr=self.hparams.lr,
weight_decay=1e-4
)
class IntentClassifierWithOOS:
def __init__(self, embedder, classifier, oos_detector, label_encoder, centroids_dict, oos_threshold=0.5, device="cpu"):
self.embedder = embedder # SentenceTransformer
self.classifier = classifier.eval().to(device) # MLP
self.oos_detector = oos_detector # pipeline sklearn
self.label_encoder = label_encoder # fitted LabelEncoder
self.centroids_dict = centroids_dict # {class_id: centroid}
self.threshold = oos_threshold
self.device = device
def _compute_features(self, embedding, logits, predicted_class):
probs = F.softmax(logits, dim=0).cpu().numpy()
entropy = -np.sum(probs * np.log(probs + 1e-10))
msp = np.max(probs)
energy = torch.logsumexp(logits, dim=0).item()
# Logit gap
sorted_logits = torch.sort(logits, descending=True).values
logit_gap = (sorted_logits[0] - sorted_logits[1]).item()
# Euclidean distance to class centroid
centroid = self.centroids_dict.get(predicted_class)
dist = np.linalg.norm(embedding - centroid) if centroid is not None else np.nan
# Cosine distance
cos_dist = cosine_distances([embedding], [centroid])[0][0] if centroid is not None else np.nan
norm_emb = np.linalg.norm(embedding)
return np.array([entropy, msp, dist])
def predict(self, sentence):
# 1. Embedding
embedding = self.embedder.encode(sentence)
embedding = np.array(embedding)
embedding_tensor = torch.tensor(embedding, dtype=torch.float32).unsqueeze(0).to(self.device)
# 2. Intent prediction (MLP)
with torch.no_grad():
logits = self.classifier(embedding_tensor)
logits = logits.squeeze(0)
probs = F.softmax(logits, dim=0)
predicted_class = torch.argmax(probs).item()
confidence = probs[predicted_class].item()
# 3. Feature extraction
features = self._compute_features(embedding, logits, predicted_class).reshape(1, -1)
# 4. OOS detection
oos_score = self.oos_detector.predict_proba(features)[0, 1]
is_oos = oos_score >= self.threshold
# 5. Output
return {
"intent": "oos" if is_oos else self.label_encoder.inverse_transform([predicted_class])[0],
"is_oos": bool(is_oos),
"confidence": None if is_oos else confidence,
"oos_score": oos_score
}
# Load all saved components from the current directory
best_model = IntentClassifier.load_from_checkpoint(
"intent_classifier.ckpt",
map_location=torch.device("cpu")
)
oos_detector = joblib.load("oos_detector.pkl")
label_encoder = joblib.load("label_encoder.pkl")
class_centroids = joblib.load("class_centroids.pkl")
best_threshold = joblib.load("oos_threshold.pkl")
print("Model charging")
# Recharger l'embedding model
embedder = SentenceTransformer("intfloat/e5-small-v2")
# Build the full inference model
model = IntentClassifierWithOOS(
embedder=embedder,
classifier=best_model,
oos_detector=oos_detector,
label_encoder=label_encoder,
centroids_dict=class_centroids,
oos_threshold=best_threshold,
device="cpu"
)