Tbruand
commited on
Commit
·
402868e
1
Parent(s):
0cf4a6c
feat(models): ajoute un retour (label, score) au modèle few-shot pour uniformiser la sortie
Browse files- models/few_shot.py +4 -3
models/few_shot.py
CHANGED
|
@@ -6,8 +6,9 @@ class FewShotModel(BaseModel):
|
|
| 6 |
# On utilise un modèle préentraîné pour la classification de texte
|
| 7 |
self.classifier = pipeline("text-classification", model="textattack/roberta-base-rotten-tomatoes")
|
| 8 |
|
| 9 |
-
def predict(self, text: str) -> str:
|
| 10 |
result = self.classifier(text, truncation=True)[0]
|
| 11 |
label = result["label"].lower()
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
| 6 |
# On utilise un modèle préentraîné pour la classification de texte
|
| 7 |
self.classifier = pipeline("text-classification", model="textattack/roberta-base-rotten-tomatoes")
|
| 8 |
|
| 9 |
+
def predict(self, text: str) -> list[tuple[str, float]]:
|
| 10 |
result = self.classifier(text, truncation=True)[0]
|
| 11 |
label = result["label"].lower()
|
| 12 |
+
score = result["score"]
|
| 13 |
+
label = "non-toxique" if "pos" in label else "toxique"
|
| 14 |
+
return [(label, score)]
|