|
import streamlit as st |
|
import os |
|
import io |
|
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration |
|
import time |
|
import json |
|
from typing import List |
|
import torch |
|
import random |
|
import logging |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
else: |
|
device = torch.device("cpu") |
|
logging.warning("GPU not found, using CPU, translation will be very slow.") |
|
|
|
st.cache(suppress_st_warning=True, allow_output_mutation=True) |
|
st.set_page_config(page_title="M2M100 Translator") |
|
|
|
lang_id = { |
|
"Afrikaans": "af", |
|
"Amharic": "am", |
|
"Arabic": "ar", |
|
"Asturian": "ast", |
|
"Azerbaijani": "az", |
|
"Bashkir": "ba", |
|
"Belarusian": "be", |
|
"Bulgarian": "bg", |
|
"Bengali": "bn", |
|
"Breton": "br", |
|
"Bosnian": "bs", |
|
"Catalan": "ca", |
|
"Cebuano": "ceb", |
|
"Czech": "cs", |
|
"Welsh": "cy", |
|
"Danish": "da", |
|
"German": "de", |
|
"Greeek": "el", |
|
"English": "en", |
|
"Spanish": "es", |
|
"Estonian": "et", |
|
"Persian": "fa", |
|
"Fulah": "ff", |
|
"Finnish": "fi", |
|
"French": "fr", |
|
"Western Frisian": "fy", |
|
"Irish": "ga", |
|
"Gaelic": "gd", |
|
"Galician": "gl", |
|
"Gujarati": "gu", |
|
"Hausa": "ha", |
|
"Hebrew": "he", |
|
"Hindi": "hi", |
|
"Croatian": "hr", |
|
"Haitian": "ht", |
|
"Hungarian": "hu", |
|
"Armenian": "hy", |
|
"Indonesian": "id", |
|
"Igbo": "ig", |
|
"Iloko": "ilo", |
|
"Icelandic": "is", |
|
"Italian": "it", |
|
"Japanese": "ja", |
|
"Javanese": "jv", |
|
"Georgian": "ka", |
|
"Kazakh": "kk", |
|
"Central Khmer": "km", |
|
"Kannada": "kn", |
|
"Korean": "ko", |
|
"Luxembourgish": "lb", |
|
"Ganda": "lg", |
|
"Lingala": "ln", |
|
"Lao": "lo", |
|
"Lithuanian": "lt", |
|
"Latvian": "lv", |
|
"Malagasy": "mg", |
|
"Macedonian": "mk", |
|
"Malayalam": "ml", |
|
"Mongolian": "mn", |
|
"Marathi": "mr", |
|
"Malay": "ms", |
|
"Burmese": "my", |
|
"Nepali": "ne", |
|
"Dutch": "nl", |
|
"Norwegian": "no", |
|
"Northern Sotho": "ns", |
|
"Occitan": "oc", |
|
"Oriya": "or", |
|
"Panjabi": "pa", |
|
"Polish": "pl", |
|
"Pushto": "ps", |
|
"Portuguese": "pt", |
|
"Romanian": "ro", |
|
"Russian": "ru", |
|
"Sindhi": "sd", |
|
"Sinhala": "si", |
|
"Slovak": "sk", |
|
"Slovenian": "sl", |
|
"Somali": "so", |
|
"Albanian": "sq", |
|
"Serbian": "sr", |
|
"Swati": "ss", |
|
"Sundanese": "su", |
|
"Swedish": "sv", |
|
"Swahili": "sw", |
|
"Tamil": "ta", |
|
"Thai": "th", |
|
"Tagalog": "tl", |
|
"Tswana": "tn", |
|
"Turkish": "tr", |
|
"Ukrainian": "uk", |
|
"Urdu": "ur", |
|
"Uzbek": "uz", |
|
"Vietnamese": "vi", |
|
"Wolof": "wo", |
|
"Xhosa": "xh", |
|
"Yiddish": "yi", |
|
"Yoruba": "yo", |
|
"Chinese": "zh", |
|
"Zulu": "zu", |
|
} |
|
|
|
|
|
@st.cache(suppress_st_warning=True, allow_output_mutation=True) |
|
def load_model( |
|
pretrained_model: str = "facebook/m2m100_1.2B", |
|
cache_dir: str = "models/", |
|
): |
|
tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) |
|
model = M2M100ForConditionalGeneration.from_pretrained( |
|
pretrained_model, cache_dir=cache_dir |
|
).to(device) |
|
model.eval() |
|
return tokenizer, model |
|
|
|
|
|
st.title("M2M100 Translator") |
|
st.write("M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper https://arxiv.org/abs/2010.11125 and first released in https://github.com/pytorch/fairseq/tree/master/examples/m2m_100 repository. The model that can directly translate between the 9,900 directions of 100 languages.\n") |
|
|
|
st.write(" This demo uses the facebook/m2m100_1.2B model. For local inference see https://github.com/ikergarcia1996/Easy-Translate") |
|
|
|
|
|
user_input: str = st.text_area( |
|
"Input text", |
|
height=200, |
|
max_chars=5120, |
|
) |
|
|
|
source_lang = st.selectbox(label="Source language", options=list(lang_id.keys())) |
|
target_lang = st.selectbox(label="Target language", options=list(lang_id.keys())) |
|
|
|
if st.button("Run"): |
|
time_start = time.time() |
|
tokenizer, model = load_model() |
|
|
|
src_lang = lang_id[source_lang] |
|
trg_lang = lang_id[target_lang] |
|
tokenizer.src_lang = src_lang |
|
with torch.no_grad(): |
|
encoded_input = tokenizer(user_input, return_tensors="pt").to(device) |
|
generated_tokens = model.generate( |
|
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) |
|
) |
|
translated_text = tokenizer.batch_decode( |
|
generated_tokens, skip_special_tokens=True |
|
)[0] |
|
|
|
time_end = time.time() |
|
st.success(translated_text) |
|
|
|
st.write(f"Computation time: {round((time_end-time_start),3)} segs") |
|
|