admchiem commited on
Commit
5121a37
·
1 Parent(s): ef8bc9e

Initial commit with LFS-tracked model and assets

Browse files
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import model # Ton modèle préparé dans `inference.py`
3
+
4
+ def predict_intent(text):
5
+ result = model.predict(text)
6
+ return f"Intent: {result['intent']}\nIs OOS: {result['is_oos']}\nConfidence: {result['confidence']}\nOOS Score: {result['oos_score']:.3f}"
7
+
8
+ iface = gr.Interface(
9
+ fn=predict_intent,
10
+ inputs=gr.Textbox(lines=2, placeholder="Enter a sentence..."),
11
+ outputs="text",
12
+ title="Intent & OOS Detector",
13
+ description="Type a sentence and get its predicted intent or detect if it's out-of-scope."
14
+ )
15
+
16
+ iface.launch()
class_centroids.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f267dbd996b67e662479e83e2fb644af4b67a84db0cfe453962404479a1541f
3
+ size 242642
inference.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import joblib
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import pytorch_lightning as pl
7
+ from sklearn.metrics.pairwise import cosine_distances
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ class IntentClassifier(pl.LightningModule):
11
+ def __init__(self, input_dim=384, hidden_dim=256, output_dim=150, lr=1e-3, weight_decay=1e-4):
12
+ super().__init__()
13
+ self.save_hyperparameters()
14
+
15
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
16
+ self.bn1 = nn.BatchNorm1d(hidden_dim)
17
+ self.relu = nn.ReLU()
18
+ self.dropout = nn.Dropout(0.3)
19
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
20
+
21
+ def forward(self, x):
22
+ x = self.fc1(x)
23
+ x = self.bn1(x)
24
+ x = self.relu(x)
25
+ x = self.dropout(x)
26
+ return self.fc2(x)
27
+
28
+ def training_step(self, batch, batch_idx):
29
+ x, y = batch
30
+ logits = self(x)
31
+ loss = F.cross_entropy(logits, y)
32
+ self.log("train_loss", loss)
33
+ return loss
34
+
35
+
36
+ def validation_step(self, batch, batch_idx):
37
+ x, y = batch
38
+ logits = self(x)
39
+ preds = torch.argmax(logits, dim=1)
40
+
41
+
42
+ mask = y != -1
43
+
44
+ if mask.sum() > 0:
45
+ val_loss = F.cross_entropy(logits[mask], y[mask])
46
+ val_acc = (preds[mask] == y[mask]).float().mean()
47
+ else:
48
+ val_loss = torch.tensor(0.0, device=self.device)
49
+ val_acc = torch.tensor(0.0, device=self.device)
50
+
51
+ self.log("val_loss", val_loss, prog_bar=True)
52
+ self.log("val_acc", val_acc, prog_bar=True)
53
+
54
+
55
+ def configure_optimizers(self):
56
+ return torch.optim.Adam(
57
+ self.parameters(),
58
+ lr=self.hparams.lr,
59
+ weight_decay=1e-4
60
+ )
61
+
62
+ class IntentClassifierWithOOS:
63
+ def __init__(self, embedder, classifier, oos_detector, label_encoder, centroids_dict, oos_threshold=0.5, device="cpu"):
64
+ self.embedder = embedder # SentenceTransformer
65
+ self.classifier = classifier.eval().to(device) # MLP
66
+ self.oos_detector = oos_detector # pipeline sklearn
67
+ self.label_encoder = label_encoder # fitted LabelEncoder
68
+ self.centroids_dict = centroids_dict # {class_id: centroid}
69
+ self.threshold = oos_threshold
70
+ self.device = device
71
+
72
+ def _compute_features(self, embedding, logits, predicted_class):
73
+ probs = F.softmax(logits, dim=0).cpu().numpy()
74
+ entropy = -np.sum(probs * np.log(probs + 1e-10))
75
+ msp = np.max(probs)
76
+ energy = torch.logsumexp(logits, dim=0).item()
77
+
78
+ # Logit gap
79
+ sorted_logits = torch.sort(logits, descending=True).values
80
+ logit_gap = (sorted_logits[0] - sorted_logits[1]).item()
81
+
82
+ # Euclidean distance to class centroid
83
+ centroid = self.centroids_dict.get(predicted_class)
84
+ dist = np.linalg.norm(embedding - centroid) if centroid is not None else np.nan
85
+
86
+ # Cosine distance
87
+ cos_dist = cosine_distances([embedding], [centroid])[0][0] if centroid is not None else np.nan
88
+
89
+ norm_emb = np.linalg.norm(embedding)
90
+
91
+ return np.array([entropy, msp, dist])
92
+
93
+
94
+ def predict(self, sentence):
95
+ # 1. Embedding
96
+ embedding = self.embedder.encode(sentence)
97
+ embedding = np.array(embedding)
98
+ embedding_tensor = torch.tensor(embedding, dtype=torch.float32).unsqueeze(0).to(self.device)
99
+
100
+ # 2. Intent prediction (MLP)
101
+ with torch.no_grad():
102
+ logits = self.classifier(embedding_tensor)
103
+ logits = logits.squeeze(0)
104
+ probs = F.softmax(logits, dim=0)
105
+ predicted_class = torch.argmax(probs).item()
106
+ confidence = probs[predicted_class].item()
107
+
108
+ # 3. Feature extraction
109
+ features = self._compute_features(embedding, logits, predicted_class).reshape(1, -1)
110
+
111
+ # 4. OOS detection
112
+ oos_score = self.oos_detector.predict_proba(features)[0, 1]
113
+ is_oos = oos_score >= self.threshold
114
+
115
+ # 5. Output
116
+ return {
117
+ "intent": "oos" if is_oos else self.label_encoder.inverse_transform([predicted_class])[0],
118
+ "is_oos": bool(is_oos),
119
+ "confidence": None if is_oos else confidence,
120
+ "oos_score": oos_score
121
+ }
122
+
123
+
124
+ # Load all saved components from the current directory
125
+ best_model = IntentClassifier.load_from_checkpoint(
126
+ "intent_classifier.ckpt",
127
+ map_location=torch.device("cpu")
128
+ )
129
+
130
+ oos_detector = joblib.load("oos_detector.pkl")
131
+ label_encoder = joblib.load("label_encoder.pkl")
132
+ class_centroids = joblib.load("class_centroids.pkl")
133
+ best_threshold = joblib.load("oos_threshold.pkl")
134
+
135
+ # Recharger l'embedding model
136
+ embedder = SentenceTransformer("intfloat/e5-small-v2")
137
+
138
+ # Build the full inference model
139
+ model = IntentClassifierWithOOS(
140
+ embedder=embedder,
141
+ classifier=best_model,
142
+ oos_detector=oos_detector,
143
+ label_encoder=label_encoder,
144
+ centroids_dict=class_centroids,
145
+ oos_threshold=best_threshold,
146
+ device="cpu"
147
+ )
148
+
149
+ # Test with a sample query
150
+ result = model.predict("Can you play some jazz music?")
151
+ print(result)
intent_classifier.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a35fa01002fc1cfba9e4b69cae1d2c8c8355836971c2fb35a7a6ce8153bbea20
3
+ size 1664127
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65443e0fe2e184180333ea5498c8f9a9b6ee0e60dc4e990aeeeccef021d9bd42
3
+ size 15327
oos_detector.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:300780a220a6934f3266073b8e1de7a0a0fb93f2e097f66c6739f8c598eb6e68
3
+ size 1422
oos_threshold.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48c52d53a5fff98b3a4a44a5d9f9bb903d63f5149ec727aba8efe6fb2f2276c1
3
+ size 21
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ sentence-transformers
3
+ scikit-learn
4
+ numpy
5
+ joblib
6
+ gradio