jaimin commited on
Commit
a8edf80
·
1 Parent(s): 58c96a8

Create adequacy.py

Browse files
Files changed (1) hide show
  1. adequacy.py +40 -0
adequacy.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+
3
+ tokenizer = AutoTokenizer.from_pretrained("jaimin/parrot_adequacy_model")
4
+
5
+ model = AutoModelForSequenceClassification.from_pretrained("jaimin/parrot_adequacy_model")
6
+
7
+
8
+ class Adequacy():
9
+
10
+ def __init__(self, model_tag='jaimin/parrot_adequacy_model', use_auth_token="access"):
11
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
12
+ self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag,use_auth_token="access")
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_tag,use_auth_token="access")
14
+
15
+ def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
16
+ top_adequacy_phrases = []
17
+ for para_phrase in para_phrases:
18
+ x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
19
+ self.adequacy_model = self.adequacy_model.to(device)
20
+ logits = self.adequacy_model(**x).logits
21
+ probs = logits.softmax(dim=1)
22
+ prob_label_is_true = probs[:, 1]
23
+ adequacy_score = prob_label_is_true.item()
24
+ if adequacy_score >= adequacy_threshold:
25
+ top_adequacy_phrases.append(para_phrase)
26
+ return top_adequacy_phrases
27
+
28
+ def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
29
+ adequacy_scores = {}
30
+ for para_phrase in para_phrases:
31
+ x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
32
+ x = x.to(device)
33
+ self.adequacy_model = self.adequacy_model.to(device)
34
+ logits = self.adequacy_model(**x).logits
35
+ probs = logits.softmax(dim=1)
36
+ prob_label_is_true = probs[:, 1]
37
+ adequacy_score = prob_label_is_true.item()
38
+ if adequacy_score >= adequacy_threshold:
39
+ adequacy_scores[para_phrase] = adequacy_score
40
+ return adequacy_scores