feat: merge de la fonctionnalité d'affichage des deux scores de classification
Browse files- Ajout de markdown dans le handler
- Modification du modèle ZeroShotModel pour retourner tous les labels
- Affichage markdown dans Gradio
- Test adapté
- app/handler.py +5 -1
- app/interface.py +3 -8
- models/zero_shot.py +2 -4
- tests/test_handler.py +8 -9
app/handler.py
CHANGED
@@ -3,4 +3,8 @@ from models.zero_shot import ZeroShotModel
|
|
3 |
model = ZeroShotModel()
|
4 |
|
5 |
def predict(text: str) -> str:
|
6 |
-
|
|
|
|
|
|
|
|
|
|
3 |
model = ZeroShotModel()
|
4 |
|
5 |
def predict(text: str) -> str:
|
6 |
+
results = model.predict(text)
|
7 |
+
output = "### Résultat de la classification :\n\n"
|
8 |
+
for label, score in results:
|
9 |
+
output += f"- **{label}** : {score*100:.1f}%\n"
|
10 |
+
return output
|
app/interface.py
CHANGED
@@ -1,17 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
-
from app.handler import predict
|
3 |
-
|
4 |
-
def predict(text: str) -> str:
|
5 |
-
label, score = model.predict(text)
|
6 |
-
return f"{label} ({score*100:.1f}%)"
|
7 |
-
|
8 |
|
9 |
def launch_app():
|
10 |
iface = gr.Interface(
|
11 |
fn=predict,
|
12 |
inputs="text",
|
13 |
-
outputs="
|
14 |
title="🧪 ToxiCheck",
|
15 |
-
description="Entrez un texte pour détecter s'il est toxique. Résultat avec score de confiance."
|
16 |
)
|
17 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from app.handler import predict
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def launch_app():
|
5 |
iface = gr.Interface(
|
6 |
fn=predict,
|
7 |
inputs="text",
|
8 |
+
outputs="markdown",
|
9 |
title="🧪 ToxiCheck",
|
10 |
+
description="Entrez un texte pour détecter s'il est toxique. Résultat avec score de confiance pour chaque label."
|
11 |
)
|
12 |
iface.launch()
|
models/zero_shot.py
CHANGED
@@ -5,9 +5,7 @@ class ZeroShotModel(BaseModel):
|
|
5 |
def __init__(self):
|
6 |
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
7 |
|
8 |
-
def predict(self, text: str) -> tuple[str, float]:
|
9 |
labels = ["toxique", "non-toxique"]
|
10 |
result = self.classifier(text, candidate_labels=labels)
|
11 |
-
|
12 |
-
score = result["scores"][0]
|
13 |
-
return label, score
|
|
|
5 |
def __init__(self):
|
6 |
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
7 |
|
8 |
+
def predict(self, text: str) -> list[tuple[str, float]]:
|
9 |
labels = ["toxique", "non-toxique"]
|
10 |
result = self.classifier(text, candidate_labels=labels)
|
11 |
+
return list(zip(result["labels"], result["scores"]))
|
|
|
|
tests/test_handler.py
CHANGED
@@ -5,14 +5,13 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")
|
|
5 |
from app.handler import predict
|
6 |
|
7 |
def test_zero_shot_prediction_output():
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
label_1, score_1 = predict(text_1)
|
12 |
-
label_2, score_2 = predict(text_2)
|
13 |
|
14 |
-
print(
|
15 |
-
print(f"Prediction 2: {label_2} ({score_2:.2f})")
|
16 |
|
17 |
-
|
18 |
-
assert
|
|
|
|
|
|
|
|
5 |
from app.handler import predict
|
6 |
|
7 |
def test_zero_shot_prediction_output():
|
8 |
+
text = "Tu es complètement stupide"
|
9 |
+
output = predict(text)
|
|
|
|
|
|
|
10 |
|
11 |
+
print("Résultat brut :", output)
|
|
|
12 |
|
13 |
+
# Vérifie que le format markdown est respecté
|
14 |
+
assert "### Résultat de la classification" in output
|
15 |
+
assert "**toxique**" in output
|
16 |
+
assert "**non-toxique**" in output
|
17 |
+
assert "%" in output
|