Spaces:
Running
Running
import numpy as np | |
import joblib | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from sklearn.metrics.pairwise import cosine_distances | |
from sentence_transformers import SentenceTransformer | |
class IntentClassifier(pl.LightningModule): | |
def __init__(self, input_dim=384, hidden_dim=256, output_dim=150, lr=1e-3, weight_decay=1e-4): | |
super().__init__() | |
self.save_hyperparameters() | |
self.fc1 = nn.Linear(input_dim, hidden_dim) | |
self.bn1 = nn.BatchNorm1d(hidden_dim) | |
self.relu = nn.ReLU() | |
self.dropout = nn.Dropout(0.3) | |
self.fc2 = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.dropout(x) | |
return self.fc2(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.cross_entropy(logits, y) | |
self.log("train_loss", loss) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
preds = torch.argmax(logits, dim=1) | |
mask = y != -1 | |
if mask.sum() > 0: | |
val_loss = F.cross_entropy(logits[mask], y[mask]) | |
val_acc = (preds[mask] == y[mask]).float().mean() | |
else: | |
val_loss = torch.tensor(0.0, device=self.device) | |
val_acc = torch.tensor(0.0, device=self.device) | |
self.log("val_loss", val_loss, prog_bar=True) | |
self.log("val_acc", val_acc, prog_bar=True) | |
def configure_optimizers(self): | |
return torch.optim.Adam( | |
self.parameters(), | |
lr=self.hparams.lr, | |
weight_decay=1e-4 | |
) | |
class IntentClassifierWithOOS: | |
def __init__(self, embedder, classifier, oos_detector, label_encoder, centroids_dict, oos_threshold=0.5, device="cpu"): | |
self.embedder = embedder # SentenceTransformer | |
self.classifier = classifier.eval().to(device) # MLP | |
self.oos_detector = oos_detector # pipeline sklearn | |
self.label_encoder = label_encoder # fitted LabelEncoder | |
self.centroids_dict = centroids_dict # {class_id: centroid} | |
self.threshold = oos_threshold | |
self.device = device | |
def _compute_features(self, embedding, logits, predicted_class): | |
probs = F.softmax(logits, dim=0).cpu().numpy() | |
entropy = -np.sum(probs * np.log(probs + 1e-10)) | |
msp = np.max(probs) | |
energy = torch.logsumexp(logits, dim=0).item() | |
# Logit gap | |
sorted_logits = torch.sort(logits, descending=True).values | |
logit_gap = (sorted_logits[0] - sorted_logits[1]).item() | |
# Euclidean distance to class centroid | |
centroid = self.centroids_dict.get(predicted_class) | |
dist = np.linalg.norm(embedding - centroid) if centroid is not None else np.nan | |
# Cosine distance | |
cos_dist = cosine_distances([embedding], [centroid])[0][0] if centroid is not None else np.nan | |
norm_emb = np.linalg.norm(embedding) | |
return np.array([entropy, msp, dist]) | |
def predict(self, sentence): | |
# 1. Embedding | |
embedding = self.embedder.encode(sentence) | |
embedding = np.array(embedding) | |
embedding_tensor = torch.tensor(embedding, dtype=torch.float32).unsqueeze(0).to(self.device) | |
# 2. Intent prediction (MLP) | |
with torch.no_grad(): | |
logits = self.classifier(embedding_tensor) | |
logits = logits.squeeze(0) | |
probs = F.softmax(logits, dim=0) | |
predicted_class = torch.argmax(probs).item() | |
confidence = probs[predicted_class].item() | |
# 3. Feature extraction | |
features = self._compute_features(embedding, logits, predicted_class).reshape(1, -1) | |
# 4. OOS detection | |
oos_score = self.oos_detector.predict_proba(features)[0, 1] | |
is_oos = oos_score >= self.threshold | |
# 5. Output | |
return { | |
"intent": "oos" if is_oos else self.label_encoder.inverse_transform([predicted_class])[0], | |
"is_oos": bool(is_oos), | |
"confidence": None if is_oos else confidence, | |
"oos_score": oos_score | |
} | |
# Load all saved components from the current directory | |
best_model = IntentClassifier.load_from_checkpoint( | |
"intent_classifier.ckpt", | |
map_location=torch.device("cpu") | |
) | |
oos_detector = joblib.load("oos_detector.pkl") | |
label_encoder = joblib.load("label_encoder.pkl") | |
class_centroids = joblib.load("class_centroids.pkl") | |
best_threshold = joblib.load("oos_threshold.pkl") | |
# Recharger l'embedding model | |
embedder = SentenceTransformer("intfloat/e5-small-v2") | |
# Build the full inference model | |
model = IntentClassifierWithOOS( | |
embedder=embedder, | |
classifier=best_model, | |
oos_detector=oos_detector, | |
label_encoder=label_encoder, | |
centroids_dict=class_centroids, | |
oos_threshold=best_threshold, | |
device="cpu" | |
) | |
# Test with a sample query | |
result = model.predict("Can you play some jazz music?") | |
print(result) |