Tbruand
commited on
Commit
·
0005ea7
1
Parent(s):
52e008b
feat(models): retourne tous les labels et scores dans ZeroShotModel.predict
Browse files- models/zero_shot.py +2 -4
models/zero_shot.py
CHANGED
|
@@ -5,9 +5,7 @@ class ZeroShotModel(BaseModel):
|
|
| 5 |
def __init__(self):
|
| 6 |
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
| 7 |
|
| 8 |
-
def predict(self, text: str) -> tuple[str, float]:
|
| 9 |
labels = ["toxique", "non-toxique"]
|
| 10 |
result = self.classifier(text, candidate_labels=labels)
|
| 11 |
-
|
| 12 |
-
score = result["scores"][0]
|
| 13 |
-
return label, score
|
|
|
|
| 5 |
def __init__(self):
|
| 6 |
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
| 7 |
|
| 8 |
+
def predict(self, text: str) -> list[tuple[str, float]]:
|
| 9 |
labels = ["toxique", "non-toxique"]
|
| 10 |
result = self.classifier(text, candidate_labels=labels)
|
| 11 |
+
return list(zip(result["labels"], result["scores"]))
|
|
|
|
|
|