File size: 4,358 Bytes
6e5097e
 
 
 
 
fc9ce6a
 
 
 
6e5097e
fc9ce6a
6e5097e
 
 
 
 
 
fc9ce6a
6e5097e
fc9ce6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e5097e
 
fc9ce6a
 
 
6e5097e
 
 
 
 
 
 
fc9ce6a
 
 
 
 
 
 
 
6e5097e
fc9ce6a
 
 
6e5097e
 
fc9ce6a
6e5097e
 
fc9ce6a
 
 
6e5097e
fc9ce6a
 
 
6e5097e
 
 
 
 
fc9ce6a
6e5097e
 
 
 
fc9ce6a
 
6e5097e
 
 
 
 
 
fc9ce6a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st
import os
import time
import torch
import logging
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration

# Set Streamlit page configuration
st.set_page_config(page_title="M2M100 Translator")

# Check device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    logging.warning("GPU not found, using CPU, translation will be very slow.")

# Language code mapping
lang_id = {
    "Afrikaans": "af", "Amharic": "am", "Arabic": "ar", "Asturian": "ast",
    "Azerbaijani": "az", "Bashkir": "ba", "Belarusian": "be", "Bulgarian": "bg",
    "Bengali": "bn", "Breton": "br", "Bosnian": "bs", "Catalan": "ca",
    "Cebuano": "ceb", "Czech": "cs", "Welsh": "cy", "Danish": "da",
    "German": "de", "Greeek": "el", "English": "en", "Spanish": "es",
    "Estonian": "et", "Persian": "fa", "Fulah": "ff", "Finnish": "fi",
    "French": "fr", "Western Frisian": "fy", "Irish": "ga", "Gaelic": "gd",
    "Galician": "gl", "Gujarati": "gu", "Hausa": "ha", "Hebrew": "he",
    "Hindi": "hi", "Croatian": "hr", "Haitian": "ht", "Hungarian": "hu",
    "Armenian": "hy", "Indonesian": "id", "Igbo": "ig", "Iloko": "ilo",
    "Icelandic": "is", "Italian": "it", "Japanese": "ja", "Javanese": "jv",
    "Georgian": "ka", "Kazakh": "kk", "Central Khmer": "km", "Kannada": "kn",
    "Korean": "ko", "Luxembourgish": "lb", "Ganda": "lg", "Lingala": "ln",
    "Lao": "lo", "Lithuanian": "lt", "Latvian": "lv", "Malagasy": "mg",
    "Macedonian": "mk", "Malayalam": "ml", "Mongolian": "mn", "Marathi": "mr",
    "Malay": "ms", "Burmese": "my", "Nepali": "ne", "Dutch": "nl",
    "Norwegian": "no", "Northern Sotho": "ns", "Occitan": "oc", "Oriya": "or",
    "Panjabi": "pa", "Polish": "pl", "Pushto": "ps", "Portuguese": "pt",
    "Romanian": "ro", "Russian": "ru", "Sindhi": "sd", "Sinhala": "si",
    "Slovak": "sk", "Slovenian": "sl", "Somali": "so", "Albanian": "sq",
    "Serbian": "sr", "Swati": "ss", "Sundanese": "su", "Swedish": "sv",
    "Swahili": "sw", "Tamil": "ta", "Thai": "th", "Tagalog": "tl",
    "Tswana": "tn", "Turkish": "tr", "Ukrainian": "uk", "Urdu": "ur",
    "Uzbek": "uz", "Vietnamese": "vi", "Wolof": "wo", "Xhosa": "xh",
    "Yiddish": "yi", "Yoruba": "yo", "Chinese": "zh", "Zulu": "zu",
}

# Cache the model and tokenizer using new API
@st.cache_resource
def load_model(pretrained_model="facebook/m2m100_1.2B", cache_dir="models/"):
    tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
    model = M2M100ForConditionalGeneration.from_pretrained(
        pretrained_model, cache_dir=cache_dir
    ).to(device)
    model.eval()
    return tokenizer, model

# App Title and Intro
st.title("🌐 M2M100 Translator")
st.write("""
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
It supports **100 languages** and translates in **9900 directions**.  
Model: `facebook/m2m100_1.2B`  
More info: [Paper](https://arxiv.org/abs/2010.11125) | [Repo](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100)
""")

# Input Text Area
user_input = st.text_area(
    "Enter text to translate:",
    height=200,
    max_chars=5120,
    placeholder="Type your sentence here..."
)

# Language selectors
source_lang = st.selectbox("Select source language", sorted(lang_id.keys()))
target_lang = st.selectbox("Select target language", sorted(lang_id.keys()))

# Translate Button
if st.button("Translate"):
    with st.spinner("Translating... Please wait"):
        time_start = time.time()
        tokenizer, model = load_model()

        src_lang = lang_id[source_lang]
        trg_lang = lang_id[target_lang]

        tokenizer.src_lang = src_lang
        with torch.no_grad():
            encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
            generated_tokens = model.generate(
                **encoded_input,
                forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
            )
            translated_text = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )[0]

        time_end = time.time()
        st.success("Translation complete!")
        st.markdown(f"**Translated Text:**\n\n{translated_text}")
        st.caption(f"Time taken: {round(time_end - time_start, 2)} seconds")