File size: 485 Bytes
9decc9d 7301ed6 9decc9d 7301ed6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from transformers import pipeline
from models.base import BaseModel
class ZeroShotModel(BaseModel):
def __init__(self):
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
def predict(self, text: str) -> tuple[str, float]:
labels = ["toxique", "non-toxique"]
result = self.classifier(text, candidate_labels=labels)
label = result["labels"][0]
score = result["scores"][0]
return label, score |