Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 3,185 Bytes
			
			| 640a35c 4af5544 52f4023 640a35c 4af5544 640a35c 4af5544 640a35c b551379 4af5544 640a35c 4af5544 640a35c 4af5544 2e7a521 b551379 4af5544 640a35c 4af5544 640a35c 4af5544 640a35c 4af5544 640a35c 4af5544 640a35c 4af5544 2e7a521 4af5544 2e7a521 640a35c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import gradio as gr
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
import torch
from sacremoses import MosesPunctNormalizer
import re
import unicodedata
import sys
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the small model
small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
# Fix tokenizer
def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
    old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
    tokenizer.lang_code_to_id[new_lang] = old_len - 1
    tokenizer.id_to_lang_code[old_len - 1] = new_lang
    tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
    tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
    tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
    if new_lang not in tokenizer._additional_special_tokens:
        tokenizer._additional_special_tokens.append(new_lang)
    tokenizer.added_tokens_encoder = {}
    tokenizer.added_tokens_decoder = {}
fix_tokenizer(small_tokenizer)
# Translation function
def translate(text, src_lang, tgt_lang):
    tokenizer, model = small_tokenizer, small_model
    if src_lang == "zho_Hant":
        text = preproc_chinese(text)
    tokenizer.src_lang = src_lang
    tokenizer.tgt_lang = tgt_lang
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
    model.eval()
    result = model.generate(
        **inputs.to(model.device),
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
        max_new_tokens=256,
        num_beams=4
    )
    return tokenizer.batch_decode(result, skip_special_tokens=True)[0]
# Preprocessing for Chinese
mpn_chinese = MosesPunctNormalizer(lang="zh")
mpn_chinese.substitutions = [(re.compile(r), sub) for r, sub in mpn_chinese.substitutions]
def get_non_printing_char_replacer(replace_by=" "):
    non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}}
    return lambda line: line.translate(non_printable_map)
replace_nonprint = get_non_printing_char_replacer(" ")
def preproc_chinese(text):
    clean = text
    for pattern, sub in mpn_chinese.substitutions:
        clean = pattern.sub(sub, clean)
    clean = replace_nonprint(clean)
    return unicodedata.normalize("NFKC", clean)
with gr.Blocks() as demo:
    gr.Markdown("# AMIS - Chinese Translation Tool")
    src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
    tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
    input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
    output_text = gr.Textbox(label="Translated Text", interactive=False)
    translate_btn = gr.Button("Translate")
    translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
if __name__ == "__main__":
    demo.launch()
 |