File size: 7,085 Bytes
6cfb2f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
class Styleformer():
def __init__(
self,
style=0,
ctf_model_tag="jaimin/Informal_to_formal",
ftc_model_tag="jaimin/formal_to_informal",
atp_model_tag="jaimin/Active_to_passive",
pta_model_tag="jaimin/Passive_to_active",
adequacy_model_tag="jaimin/parrot_adequacy_model",
):
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from adequacy import Adequacy
self.style = style
self.adequacy = adequacy_model_tag and Adequacy(model_tag=adequacy_model_tag, use_auth_token="access")
self.model_loaded = False
if self.style == 0:
self.ctf_tokenizer = AutoTokenizer.from_pretrained(ctf_model_tag, use_auth_token="access")
self.ctf_model = AutoModelForSeq2SeqLM.from_pretrained(ctf_model_tag, use_auth_token="access")
print("Casual to Formal model loaded...")
self.model_loaded = True
elif self.style == 1:
self.ftc_tokenizer = AutoTokenizer.from_pretrained(ftc_model_tag, use_auth_token="access")
self.ftc_model = AutoModelForSeq2SeqLM.from_pretrained(ftc_model_tag, use_auth_token="access")
print("Formal to Casual model loaded...")
self.model_loaded = True
elif self.style == 2:
self.atp_tokenizer = AutoTokenizer.from_pretrained(atp_model_tag,use_auth_token="access")
self.atp_model = AutoModelForSeq2SeqLM.from_pretrained(atp_model_tag,use_auth_token="access")
print("Active to Passive model loaded...")
self.model_loaded = True
elif self.style == 3:
self.pta_tokenizer = AutoTokenizer.from_pretrained(pta_model_tag,use_auth_token="access")
self.pta_model = AutoModelForSeq2SeqLM.from_pretrained(pta_model_tag,use_auth_token="access")
print("Passive to Active model loaded...")
self.model_loaded = True
else:
print("Only CTF, FTC, ATP and PTA are supported in the pre-release...stay tuned")
def transfer(self, input_sentence, inference_on=-1, quality_filter=0.95, max_candidates=5):
if self.model_loaded:
if inference_on == -1:
device = "cpu"
elif inference_on >= 0 and inference_on < 999:
device = "cpu:" + str(inference_on)
else:
device = "cpu"
print("Onnx + Quantisation is not supported in the pre-release...stay tuned.")
if self.style == 0:
output_sentence = self._casual_to_formal(input_sentence, device, quality_filter, max_candidates)
return output_sentence
elif self.style == 1:
output_sentence = self._formal_to_casual(input_sentence, device, quality_filter, max_candidates)
return output_sentence
elif self.style == 2:
output_sentence = self._active_to_passive(input_sentence, device)
return output_sentence
elif self.style == 3:
output_sentence = self._passive_to_active(input_sentence, device)
return output_sentence
else:
print("Models aren't loaded for this style, please use the right style during init")
def _formal_to_casual(self, input_sentence, device, quality_filter, max_candidates):
ftc_prefix = "transfer Formal to Casual: "
src_sentence = input_sentence
input_sentence = ftc_prefix + input_sentence
input_ids = self.ftc_tokenizer.encode(input_sentence, return_tensors='pt')
self.ftc_model = self.ftc_model.to(device)
input_ids = input_ids.to(device)
preds = self.ftc_model.generate(
input_ids,
do_sample=True,
max_length=32,
top_k=50,
top_p=0.95,
early_stopping=True,
num_return_sequences=max_candidates)
gen_sentences = set()
for pred in preds:
gen_sentences.add(self.ftc_tokenizer.decode(pred, skip_special_tokens=True).strip())
adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
ranked_sentences = sorted(adequacy_scored_phrases.items(), key=lambda x: x[1], reverse=True)
if len(ranked_sentences) > 0:
return ranked_sentences[0][0]
else:
return None
def _casual_to_formal(self, input_sentence, device, quality_filter, max_candidates):
ctf_prefix = "transfer Casual to Formal: "
src_sentence = input_sentence
input_sentence = ctf_prefix + input_sentence
input_ids = self.ctf_tokenizer.encode(input_sentence, return_tensors='pt')
self.ctf_model = self.ctf_model.to(device)
input_ids = input_ids.to(device)
preds = self.ctf_model.generate(
input_ids,
do_sample=True,
max_length=32,
top_k=50,
top_p=0.95,
early_stopping=True,
num_return_sequences=max_candidates)
gen_sentences = set()
for pred in preds:
gen_sentences.add(self.ctf_tokenizer.decode(pred, skip_special_tokens=True).strip())
adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
ranked_sentences = sorted(adequacy_scored_phrases.items(), key=lambda x: x[1], reverse=True)
if len(ranked_sentences) > 0:
return ranked_sentences[0][0]
else:
return None
def _active_to_passive(self, input_sentence, device):
atp_prefix = "transfer Active to Passive: "
src_sentence = input_sentence
input_sentence = atp_prefix + input_sentence
input_ids = self.atp_tokenizer.encode(input_sentence, return_tensors='pt')
self.atp_model = self.atp_model.to(device)
input_ids = input_ids.to(device)
preds = self.atp_model.generate(
input_ids,
do_sample=True,
max_length=32,
top_k=50,
top_p=0.95,
early_stopping=True,
num_return_sequences=1)
return self.atp_tokenizer.decode(preds[0], skip_special_tokens=True).strip()
def _passive_to_active(self, input_sentence, device):
pta_prefix = "transfer Passive to Active: "
src_sentence = input_sentence
input_sentence = pta_prefix + input_sentence
input_ids = self.pta_tokenizer.encode(input_sentence, return_tensors='pt')
self.pta_model = self.pta_model.to(device)
input_ids = input_ids.to(device)
preds = self.pta_model.generate(
input_ids,
do_sample=True,
max_length=32,
top_k=50,
top_p=0.95,
early_stopping=True,
num_return_sequences=1)
return self.pta_tokenizer.decode(preds[0], skip_special_tokens=True).strip() |