Active / adequacy.py
jaimin's picture
Create adequacy.py
a8edf80
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("jaimin/parrot_adequacy_model")
model = AutoModelForSequenceClassification.from_pretrained("jaimin/parrot_adequacy_model")
class Adequacy():
def __init__(self, model_tag='jaimin/parrot_adequacy_model', use_auth_token="access"):
from transformers import AutoModelForSequenceClassification, AutoTokenizer
self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag,use_auth_token="access")
self.tokenizer = AutoTokenizer.from_pretrained(model_tag,use_auth_token="access")
def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
top_adequacy_phrases = []
for para_phrase in para_phrases:
x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
self.adequacy_model = self.adequacy_model.to(device)
logits = self.adequacy_model(**x).logits
probs = logits.softmax(dim=1)
prob_label_is_true = probs[:, 1]
adequacy_score = prob_label_is_true.item()
if adequacy_score >= adequacy_threshold:
top_adequacy_phrases.append(para_phrase)
return top_adequacy_phrases
def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
adequacy_scores = {}
for para_phrase in para_phrases:
x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
x = x.to(device)
self.adequacy_model = self.adequacy_model.to(device)
logits = self.adequacy_model(**x).logits
probs = logits.softmax(dim=1)
prob_label_is_true = probs[:, 1]
adequacy_score = prob_label_is_true.item()
if adequacy_score >= adequacy_threshold:
adequacy_scores[para_phrase] = adequacy_score
return adequacy_scores