File size: 2,021 Bytes
a8edf80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("jaimin/parrot_adequacy_model")

model = AutoModelForSequenceClassification.from_pretrained("jaimin/parrot_adequacy_model")


class Adequacy():

    def __init__(self, model_tag='jaimin/parrot_adequacy_model', use_auth_token="access"):
        from transformers import AutoModelForSequenceClassification, AutoTokenizer
        self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag,use_auth_token="access")
        self.tokenizer = AutoTokenizer.from_pretrained(model_tag,use_auth_token="access")

    def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
        top_adequacy_phrases = []
        for para_phrase in para_phrases:
            x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
            self.adequacy_model = self.adequacy_model.to(device)
            logits = self.adequacy_model(**x).logits
            probs = logits.softmax(dim=1)
            prob_label_is_true = probs[:, 1]
            adequacy_score = prob_label_is_true.item()
            if adequacy_score >= adequacy_threshold:
                top_adequacy_phrases.append(para_phrase)
        return top_adequacy_phrases

    def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
        adequacy_scores = {}
        for para_phrase in para_phrases:
            x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
            x = x.to(device)
            self.adequacy_model = self.adequacy_model.to(device)
            logits = self.adequacy_model(**x).logits
            probs = logits.softmax(dim=1)
            prob_label_is_true = probs[:, 1]
            adequacy_score = prob_label_is_true.item()
            if adequacy_score >= adequacy_threshold:
                adequacy_scores[para_phrase] = adequacy_score
        return adequacy_scores