jaimin commited on
Commit
6cfb2f0
·
1 Parent(s): 402f3a4

Create styleformer.py

Browse files
Files changed (1) hide show
  1. styleformer.py +161 -0
styleformer.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Styleformer():
2
+
3
+ def __init__(
4
+ self,
5
+ style=0,
6
+ ctf_model_tag="jaimin/Informal_to_formal",
7
+ ftc_model_tag="jaimin/formal_to_informal",
8
+ atp_model_tag="jaimin/Active_to_passive",
9
+ pta_model_tag="jaimin/Passive_to_active",
10
+ adequacy_model_tag="jaimin/parrot_adequacy_model",
11
+ ):
12
+ from transformers import AutoTokenizer
13
+ from transformers import AutoModelForSeq2SeqLM
14
+ from adequacy import Adequacy
15
+
16
+ self.style = style
17
+ self.adequacy = adequacy_model_tag and Adequacy(model_tag=adequacy_model_tag, use_auth_token="access")
18
+ self.model_loaded = False
19
+
20
+ if self.style == 0:
21
+ self.ctf_tokenizer = AutoTokenizer.from_pretrained(ctf_model_tag, use_auth_token="access")
22
+ self.ctf_model = AutoModelForSeq2SeqLM.from_pretrained(ctf_model_tag, use_auth_token="access")
23
+ print("Casual to Formal model loaded...")
24
+ self.model_loaded = True
25
+ elif self.style == 1:
26
+ self.ftc_tokenizer = AutoTokenizer.from_pretrained(ftc_model_tag, use_auth_token="access")
27
+ self.ftc_model = AutoModelForSeq2SeqLM.from_pretrained(ftc_model_tag, use_auth_token="access")
28
+ print("Formal to Casual model loaded...")
29
+ self.model_loaded = True
30
+ elif self.style == 2:
31
+ self.atp_tokenizer = AutoTokenizer.from_pretrained(atp_model_tag,use_auth_token="access")
32
+ self.atp_model = AutoModelForSeq2SeqLM.from_pretrained(atp_model_tag,use_auth_token="access")
33
+ print("Active to Passive model loaded...")
34
+ self.model_loaded = True
35
+ elif self.style == 3:
36
+ self.pta_tokenizer = AutoTokenizer.from_pretrained(pta_model_tag,use_auth_token="access")
37
+ self.pta_model = AutoModelForSeq2SeqLM.from_pretrained(pta_model_tag,use_auth_token="access")
38
+ print("Passive to Active model loaded...")
39
+ self.model_loaded = True
40
+ else:
41
+ print("Only CTF, FTC, ATP and PTA are supported in the pre-release...stay tuned")
42
+
43
+ def transfer(self, input_sentence, inference_on=-1, quality_filter=0.95, max_candidates=5):
44
+ if self.model_loaded:
45
+ if inference_on == -1:
46
+ device = "cpu"
47
+ elif inference_on >= 0 and inference_on < 999:
48
+ device = "cpu:" + str(inference_on)
49
+ else:
50
+ device = "cpu"
51
+ print("Onnx + Quantisation is not supported in the pre-release...stay tuned.")
52
+
53
+ if self.style == 0:
54
+ output_sentence = self._casual_to_formal(input_sentence, device, quality_filter, max_candidates)
55
+ return output_sentence
56
+ elif self.style == 1:
57
+ output_sentence = self._formal_to_casual(input_sentence, device, quality_filter, max_candidates)
58
+ return output_sentence
59
+ elif self.style == 2:
60
+ output_sentence = self._active_to_passive(input_sentence, device)
61
+ return output_sentence
62
+ elif self.style == 3:
63
+ output_sentence = self._passive_to_active(input_sentence, device)
64
+ return output_sentence
65
+
66
+ else:
67
+ print("Models aren't loaded for this style, please use the right style during init")
68
+
69
+ def _formal_to_casual(self, input_sentence, device, quality_filter, max_candidates):
70
+ ftc_prefix = "transfer Formal to Casual: "
71
+ src_sentence = input_sentence
72
+ input_sentence = ftc_prefix + input_sentence
73
+ input_ids = self.ftc_tokenizer.encode(input_sentence, return_tensors='pt')
74
+ self.ftc_model = self.ftc_model.to(device)
75
+ input_ids = input_ids.to(device)
76
+
77
+ preds = self.ftc_model.generate(
78
+ input_ids,
79
+ do_sample=True,
80
+ max_length=32,
81
+ top_k=50,
82
+ top_p=0.95,
83
+ early_stopping=True,
84
+ num_return_sequences=max_candidates)
85
+
86
+ gen_sentences = set()
87
+ for pred in preds:
88
+ gen_sentences.add(self.ftc_tokenizer.decode(pred, skip_special_tokens=True).strip())
89
+
90
+ adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
91
+ ranked_sentences = sorted(adequacy_scored_phrases.items(), key=lambda x: x[1], reverse=True)
92
+ if len(ranked_sentences) > 0:
93
+ return ranked_sentences[0][0]
94
+ else:
95
+ return None
96
+
97
+ def _casual_to_formal(self, input_sentence, device, quality_filter, max_candidates):
98
+ ctf_prefix = "transfer Casual to Formal: "
99
+ src_sentence = input_sentence
100
+ input_sentence = ctf_prefix + input_sentence
101
+ input_ids = self.ctf_tokenizer.encode(input_sentence, return_tensors='pt')
102
+ self.ctf_model = self.ctf_model.to(device)
103
+ input_ids = input_ids.to(device)
104
+
105
+ preds = self.ctf_model.generate(
106
+ input_ids,
107
+ do_sample=True,
108
+ max_length=32,
109
+ top_k=50,
110
+ top_p=0.95,
111
+ early_stopping=True,
112
+ num_return_sequences=max_candidates)
113
+
114
+ gen_sentences = set()
115
+ for pred in preds:
116
+ gen_sentences.add(self.ctf_tokenizer.decode(pred, skip_special_tokens=True).strip())
117
+
118
+ adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
119
+ ranked_sentences = sorted(adequacy_scored_phrases.items(), key=lambda x: x[1], reverse=True)
120
+ if len(ranked_sentences) > 0:
121
+ return ranked_sentences[0][0]
122
+ else:
123
+ return None
124
+
125
+ def _active_to_passive(self, input_sentence, device):
126
+ atp_prefix = "transfer Active to Passive: "
127
+ src_sentence = input_sentence
128
+ input_sentence = atp_prefix + input_sentence
129
+ input_ids = self.atp_tokenizer.encode(input_sentence, return_tensors='pt')
130
+ self.atp_model = self.atp_model.to(device)
131
+ input_ids = input_ids.to(device)
132
+
133
+ preds = self.atp_model.generate(
134
+ input_ids,
135
+ do_sample=True,
136
+ max_length=32,
137
+ top_k=50,
138
+ top_p=0.95,
139
+ early_stopping=True,
140
+ num_return_sequences=1)
141
+
142
+ return self.atp_tokenizer.decode(preds[0], skip_special_tokens=True).strip()
143
+
144
+ def _passive_to_active(self, input_sentence, device):
145
+ pta_prefix = "transfer Passive to Active: "
146
+ src_sentence = input_sentence
147
+ input_sentence = pta_prefix + input_sentence
148
+ input_ids = self.pta_tokenizer.encode(input_sentence, return_tensors='pt')
149
+ self.pta_model = self.pta_model.to(device)
150
+ input_ids = input_ids.to(device)
151
+
152
+ preds = self.pta_model.generate(
153
+ input_ids,
154
+ do_sample=True,
155
+ max_length=32,
156
+ top_k=50,
157
+ top_p=0.95,
158
+ early_stopping=True,
159
+ num_return_sequences=1)
160
+
161
+ return self.pta_tokenizer.decode(preds[0], skip_special_tokens=True).strip()