Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
import json | |
from transformers import BertTokenizer | |
from torch.utils.data import Dataset | |
from transformers.models.bert.modeling_bert import * | |
import torch | |
import torch.nn.functional as F | |
from onnxruntime import InferenceSession, GraphOptimizationLevel, SessionOptions | |
class PolyDataset(Dataset): | |
def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1): | |
self.dataset = self.preprocess(words, labels) | |
self.word_pad_idx = word_pad_idx | |
self.label_pad_idx = label_pad_idx | |
def preprocess(self, origin_sentences, origin_labels): | |
""" | |
Maps tokens and tags to their indices and stores them in the dict data. | |
examples: | |
word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部'] | |
sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956], | |
array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) | |
label:[3, 13, 13, 13, 0, 0, 0, 0, 0] | |
""" | |
data = [] | |
labels = [] | |
sentences = [] | |
# tokenize | |
for line in origin_sentences: | |
# replace each token by its index | |
# we can not use encode_plus because our sentences are aligned to labels in list type | |
words = [] | |
word_lens = [] | |
for token in line: | |
words.append(token) | |
word_lens.append(1) | |
token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1]) | |
sentences.append(((words, token_start_idxs), 0)) | |
### | |
for tag in origin_labels: | |
labels.append(tag) | |
for sentence, label in zip(sentences, labels): | |
data.append((sentence, label)) | |
return data | |
def __getitem__(self, idx): | |
"""sample data to get batch""" | |
word = self.dataset[idx][0] | |
label = self.dataset[idx][1] | |
return [word, label] | |
def __len__(self): | |
"""get dataset size""" | |
return len(self.dataset) | |
def collate_fn(self, batch): | |
sentences = [x[0][0] for x in batch] | |
ori_sents = [x[0][1] for x in batch] | |
labels = [x[1] for x in batch] | |
batch_len = len(sentences) | |
# compute length of longest sentence in batch | |
max_len = max([len(s[0]) for s in sentences]) | |
max_label_len = 0 | |
batch_data = np.ones((batch_len, max_len)) | |
batch_label_starts = [] | |
# padding and aligning | |
for j in range(batch_len): | |
cur_len = len(sentences[j][0]) | |
batch_data[j][:cur_len] = sentences[j][0] | |
label_start_idx = sentences[j][-1] | |
label_starts = np.zeros(max_len) | |
label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1 | |
batch_label_starts.append(label_starts) | |
max_label_len = max(int(sum(label_starts)), max_label_len) | |
# padding label | |
batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len)) | |
batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len)) | |
for j in range(batch_len): | |
cur_tags_len = len(labels[j]) | |
batch_labels[j][:cur_tags_len] = labels[j] | |
batch_pmasks[j][:cur_tags_len] = [ | |
1 if item > 0 else 0 for item in labels[j] | |
] | |
# convert data to torch LongTensors | |
batch_data = torch.tensor(batch_data, dtype=torch.long) | |
batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long) | |
batch_labels = torch.tensor(batch_labels, dtype=torch.long) | |
batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long) | |
return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents] | |
class BertPolyPredict: | |
def __init__(self, bert_model, jsonr_file, json_file): | |
self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True) | |
with open(jsonr_file, "r", encoding="utf8") as fp: | |
self.pron_dict = json.load(fp) | |
with open(json_file, "r", encoding="utf8") as fp: | |
self.pron_dict_id_2_pinyin = json.load(fp) | |
self.num_polyphone = len(self.pron_dict) | |
self.device = "cpu" | |
self.polydataset = PolyDataset | |
options = SessionOptions() # initialize session options | |
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL | |
print(os.path.join(bert_model, "poly_bert_model.onnx")) | |
self.session = InferenceSession( | |
os.path.join(bert_model, "poly_bert_model.onnx"), | |
sess_options=options, | |
providers=[ | |
"CUDAExecutionProvider", | |
"CPUExecutionProvider", | |
], # CPUExecutionProvider #CUDAExecutionProvider | |
) | |
# self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}]) | |
# disable session.run() fallback mechanism, it prevents for a reset of the execution provider | |
self.session.disable_fallback() | |
def predict_process(self, txt_list): | |
word_test, label_test, texts_test = self.get_examples_po(txt_list) | |
data = self.polydataset(word_test, label_test) | |
predict_loader = DataLoader( | |
data, batch_size=1, shuffle=False, collate_fn=data.collate_fn | |
) | |
pred_tags = self.predict_onnx(predict_loader) | |
return pred_tags | |
def predict_onnx(self, dev_loader): | |
pred_tags = [] | |
with torch.no_grad(): | |
for idx, batch_samples in enumerate(dev_loader): | |
# [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents] | |
batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = ( | |
batch_samples | |
) | |
# shift tensors to GPU if available | |
batch_data = batch_data.to(self.device) | |
batch_label_starts = batch_label_starts.to(self.device) | |
batch_labels = batch_labels.to(self.device) | |
batch_pmasks = batch_pmasks.to(self.device) | |
batch_data = np.asarray(batch_data, dtype=np.int32) | |
batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32) | |
# batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0] | |
batch_output = self.session.run( | |
output_names=["outputs"], input_feed={"input_ids": batch_data} | |
)[0] | |
label_masks = batch_pmasks == 1 | |
batch_labels = batch_labels.to("cpu").numpy() | |
for i, indices in enumerate(np.argmax(batch_output, axis=2)): | |
for j, idx in enumerate(indices): | |
if label_masks[i][j]: | |
# pred_tag.append(idx) | |
pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)]) | |
return pred_tags | |
def get_examples_po(self, text_list): | |
word_list = [] | |
label_list = [] | |
sentence_list = [] | |
id = 0 | |
for line in [text_list]: | |
sentence = line[0] | |
words = [] | |
tokens = line[0] | |
index = line[-1] | |
front = index | |
back = len(tokens) - index - 1 | |
labels = [0] * front + [1] + [0] * back | |
words = ["[CLS]"] + [item for item in sentence] | |
words = self.tokenizer.convert_tokens_to_ids(words) | |
word_list.append(words) | |
label_list.append(labels) | |
sentence_list.append(sentence) | |
id += 1 | |
# mask_list.append(masks) | |
assert len(labels) + 1 == len(words), print( | |
( | |
poly, | |
sentence, | |
words, | |
labels, | |
sentence, | |
len(sentence), | |
len(words), | |
len(labels), | |
) | |
) | |
assert len(labels) + 1 == len( | |
words | |
), "Number of labels does not match number of words" | |
assert len(labels) == len( | |
sentence | |
), "Number of labels does not match number of sentences" | |
assert len(word_list) == len( | |
label_list | |
), "Number of label sentences does not match number of word sentences" | |
return word_list, label_list, text_list | |