Merge pull request #2 from Tbruand/feat/few-shot-model
Browse filesfeat(handler): fusion du modèle few-shot et sélection dynamique dans l’interface
- Ajout d’un modèle few-shot basé sur roberta
- Support de la sélection dynamique du modèle (zero-shot / few-shot)
- Affichage formaté Markdown avec scores
- Couverture testée par `pytest`
- app/handler.py +16 -7
- app/interface.py +4 -1
- models/few_shot.py +14 -0
- tests/test_handler.py +11 -1
app/handler.py
CHANGED
@@ -1,10 +1,19 @@
|
|
1 |
from models.zero_shot import ZeroShotModel
|
|
|
2 |
|
3 |
-
|
|
|
4 |
|
5 |
-
def predict(text: str) -> str:
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from models.zero_shot import ZeroShotModel
|
2 |
+
from models.few_shot import FewShotModel
|
3 |
|
4 |
+
zero_shot_model = ZeroShotModel()
|
5 |
+
few_shot_model = FewShotModel()
|
6 |
|
7 |
+
def predict(text: str, model_type: str = "zero-shot") -> str:
|
8 |
+
if model_type == "few-shot":
|
9 |
+
results = few_shot_model.predict(text)
|
10 |
+
output = "### Résultat de la classification (Few-Shot) :\n\n"
|
11 |
+
for label, score in results:
|
12 |
+
output += f"- **{label}** : {score*100:.1f}%\n"
|
13 |
+
return output
|
14 |
+
else:
|
15 |
+
results = zero_shot_model.predict(text)
|
16 |
+
output = "### Résultat de la classification (Zero-Shot) :\n\n"
|
17 |
+
for label, score in results:
|
18 |
+
output += f"- **{label}** : {score*100:.1f}%\n"
|
19 |
+
return output
|
app/interface.py
CHANGED
@@ -4,7 +4,10 @@ from app.handler import predict
|
|
4 |
def launch_app():
|
5 |
iface = gr.Interface(
|
6 |
fn=predict,
|
7 |
-
inputs=
|
|
|
|
|
|
|
8 |
outputs="markdown",
|
9 |
title="🧪 ToxiCheck",
|
10 |
description="Entrez un texte pour détecter s'il est toxique. Résultat avec score de confiance pour chaque label."
|
|
|
4 |
def launch_app():
|
5 |
iface = gr.Interface(
|
6 |
fn=predict,
|
7 |
+
inputs=[
|
8 |
+
gr.Textbox(label="Texte à analyser"),
|
9 |
+
gr.Dropdown(choices=["zero-shot", "few-shot"], label="Type de modèle", value="zero-shot")
|
10 |
+
],
|
11 |
outputs="markdown",
|
12 |
title="🧪 ToxiCheck",
|
13 |
description="Entrez un texte pour détecter s'il est toxique. Résultat avec score de confiance pour chaque label."
|
models/few_shot.py
CHANGED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.base import BaseModel
|
2 |
+
from transformers import pipeline
|
3 |
+
|
4 |
+
class FewShotModel(BaseModel):
|
5 |
+
def __init__(self):
|
6 |
+
# On utilise un modèle préentraîné pour la classification de texte
|
7 |
+
self.classifier = pipeline("text-classification", model="textattack/roberta-base-rotten-tomatoes")
|
8 |
+
|
9 |
+
def predict(self, text: str) -> list[tuple[str, float]]:
|
10 |
+
result = self.classifier(text, truncation=True)[0]
|
11 |
+
label = result["label"].lower()
|
12 |
+
score = result["score"]
|
13 |
+
label = "non-toxique" if "pos" in label else "toxique"
|
14 |
+
return [(label, score)]
|
tests/test_handler.py
CHANGED
@@ -14,4 +14,14 @@ def test_zero_shot_prediction_output():
|
|
14 |
assert "### Résultat de la classification" in output
|
15 |
assert "**toxique**" in output
|
16 |
assert "**non-toxique**" in output
|
17 |
-
assert "%" in output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
assert "### Résultat de la classification" in output
|
15 |
assert "**toxique**" in output
|
16 |
assert "**non-toxique**" in output
|
17 |
+
assert "%" in output
|
18 |
+
|
19 |
+
def test_few_shot_prediction_output():
|
20 |
+
from app.handler import predict
|
21 |
+
text = "Tu es un abruti fini"
|
22 |
+
output = predict(text, model_type="few-shot")
|
23 |
+
|
24 |
+
print("Résultat few-shot :", output)
|
25 |
+
|
26 |
+
assert "### Résultat de la classification" in output
|
27 |
+
assert "toxique" in output or "non-toxique" in output
|