Amis-Zh-MT / app.py
hunterschep's picture
Update app.py
52f4023 verified
raw
history blame
3.19 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
import torch
from sacremoses import MosesPunctNormalizer
import re
import unicodedata
import sys
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the small model
small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
# Fix tokenizer
def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[new_lang] = old_len - 1
tokenizer.id_to_lang_code[old_len - 1] = new_lang
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
if new_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(new_lang)
tokenizer.added_tokens_encoder = {}
tokenizer.added_tokens_decoder = {}
fix_tokenizer(small_tokenizer)
# Translation function
def translate(text, src_lang, tgt_lang):
tokenizer, model = small_tokenizer, small_model
if src_lang == "zho_Hant":
text = preproc_chinese(text)
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
model.eval()
result = model.generate(
**inputs.to(model.device),
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=256,
num_beams=4
)
return tokenizer.batch_decode(result, skip_special_tokens=True)[0]
# Preprocessing for Chinese
mpn_chinese = MosesPunctNormalizer(lang="zh")
mpn_chinese.substitutions = [(re.compile(r), sub) for r, sub in mpn_chinese.substitutions]
def get_non_printing_char_replacer(replace_by=" "):
non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}}
return lambda line: line.translate(non_printable_map)
replace_nonprint = get_non_printing_char_replacer(" ")
def preproc_chinese(text):
clean = text
for pattern, sub in mpn_chinese.substitutions:
clean = pattern.sub(sub, clean)
clean = replace_nonprint(clean)
return unicodedata.normalize("NFKC", clean)
with gr.Blocks() as demo:
gr.Markdown("# AMIS - Chinese Translation Tool")
src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
output_text = gr.Textbox(label="Translated Text", interactive=False)
translate_btn = gr.Button("Translate")
translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
if __name__ == "__main__":
demo.launch()