CloudNativeDDL / app.py
Timing0311's picture
test translation func
a88db4c
raw
history blame
656 Bytes
from transformers import MT5ForConditionalGeneration, AutoTokenizer
import gradio as grad
mdl = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
def translation_CN2EN(text):
inp = "translate English to Chinese:: "+text
enc = tokenizer(inp, return_tensors="pt")
tokens = mdl.generate(**enc)
response = tokenizer.batch_decode(tokens)
return response
para=grad.Textbox(lines=1, label="Chinese Text", placeholder="Text in Chinese")
out=grad.Textbox(lines=1, label="English Translation")
grad.Interface(translation_CN2EN, inputs=para, outputs=out).launch()