Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import random | |
| import enum | |
| import traceback | |
| import os | |
| import sys | |
| import json | |
| F_DIR = os.path.dirname(os.path.realpath(__file__)) | |
| class XlitError(enum.Enum): | |
| lang_err = "Unsupported langauge ID requested ;( Please check available languages." | |
| string_err = "String passed is incompatable ;(" | |
| internal_err = "Internal crash ;(" | |
| unknown_err = "Unknown Failure" | |
| loading_err = "Loading failed ;( Check if metadata/paths are correctly configured." | |
| class Encoder(nn.Module): | |
| """ | |
| Simple RNN based encoder network | |
| """ | |
| def __init__( | |
| self, | |
| input_dim, | |
| embed_dim, | |
| hidden_dim, | |
| rnn_type="gru", | |
| layers=1, | |
| bidirectional=False, | |
| dropout=0, | |
| device="cpu", | |
| ): | |
| super(Encoder, self).__init__() | |
| self.input_dim = input_dim # src_vocab_sz | |
| self.enc_embed_dim = embed_dim | |
| self.enc_hidden_dim = hidden_dim | |
| self.enc_rnn_type = rnn_type | |
| self.enc_layers = layers | |
| self.enc_directions = 2 if bidirectional else 1 | |
| self.device = device | |
| self.embedding = nn.Embedding(self.input_dim, self.enc_embed_dim) | |
| if self.enc_rnn_type == "gru": | |
| self.enc_rnn = nn.GRU( | |
| input_size=self.enc_embed_dim, | |
| hidden_size=self.enc_hidden_dim, | |
| num_layers=self.enc_layers, | |
| bidirectional=bidirectional, | |
| ) | |
| elif self.enc_rnn_type == "lstm": | |
| self.enc_rnn = nn.LSTM( | |
| input_size=self.enc_embed_dim, | |
| hidden_size=self.enc_hidden_dim, | |
| num_layers=self.enc_layers, | |
| bidirectional=bidirectional, | |
| ) | |
| else: | |
| raise Exception("unknown RNN type mentioned") | |
| def forward(self, x, x_sz, hidden=None): | |
| """ | |
| x_sz: (batch_size, 1) - Unpadded sequence lengths used for pack_pad | |
| Return: | |
| output: (batch_size, max_length, hidden_dim) | |
| hidden: (n_layer*num_directions, batch_size, hidden_dim) | if LSTM tuple -(h_n, c_n) | |
| """ | |
| batch_sz = x.shape[0] | |
| # x: batch_size, max_length, enc_embed_dim | |
| x = self.embedding(x) | |
| ## pack the padded data | |
| # x: max_length, batch_size, enc_embed_dim -> for pack_pad | |
| x = x.permute(1, 0, 2) | |
| x = nn.utils.rnn.pack_padded_sequence(x, x_sz, enforce_sorted=False) # unpad | |
| # output: packed_size, batch_size, enc_embed_dim --> hidden from all timesteps | |
| # hidden: n_layer**num_directions, batch_size, hidden_dim | if LSTM (h_n, c_n) | |
| output, hidden = self.enc_rnn(x) | |
| ## pad the sequence to the max length in the batch | |
| # output: max_length, batch_size, enc_emb_dim*directions) | |
| output, _ = nn.utils.rnn.pad_packed_sequence(output) | |
| # output: batch_size, max_length, hidden_dim | |
| output = output.permute(1, 0, 2) | |
| return output, hidden | |
| class Decoder(nn.Module): | |
| """ | |
| Used as decoder stage | |
| """ | |
| def __init__( | |
| self, | |
| output_dim, | |
| embed_dim, | |
| hidden_dim, | |
| rnn_type="gru", | |
| layers=1, | |
| use_attention=True, | |
| enc_outstate_dim=None, # enc_directions * enc_hidden_dim | |
| dropout=0, | |
| device="cpu", | |
| ): | |
| super(Decoder, self).__init__() | |
| self.output_dim = output_dim # tgt_vocab_sz | |
| self.dec_hidden_dim = hidden_dim | |
| self.dec_embed_dim = embed_dim | |
| self.dec_rnn_type = rnn_type | |
| self.dec_layers = layers | |
| self.use_attention = use_attention | |
| self.device = device | |
| if self.use_attention: | |
| self.enc_outstate_dim = enc_outstate_dim if enc_outstate_dim else hidden_dim | |
| else: | |
| self.enc_outstate_dim = 0 | |
| self.embedding = nn.Embedding(self.output_dim, self.dec_embed_dim) | |
| if self.dec_rnn_type == "gru": | |
| self.dec_rnn = nn.GRU( | |
| input_size=self.dec_embed_dim | |
| + self.enc_outstate_dim, # to concat attention_output | |
| hidden_size=self.dec_hidden_dim, # previous Hidden | |
| num_layers=self.dec_layers, | |
| batch_first=True, | |
| ) | |
| elif self.dec_rnn_type == "lstm": | |
| self.dec_rnn = nn.LSTM( | |
| input_size=self.dec_embed_dim | |
| + self.enc_outstate_dim, # to concat attention_output | |
| hidden_size=self.dec_hidden_dim, # previous Hidden | |
| num_layers=self.dec_layers, | |
| batch_first=True, | |
| ) | |
| else: | |
| raise Exception("unknown RNN type mentioned") | |
| self.fc = nn.Sequential( | |
| nn.Linear(self.dec_hidden_dim, self.dec_embed_dim), | |
| nn.LeakyReLU(), | |
| # nn.Linear(self.dec_embed_dim, self.dec_embed_dim), nn.LeakyReLU(), # removing to reduce size | |
| nn.Linear(self.dec_embed_dim, self.output_dim), | |
| ) | |
| ##----- Attention ---------- | |
| if self.use_attention: | |
| self.W1 = nn.Linear(self.enc_outstate_dim, self.dec_hidden_dim) | |
| self.W2 = nn.Linear(self.dec_hidden_dim, self.dec_hidden_dim) | |
| self.V = nn.Linear(self.dec_hidden_dim, 1) | |
| def attention(self, x, hidden, enc_output): | |
| """ | |
| x: (batch_size, 1, dec_embed_dim) -> after Embedding | |
| enc_output: batch_size, max_length, enc_hidden_dim *num_directions | |
| hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n) | |
| """ | |
| ## perform addition to calculate the score | |
| # hidden_with_time_axis: batch_size, 1, hidden_dim | |
| ## hidden_with_time_axis = hidden.permute(1, 0, 2) ## replaced with below 2lines | |
| hidden_with_time_axis = torch.sum(hidden, axis=0) | |
| hidden_with_time_axis = hidden_with_time_axis.unsqueeze(1) | |
| # score: batch_size, max_length, hidden_dim | |
| score = torch.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)) | |
| # attention_weights: batch_size, max_length, 1 | |
| # we get 1 at the last axis because we are applying score to self.V | |
| attention_weights = torch.softmax(self.V(score), dim=1) | |
| # context_vector shape after sum == (batch_size, hidden_dim) | |
| context_vector = attention_weights * enc_output | |
| context_vector = torch.sum(context_vector, dim=1) | |
| # context_vector: batch_size, 1, hidden_dim | |
| context_vector = context_vector.unsqueeze(1) | |
| # attend_out (batch_size, 1, dec_embed_dim + hidden_size) | |
| attend_out = torch.cat((context_vector, x), -1) | |
| return attend_out, attention_weights | |
| def forward(self, x, hidden, enc_output): | |
| """ | |
| x: (batch_size, 1) | |
| enc_output: batch_size, max_length, dec_embed_dim | |
| hidden: n_layer, batch_size, hidden_size | lstm: (h_n, c_n) | |
| """ | |
| if (hidden is None) and (self.use_attention is False): | |
| raise Exception("No use of a decoder with No attention and No Hidden") | |
| batch_sz = x.shape[0] | |
| if hidden is None: | |
| # hidden: n_layers, batch_size, hidden_dim | |
| hid_for_att = torch.zeros( | |
| (self.dec_layers, batch_sz, self.dec_hidden_dim) | |
| ).to(self.device) | |
| elif self.dec_rnn_type == "lstm": | |
| hid_for_att = hidden[0] # h_n | |
| else: | |
| hid_for_att = hidden | |
| # x (batch_size, 1, dec_embed_dim) -> after embedding | |
| x = self.embedding(x) | |
| if self.use_attention: | |
| # x (batch_size, 1, dec_embed_dim + hidden_size) -> after attention | |
| # aw: (batch_size, max_length, 1) | |
| x, aw = self.attention(x, hid_for_att, enc_output) | |
| else: | |
| x, aw = x, 0 | |
| # passing the concatenated vector to the GRU | |
| # output: (batch_size, n_layers, hidden_size) | |
| # hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n) | |
| output, hidden = ( | |
| self.dec_rnn(x, hidden) if hidden is not None else self.dec_rnn(x) | |
| ) | |
| # output :shp: (batch_size * 1, hidden_size) | |
| output = output.view(-1, output.size(2)) | |
| # output :shp: (batch_size * 1, output_dim) | |
| output = self.fc(output) | |
| return output, hidden, aw | |
| class Seq2Seq(nn.Module): | |
| """ | |
| Used to construct seq2seq architecture with encoder decoder objects | |
| """ | |
| def __init__( | |
| self, encoder, decoder, pass_enc2dec_hid=False, dropout=0, device="cpu" | |
| ): | |
| super(Seq2Seq, self).__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.device = device | |
| self.pass_enc2dec_hid = pass_enc2dec_hid | |
| if self.pass_enc2dec_hid: | |
| assert ( | |
| decoder.dec_hidden_dim == encoder.enc_hidden_dim | |
| ), "Hidden Dimension of encoder and decoder must be same, or unset `pass_enc2dec_hid`" | |
| if decoder.use_attention: | |
| assert ( | |
| decoder.enc_outstate_dim | |
| == encoder.enc_directions * encoder.enc_hidden_dim | |
| ), "Set `enc_out_dim` correctly in decoder" | |
| assert ( | |
| self.pass_enc2dec_hid or decoder.use_attention | |
| ), "No use of a decoder with No attention and No Hidden from Encoder" | |
| def forward(self, src, tgt, src_sz, teacher_forcing_ratio=0): | |
| """ | |
| src: (batch_size, sequence_len.padded) | |
| tgt: (batch_size, sequence_len.padded) | |
| src_sz: [batch_size, 1] - Unpadded sequence lengths | |
| """ | |
| batch_size = tgt.shape[0] | |
| # enc_output: (batch_size, padded_seq_length, enc_hidden_dim*num_direction) | |
| # enc_hidden: (enc_layers*num_direction, batch_size, hidden_dim) | |
| enc_output, enc_hidden = self.encoder(src, src_sz) | |
| if self.pass_enc2dec_hid: | |
| # dec_hidden: dec_layers, batch_size , dec_hidden_dim | |
| dec_hidden = enc_hidden | |
| else: | |
| # dec_hidden -> Will be initialized to zeros internally | |
| dec_hidden = None | |
| # pred_vecs: (batch_size, output_dim, sequence_sz) -> shape required for CELoss | |
| pred_vecs = torch.zeros(batch_size, self.decoder.output_dim, tgt.size(1)).to( | |
| self.device | |
| ) | |
| # dec_input: (batch_size, 1) | |
| dec_input = tgt[:, 0].unsqueeze(1) # initialize to start token | |
| pred_vecs[:, 1, 0] = 1 # Initialize to start tokens all batches | |
| for t in range(1, tgt.size(1)): | |
| # dec_hidden: dec_layers, batch_size , dec_hidden_dim | |
| # dec_output: batch_size, output_dim | |
| # dec_input: (batch_size, 1) | |
| dec_output, dec_hidden, _ = self.decoder( | |
| dec_input, | |
| dec_hidden, | |
| enc_output, | |
| ) | |
| pred_vecs[:, :, t] = dec_output | |
| # # prediction: batch_size | |
| prediction = torch.argmax(dec_output, dim=1) | |
| # Teacher Forcing | |
| if random.random() < teacher_forcing_ratio: | |
| dec_input = tgt[:, t].unsqueeze(1) | |
| else: | |
| dec_input = prediction.unsqueeze(1) | |
| return pred_vecs # (batch_size, output_dim, sequence_sz) | |
| def inference(self, src, max_tgt_sz=50, debug=0): | |
| """ | |
| single input only, No batch Inferencing | |
| src: (sequence_len) | |
| debug: if True will return attention weights also | |
| """ | |
| batch_size = 1 | |
| start_tok = src[0] | |
| end_tok = src[-1] | |
| src_sz = torch.tensor([len(src)]) | |
| src_ = src.unsqueeze(0) | |
| # enc_output: (batch_size, padded_seq_length, enc_hidden_dim*num_direction) | |
| # enc_hidden: (enc_layers*num_direction, batch_size, hidden_dim) | |
| enc_output, enc_hidden = self.encoder(src_, src_sz) | |
| if self.pass_enc2dec_hid: | |
| # dec_hidden: dec_layers, batch_size , dec_hidden_dim | |
| dec_hidden = enc_hidden | |
| else: | |
| # dec_hidden -> Will be initialized to zeros internally | |
| dec_hidden = None | |
| # pred_arr: (sequence_sz, 1) -> shape required for CELoss | |
| pred_arr = torch.zeros(max_tgt_sz, 1).to(self.device) | |
| if debug: | |
| attend_weight_arr = torch.zeros(max_tgt_sz, len(src)).to(self.device) | |
| # dec_input: (batch_size, 1) | |
| dec_input = start_tok.view(1, 1) # initialize to start token | |
| pred_arr[0] = start_tok.view(1, 1) # initialize to start token | |
| for t in range(max_tgt_sz): | |
| # dec_hidden: dec_layers, batch_size , dec_hidden_dim | |
| # dec_output: batch_size, output_dim | |
| # dec_input: (batch_size, 1) | |
| dec_output, dec_hidden, aw = self.decoder( | |
| dec_input, | |
| dec_hidden, | |
| enc_output, | |
| ) | |
| # prediction :shp: (1,1) | |
| prediction = torch.argmax(dec_output, dim=1) | |
| dec_input = prediction.unsqueeze(1) | |
| pred_arr[t] = prediction | |
| if debug: | |
| attend_weight_arr[t] = aw.squeeze(-1) | |
| if torch.eq(prediction, end_tok): | |
| break | |
| if debug: | |
| return pred_arr.squeeze(), attend_weight_arr | |
| # pred_arr :shp: (sequence_len) | |
| return pred_arr.squeeze().to(dtype=torch.long) | |
| def active_beam_inference(self, src, beam_width=3, max_tgt_sz=50): | |
| """Active beam Search based decoding | |
| src: (sequence_len) | |
| """ | |
| def _avg_score(p_tup): | |
| """Used for Sorting | |
| TODO: Dividing by length of sequence power alpha as hyperparam | |
| """ | |
| return p_tup[0] | |
| batch_size = 1 | |
| start_tok = src[0] | |
| end_tok = src[-1] | |
| src_sz = torch.tensor([len(src)]) | |
| src_ = src.unsqueeze(0) | |
| # enc_output: (batch_size, padded_seq_length, enc_hidden_dim*num_direction) | |
| # enc_hidden: (enc_layers*num_direction, batch_size, hidden_dim) | |
| enc_output, enc_hidden = self.encoder(src_, src_sz) | |
| if self.pass_enc2dec_hid: | |
| # dec_hidden: dec_layers, batch_size , dec_hidden_dim | |
| init_dec_hidden = enc_hidden | |
| else: | |
| # dec_hidden -> Will be initialized to zeros internally | |
| init_dec_hidden = None | |
| # top_pred[][0] = Σ-log_softmax | |
| # top_pred[][1] = sequence torch.tensor shape: (1) | |
| # top_pred[][2] = dec_hidden | |
| top_pred_list = [(0, start_tok.unsqueeze(0), init_dec_hidden)] | |
| for t in range(max_tgt_sz): | |
| cur_pred_list = [] | |
| for p_tup in top_pred_list: | |
| if p_tup[1][-1] == end_tok: | |
| cur_pred_list.append(p_tup) | |
| continue | |
| # dec_hidden: dec_layers, 1, hidden_dim | |
| # dec_output: 1, output_dim | |
| dec_output, dec_hidden, _ = self.decoder( | |
| x=p_tup[1][-1].view(1, 1), # dec_input: (1,1) | |
| hidden=p_tup[2], | |
| enc_output=enc_output, | |
| ) | |
| ## π{prob} = Σ{log(prob)} -> to prevent diminishing | |
| # dec_output: (1, output_dim) | |
| dec_output = nn.functional.log_softmax(dec_output, dim=1) | |
| # pred_topk.values & pred_topk.indices: (1, beam_width) | |
| pred_topk = torch.topk(dec_output, k=beam_width, dim=1) | |
| for i in range(beam_width): | |
| sig_logsmx_ = p_tup[0] + pred_topk.values[0][i] | |
| # seq_tensor_ : (seq_len) | |
| seq_tensor_ = torch.cat((p_tup[1], pred_topk.indices[0][i].view(1))) | |
| cur_pred_list.append((sig_logsmx_, seq_tensor_, dec_hidden)) | |
| cur_pred_list.sort(key=_avg_score, reverse=True) # Maximized order | |
| top_pred_list = cur_pred_list[:beam_width] | |
| # check if end_tok of all topk | |
| end_flags_ = [1 if t[1][-1] == end_tok else 0 for t in top_pred_list] | |
| if beam_width == sum(end_flags_): | |
| break | |
| pred_tnsr_list = [t[1] for t in top_pred_list] | |
| return pred_tnsr_list | |
| def passive_beam_inference(self, src, beam_width=7, max_tgt_sz=50): | |
| """ | |
| Passive Beam search based inference | |
| src: (sequence_len) | |
| """ | |
| def _avg_score(p_tup): | |
| """Used for Sorting | |
| TODO: Dividing by length of sequence power alpha as hyperparam | |
| """ | |
| return p_tup[0] | |
| def _beam_search_topk(topk_obj, start_tok, beam_width): | |
| """search for sequence with maxim prob | |
| topk_obj[x]: .values & .indices shape:(1, beam_width) | |
| """ | |
| # top_pred_list[x]: tuple(prob, seq_tensor) | |
| top_pred_list = [ | |
| (0, start_tok.unsqueeze(0)), | |
| ] | |
| for obj in topk_obj: | |
| new_lst_ = list() | |
| for itm in top_pred_list: | |
| for i in range(beam_width): | |
| sig_logsmx_ = itm[0] + obj.values[0][i] | |
| seq_tensor_ = torch.cat((itm[1], obj.indices[0][i].view(1))) | |
| new_lst_.append((sig_logsmx_, seq_tensor_)) | |
| new_lst_.sort(key=_avg_score, reverse=True) | |
| top_pred_list = new_lst_[:beam_width] | |
| return top_pred_list | |
| batch_size = 1 | |
| start_tok = src[0] | |
| end_tok = src[-1] | |
| src_sz = torch.tensor([len(src)]) | |
| src_ = src.unsqueeze(0) | |
| enc_output, enc_hidden = self.encoder(src_, src_sz) | |
| if self.pass_enc2dec_hid: | |
| # dec_hidden: dec_layers, batch_size , dec_hidden_dim | |
| dec_hidden = enc_hidden | |
| else: | |
| # dec_hidden -> Will be initialized to zeros internally | |
| dec_hidden = None | |
| # dec_input: (1, 1) | |
| dec_input = start_tok.view(1, 1) # initialize to start token | |
| topk_obj = [] | |
| for t in range(max_tgt_sz): | |
| dec_output, dec_hidden, aw = self.decoder( | |
| dec_input, | |
| dec_hidden, | |
| enc_output, | |
| ) | |
| ## π{prob} = Σ{log(prob)} -> to prevent diminishing | |
| # dec_output: (1, output_dim) | |
| dec_output = nn.functional.log_softmax(dec_output, dim=1) | |
| # pred_topk.values & pred_topk.indices: (1, beam_width) | |
| pred_topk = torch.topk(dec_output, k=beam_width, dim=1) | |
| topk_obj.append(pred_topk) | |
| # dec_input: (1, 1) | |
| dec_input = pred_topk.indices[0][0].view(1, 1) | |
| if torch.eq(dec_input, end_tok): | |
| break | |
| top_pred_list = _beam_search_topk(topk_obj, start_tok, beam_width) | |
| pred_tnsr_list = [t[1] for t in top_pred_list] | |
| return pred_tnsr_list | |
| class GlyphStrawboss: | |
| def __init__(self, glyphs="en"): | |
| """list of letters in a language in unicode | |
| lang: List with unicodes | |
| """ | |
| if glyphs == "en": | |
| # Smallcase alone | |
| self.glyphs = [chr(alpha) for alpha in range(97, 123)] + ["é", "è", "á"] | |
| else: | |
| self.dossier = json.load(open(glyphs, encoding="utf-8")) | |
| self.numsym_map = self.dossier["numsym_map"] | |
| self.glyphs = self.dossier["glyphs"] | |
| self.indoarab_num = [chr(alpha) for alpha in range(48, 58)] | |
| self.char2idx = {} | |
| self.idx2char = {} | |
| self._create_index() | |
| def _create_index(self): | |
| self.char2idx["_"] = 0 # pad | |
| self.char2idx["$"] = 1 # start | |
| self.char2idx["#"] = 2 # end | |
| self.char2idx["*"] = 3 # Mask | |
| self.char2idx["'"] = 4 # apostrophe U+0027 | |
| self.char2idx["%"] = 5 # unused | |
| self.char2idx["!"] = 6 # unused | |
| self.char2idx["?"] = 7 | |
| self.char2idx[":"] = 8 | |
| self.char2idx[" "] = 9 | |
| self.char2idx["-"] = 10 | |
| self.char2idx[","] = 11 | |
| self.char2idx["."] = 12 | |
| self.char2idx["("] = 13 | |
| self.char2idx[")"] = 14 | |
| self.char2idx["/"] = 15 | |
| self.char2idx["^"] = 16 | |
| for idx, char in enumerate(self.indoarab_num): | |
| self.char2idx[char] = idx + 17 | |
| # letter to index mapping | |
| for idx, char in enumerate(self.glyphs): | |
| self.char2idx[char] = idx + 27 # +20 token initially | |
| # index to letter mapping | |
| for char, idx in self.char2idx.items(): | |
| self.idx2char[idx] = char | |
| def size(self): | |
| return len(self.char2idx) | |
| def word2xlitvec(self, word): | |
| """Converts given string of gyphs(word) to vector(numpy) | |
| Also adds tokens for start and end | |
| """ | |
| try: | |
| vec = [self.char2idx["$"]] # start token | |
| for i in list(word): | |
| vec.append(self.char2idx[i]) | |
| vec.append(self.char2idx["#"]) # end token | |
| vec = np.asarray(vec, dtype=np.int64) | |
| return vec | |
| except Exception as error: | |
| print("Error In word:", word, "Error Char not in Token:", error) | |
| sys.exit() | |
| def xlitvec2word(self, vector): | |
| """Converts vector(numpy) to string of glyphs(word)""" | |
| char_list = [] | |
| for i in vector: | |
| char_list.append(self.idx2char[i]) | |
| word = "".join(char_list).replace("$", "").replace("#", "") # remove tokens | |
| word = word.replace("_", "").replace("*", "") # remove tokens | |
| return word | |
| class XlitPiston: | |
| """ | |
| For handling prediction & post-processing of transliteration for a single language | |
| Class dependency: Seq2Seq, GlyphStrawboss | |
| Global Variables: F_DIR | |
| """ | |
| def __init__( | |
| self, weight_path, tglyph_cfg_file, iglyph_cfg_file="en", device="cpu" | |
| ): | |
| self.device = device | |
| self.in_glyph_obj = GlyphStrawboss(iglyph_cfg_file) | |
| self.tgt_glyph_obj = GlyphStrawboss(glyphs=tglyph_cfg_file) | |
| self._numsym_set = set( | |
| json.load(open(tglyph_cfg_file, encoding="utf-8"))["numsym_map"].keys() | |
| ) | |
| self._inchar_set = set("abcdefghijklmnopqrstuvwxyzéèá") | |
| self._natscr_set = set().union( | |
| self.tgt_glyph_obj.glyphs, sum(self.tgt_glyph_obj.numsym_map.values(), []) | |
| ) | |
| ## Model Config Static TODO: add defining in json support | |
| input_dim = self.in_glyph_obj.size() | |
| output_dim = self.tgt_glyph_obj.size() | |
| enc_emb_dim = 300 | |
| dec_emb_dim = 300 | |
| enc_hidden_dim = 512 | |
| dec_hidden_dim = 512 | |
| rnn_type = "lstm" | |
| enc2dec_hid = True | |
| attention = True | |
| enc_layers = 1 | |
| dec_layers = 2 | |
| m_dropout = 0 | |
| enc_bidirect = True | |
| enc_outstate_dim = enc_hidden_dim * (2 if enc_bidirect else 1) | |
| enc = Encoder( | |
| input_dim=input_dim, | |
| embed_dim=enc_emb_dim, | |
| hidden_dim=enc_hidden_dim, | |
| rnn_type=rnn_type, | |
| layers=enc_layers, | |
| dropout=m_dropout, | |
| device=self.device, | |
| bidirectional=enc_bidirect, | |
| ) | |
| dec = Decoder( | |
| output_dim=output_dim, | |
| embed_dim=dec_emb_dim, | |
| hidden_dim=dec_hidden_dim, | |
| rnn_type=rnn_type, | |
| layers=dec_layers, | |
| dropout=m_dropout, | |
| use_attention=attention, | |
| enc_outstate_dim=enc_outstate_dim, | |
| device=self.device, | |
| ) | |
| self.model = Seq2Seq(enc, dec, pass_enc2dec_hid=enc2dec_hid, device=self.device) | |
| self.model = self.model.to(self.device) | |
| weights = torch.load(weight_path, map_location=torch.device(self.device)) | |
| self.model.load_state_dict(weights) | |
| self.model.eval() | |
| def character_model(self, word, beam_width=1): | |
| in_vec = torch.from_numpy(self.in_glyph_obj.word2xlitvec(word)).to(self.device) | |
| ## change to active or passive beam | |
| p_out_list = self.model.active_beam_inference(in_vec, beam_width=beam_width) | |
| result = [ | |
| self.tgt_glyph_obj.xlitvec2word(out.cpu().numpy()) for out in p_out_list | |
| ] | |
| # List type | |
| return result | |
| def numsym_model(self, seg): | |
| """tgt_glyph_obj.numsym_map[x] returns a list object""" | |
| if len(seg) == 1: | |
| return [seg] + self.tgt_glyph_obj.numsym_map[seg] | |
| a = [self.tgt_glyph_obj.numsym_map[n][0] for n in seg] | |
| return [seg] + ["".join(a)] | |
| def _word_segementer(self, sequence): | |
| sequence = sequence.lower() | |
| accepted = set().union(self._numsym_set, self._inchar_set, self._natscr_set) | |
| # sequence = ''.join([i for i in sequence if i in accepted]) | |
| segment = [] | |
| idx = 0 | |
| seq_ = list(sequence) | |
| while len(seq_): | |
| # for Number-Symbol | |
| temp = "" | |
| while len(seq_) and seq_[0] in self._numsym_set: | |
| temp += seq_[0] | |
| seq_.pop(0) | |
| if temp != "": | |
| segment.append(temp) | |
| # for Target Chars | |
| temp = "" | |
| while len(seq_) and seq_[0] in self._natscr_set: | |
| temp += seq_[0] | |
| seq_.pop(0) | |
| if temp != "": | |
| segment.append(temp) | |
| # for Input-Roman Chars | |
| temp = "" | |
| while len(seq_) and seq_[0] in self._inchar_set: | |
| temp += seq_[0] | |
| seq_.pop(0) | |
| if temp != "": | |
| segment.append(temp) | |
| temp = "" | |
| while len(seq_) and seq_[0] not in accepted: | |
| temp += seq_[0] | |
| seq_.pop(0) | |
| if temp != "": | |
| segment.append(temp) | |
| return segment | |
| def inferencer(self, sequence, beam_width=10): | |
| seg = self._word_segementer(sequence[:120]) | |
| lit_seg = [] | |
| p = 0 | |
| while p < len(seg): | |
| if seg[p][0] in self._natscr_set: | |
| lit_seg.append([seg[p]]) | |
| p += 1 | |
| elif seg[p][0] in self._inchar_set: | |
| lit_seg.append(self.character_model(seg[p], beam_width=beam_width)) | |
| p += 1 | |
| elif seg[p][0] in self._numsym_set: # num & punc | |
| lit_seg.append(self.numsym_model(seg[p])) | |
| p += 1 | |
| else: | |
| lit_seg.append([seg[p]]) | |
| p += 1 | |
| ## IF segment less/equal to 2 then return combinotorial, | |
| ## ELSE only return top1 of each result concatenated | |
| if len(lit_seg) == 1: | |
| final_result = lit_seg[0] | |
| elif len(lit_seg) == 2: | |
| final_result = [""] | |
| for seg in lit_seg: | |
| new_result = [] | |
| for s in seg: | |
| for f in final_result: | |
| new_result.append(f + s) | |
| final_result = new_result | |
| else: | |
| new_result = [] | |
| for seg in lit_seg: | |
| new_result.append(seg[0]) | |
| final_result = ["".join(new_result)] | |
| return final_result | |
| class XlitEngine: | |
| """ | |
| For Managing the top level tasks and applications of transliteration | |
| Global Variables: F_DIR | |
| """ | |
| def __init__(self, lang2use="hi", config_path="models/default_lineup.json"): | |
| lineup = json.load(open(os.path.join(F_DIR, config_path), encoding="utf-8")) | |
| models_path = os.path.join(F_DIR, "models") | |
| self.lang_config = {} | |
| if lang2use in lineup: | |
| self.lang_config[lang2use] = lineup[lang2use] | |
| else: | |
| raise Exception( | |
| "XlitError: The entered Langauge code not found. Available are {}".format( | |
| lineup.keys() | |
| ) | |
| ) | |
| self.langs = {} | |
| self.lang_model = {} | |
| for la in self.lang_config: | |
| try: | |
| print("Loading {}...".format(la)) | |
| self.lang_model[la] = XlitPiston( | |
| weight_path=os.path.join( | |
| models_path, self.lang_config[la]["weight"] | |
| ), | |
| tglyph_cfg_file=os.path.join( | |
| models_path, self.lang_config[la]["script"] | |
| ), | |
| iglyph_cfg_file="en", | |
| ) | |
| self.langs[la] = self.lang_config[la]["name"] | |
| except Exception as error: | |
| print("XlitError: Failure in loading {} \n".format(la), error) | |
| print(XlitError.loading_err.value) | |
| def translit_word(self, eng_word, lang_code="hi", topk=7, beam_width=10): | |
| if eng_word == "": | |
| return [] | |
| if lang_code in self.langs: | |
| try: | |
| res_list = self.lang_model[lang_code].inferencer( | |
| eng_word, beam_width=beam_width | |
| ) | |
| return res_list[:topk] | |
| except Exception as error: | |
| print("XlitError:", traceback.format_exc()) | |
| print(XlitError.internal_err.value) | |
| return XlitError.internal_err | |
| else: | |
| print("XlitError: Unknown Langauge requested", lang_code) | |
| print(XlitError.lang_err.value) | |
| return XlitError.lang_err | |
| def translit_sentence(self, eng_sentence, lang_code="hi", beam_width=10): | |
| if eng_sentence == "": | |
| return [] | |
| if lang_code in self.langs: | |
| try: | |
| out_str = "" | |
| for word in eng_sentence.split(): | |
| res_ = self.lang_model[lang_code].inferencer( | |
| word, beam_width=beam_width | |
| ) | |
| out_str = out_str + res_[0] + " " | |
| return out_str[:-1] | |
| except Exception as error: | |
| print("XlitError:", traceback.format_exc()) | |
| print(XlitError.internal_err.value) | |
| return XlitError.internal_err | |
| else: | |
| print("XlitError: Unknown Langauge requested", lang_code) | |
| print(XlitError.lang_err.value) | |
| return XlitError.lang_err | |
| if __name__ == "__main__": | |
| engine = XlitEngine() | |
| y = engine.translit_sentence("Hello World !") | |
| print(y) | |