Spaces:
Running
Running
admchiem
commited on
Commit
·
5121a37
1
Parent(s):
ef8bc9e
Initial commit with LFS-tracked model and assets
Browse files- app.py +16 -0
- class_centroids.pkl +3 -0
- inference.py +151 -0
- intent_classifier.ckpt +3 -0
- label_encoder.pkl +3 -0
- oos_detector.pkl +3 -0
- oos_threshold.pkl +3 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from inference import model # Ton modèle préparé dans `inference.py`
|
3 |
+
|
4 |
+
def predict_intent(text):
|
5 |
+
result = model.predict(text)
|
6 |
+
return f"Intent: {result['intent']}\nIs OOS: {result['is_oos']}\nConfidence: {result['confidence']}\nOOS Score: {result['oos_score']:.3f}"
|
7 |
+
|
8 |
+
iface = gr.Interface(
|
9 |
+
fn=predict_intent,
|
10 |
+
inputs=gr.Textbox(lines=2, placeholder="Enter a sentence..."),
|
11 |
+
outputs="text",
|
12 |
+
title="Intent & OOS Detector",
|
13 |
+
description="Type a sentence and get its predicted intent or detect if it's out-of-scope."
|
14 |
+
)
|
15 |
+
|
16 |
+
iface.launch()
|
class_centroids.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f267dbd996b67e662479e83e2fb644af4b67a84db0cfe453962404479a1541f
|
3 |
+
size 242642
|
inference.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import joblib
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from sklearn.metrics.pairwise import cosine_distances
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
|
10 |
+
class IntentClassifier(pl.LightningModule):
|
11 |
+
def __init__(self, input_dim=384, hidden_dim=256, output_dim=150, lr=1e-3, weight_decay=1e-4):
|
12 |
+
super().__init__()
|
13 |
+
self.save_hyperparameters()
|
14 |
+
|
15 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
16 |
+
self.bn1 = nn.BatchNorm1d(hidden_dim)
|
17 |
+
self.relu = nn.ReLU()
|
18 |
+
self.dropout = nn.Dropout(0.3)
|
19 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.fc1(x)
|
23 |
+
x = self.bn1(x)
|
24 |
+
x = self.relu(x)
|
25 |
+
x = self.dropout(x)
|
26 |
+
return self.fc2(x)
|
27 |
+
|
28 |
+
def training_step(self, batch, batch_idx):
|
29 |
+
x, y = batch
|
30 |
+
logits = self(x)
|
31 |
+
loss = F.cross_entropy(logits, y)
|
32 |
+
self.log("train_loss", loss)
|
33 |
+
return loss
|
34 |
+
|
35 |
+
|
36 |
+
def validation_step(self, batch, batch_idx):
|
37 |
+
x, y = batch
|
38 |
+
logits = self(x)
|
39 |
+
preds = torch.argmax(logits, dim=1)
|
40 |
+
|
41 |
+
|
42 |
+
mask = y != -1
|
43 |
+
|
44 |
+
if mask.sum() > 0:
|
45 |
+
val_loss = F.cross_entropy(logits[mask], y[mask])
|
46 |
+
val_acc = (preds[mask] == y[mask]).float().mean()
|
47 |
+
else:
|
48 |
+
val_loss = torch.tensor(0.0, device=self.device)
|
49 |
+
val_acc = torch.tensor(0.0, device=self.device)
|
50 |
+
|
51 |
+
self.log("val_loss", val_loss, prog_bar=True)
|
52 |
+
self.log("val_acc", val_acc, prog_bar=True)
|
53 |
+
|
54 |
+
|
55 |
+
def configure_optimizers(self):
|
56 |
+
return torch.optim.Adam(
|
57 |
+
self.parameters(),
|
58 |
+
lr=self.hparams.lr,
|
59 |
+
weight_decay=1e-4
|
60 |
+
)
|
61 |
+
|
62 |
+
class IntentClassifierWithOOS:
|
63 |
+
def __init__(self, embedder, classifier, oos_detector, label_encoder, centroids_dict, oos_threshold=0.5, device="cpu"):
|
64 |
+
self.embedder = embedder # SentenceTransformer
|
65 |
+
self.classifier = classifier.eval().to(device) # MLP
|
66 |
+
self.oos_detector = oos_detector # pipeline sklearn
|
67 |
+
self.label_encoder = label_encoder # fitted LabelEncoder
|
68 |
+
self.centroids_dict = centroids_dict # {class_id: centroid}
|
69 |
+
self.threshold = oos_threshold
|
70 |
+
self.device = device
|
71 |
+
|
72 |
+
def _compute_features(self, embedding, logits, predicted_class):
|
73 |
+
probs = F.softmax(logits, dim=0).cpu().numpy()
|
74 |
+
entropy = -np.sum(probs * np.log(probs + 1e-10))
|
75 |
+
msp = np.max(probs)
|
76 |
+
energy = torch.logsumexp(logits, dim=0).item()
|
77 |
+
|
78 |
+
# Logit gap
|
79 |
+
sorted_logits = torch.sort(logits, descending=True).values
|
80 |
+
logit_gap = (sorted_logits[0] - sorted_logits[1]).item()
|
81 |
+
|
82 |
+
# Euclidean distance to class centroid
|
83 |
+
centroid = self.centroids_dict.get(predicted_class)
|
84 |
+
dist = np.linalg.norm(embedding - centroid) if centroid is not None else np.nan
|
85 |
+
|
86 |
+
# Cosine distance
|
87 |
+
cos_dist = cosine_distances([embedding], [centroid])[0][0] if centroid is not None else np.nan
|
88 |
+
|
89 |
+
norm_emb = np.linalg.norm(embedding)
|
90 |
+
|
91 |
+
return np.array([entropy, msp, dist])
|
92 |
+
|
93 |
+
|
94 |
+
def predict(self, sentence):
|
95 |
+
# 1. Embedding
|
96 |
+
embedding = self.embedder.encode(sentence)
|
97 |
+
embedding = np.array(embedding)
|
98 |
+
embedding_tensor = torch.tensor(embedding, dtype=torch.float32).unsqueeze(0).to(self.device)
|
99 |
+
|
100 |
+
# 2. Intent prediction (MLP)
|
101 |
+
with torch.no_grad():
|
102 |
+
logits = self.classifier(embedding_tensor)
|
103 |
+
logits = logits.squeeze(0)
|
104 |
+
probs = F.softmax(logits, dim=0)
|
105 |
+
predicted_class = torch.argmax(probs).item()
|
106 |
+
confidence = probs[predicted_class].item()
|
107 |
+
|
108 |
+
# 3. Feature extraction
|
109 |
+
features = self._compute_features(embedding, logits, predicted_class).reshape(1, -1)
|
110 |
+
|
111 |
+
# 4. OOS detection
|
112 |
+
oos_score = self.oos_detector.predict_proba(features)[0, 1]
|
113 |
+
is_oos = oos_score >= self.threshold
|
114 |
+
|
115 |
+
# 5. Output
|
116 |
+
return {
|
117 |
+
"intent": "oos" if is_oos else self.label_encoder.inverse_transform([predicted_class])[0],
|
118 |
+
"is_oos": bool(is_oos),
|
119 |
+
"confidence": None if is_oos else confidence,
|
120 |
+
"oos_score": oos_score
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
# Load all saved components from the current directory
|
125 |
+
best_model = IntentClassifier.load_from_checkpoint(
|
126 |
+
"intent_classifier.ckpt",
|
127 |
+
map_location=torch.device("cpu")
|
128 |
+
)
|
129 |
+
|
130 |
+
oos_detector = joblib.load("oos_detector.pkl")
|
131 |
+
label_encoder = joblib.load("label_encoder.pkl")
|
132 |
+
class_centroids = joblib.load("class_centroids.pkl")
|
133 |
+
best_threshold = joblib.load("oos_threshold.pkl")
|
134 |
+
|
135 |
+
# Recharger l'embedding model
|
136 |
+
embedder = SentenceTransformer("intfloat/e5-small-v2")
|
137 |
+
|
138 |
+
# Build the full inference model
|
139 |
+
model = IntentClassifierWithOOS(
|
140 |
+
embedder=embedder,
|
141 |
+
classifier=best_model,
|
142 |
+
oos_detector=oos_detector,
|
143 |
+
label_encoder=label_encoder,
|
144 |
+
centroids_dict=class_centroids,
|
145 |
+
oos_threshold=best_threshold,
|
146 |
+
device="cpu"
|
147 |
+
)
|
148 |
+
|
149 |
+
# Test with a sample query
|
150 |
+
result = model.predict("Can you play some jazz music?")
|
151 |
+
print(result)
|
intent_classifier.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a35fa01002fc1cfba9e4b69cae1d2c8c8355836971c2fb35a7a6ce8153bbea20
|
3 |
+
size 1664127
|
label_encoder.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65443e0fe2e184180333ea5498c8f9a9b6ee0e60dc4e990aeeeccef021d9bd42
|
3 |
+
size 15327
|
oos_detector.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:300780a220a6934f3266073b8e1de7a0a0fb93f2e097f66c6739f8c598eb6e68
|
3 |
+
size 1422
|
oos_threshold.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48c52d53a5fff98b3a4a44a5d9f9bb903d63f5149ec727aba8efe6fb2f2276c1
|
3 |
+
size 21
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
sentence-transformers
|
3 |
+
scikit-learn
|
4 |
+
numpy
|
5 |
+
joblib
|
6 |
+
gradio
|