Amis-Zh-MT / app.py
hunterschep's picture
fix buttons
12a520d verified
raw
history blame
3.34 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
import torch
from sacremoses import MosesPunctNormalizer
import re
import unicodedata
import sys
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the small model
small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
# Fix tokenizer
def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[new_lang] = old_len - 1
tokenizer.id_to_lang_code[old_len - 1] = new_lang
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
if new_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(new_lang)
tokenizer.added_tokens_encoder = {}
tokenizer.added_tokens_decoder = {}
fix_tokenizer(small_tokenizer)
# Translation function
def translate(text, src_lang, tgt_lang):
tokenizer, model = small_tokenizer, small_model
if src_lang == "zho_Hant":
text = preproc_chinese(text)
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
model.eval()
result = model.generate(
**inputs.to(model.device),
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=256,
num_beams=4
)
return tokenizer.batch_decode(result, skip_special_tokens=True)[0]
# Preprocessing for Chinese
mpn_chinese = MosesPunctNormalizer(lang="zh")
mpn_chinese.substitutions = [(re.compile(r), sub) for r, sub in mpn_chinese.substitutions]
def get_non_printing_char_replacer(replace_by=" "):
non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}}
return lambda line: line.translate(non_printable_map)
replace_nonprint = get_non_printing_char_replacer(" ")
def preproc_chinese(text):
clean = text
for pattern, sub in mpn_chinese.substitutions:
clean = pattern.sub(sub, clean)
clean = replace_nonprint(clean)
return unicodedata.normalize("NFKC", clean)
with gr.Blocks() as demo:
gr.Markdown("# AMIS - Chinese Translation Tool")
src_lang = gr.Radio(choices=["汉语 Chinese", "Amis"], value="zho_Hant", label="Source Language")
tgt_lang = gr.Radio(choices=["Amis", "汉语 Chinese"], value="ami_Latn", label="Target Language")
input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
output_text = gr.Textbox(label="Translated Text", interactive=False)
translate_btn = gr.Button("Translate")
translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
gr.Markdown("感謝您在此專案上的辛勤工作 這只是這些模型能力的一小部分展示。電子郵件: [email protected]")
if __name__ == "__main__":
demo.launch()