File size: 485 Bytes
9decc9d
 
 
 
 
 
 
7301ed6
9decc9d
 
7301ed6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from transformers import pipeline
from models.base import BaseModel

class ZeroShotModel(BaseModel):
    def __init__(self):
        self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

    def predict(self, text: str) -> tuple[str, float]:
        labels = ["toxique", "non-toxique"]
        result = self.classifier(text, candidate_labels=labels)
        label = result["labels"][0]
        score = result["scores"][0]
        return label, score