deep-div commited on
Commit
fc9ce6a
·
verified ·
1 Parent(s): d9ace2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -130
app.py CHANGED
@@ -1,132 +1,52 @@
1
  import streamlit as st
2
  import os
3
- import io
4
- from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
5
  import time
6
- import json
7
- from typing import List
8
  import torch
9
- import random
10
  import logging
 
 
 
 
11
 
 
12
  if torch.cuda.is_available():
13
  device = torch.device("cuda:0")
14
  else:
15
  device = torch.device("cpu")
16
  logging.warning("GPU not found, using CPU, translation will be very slow.")
17
 
18
- st.cache(suppress_st_warning=True, allow_output_mutation=True)
19
- st.set_page_config(page_title="M2M100 Translator")
20
-
21
  lang_id = {
22
- "Afrikaans": "af",
23
- "Amharic": "am",
24
- "Arabic": "ar",
25
- "Asturian": "ast",
26
- "Azerbaijani": "az",
27
- "Bashkir": "ba",
28
- "Belarusian": "be",
29
- "Bulgarian": "bg",
30
- "Bengali": "bn",
31
- "Breton": "br",
32
- "Bosnian": "bs",
33
- "Catalan": "ca",
34
- "Cebuano": "ceb",
35
- "Czech": "cs",
36
- "Welsh": "cy",
37
- "Danish": "da",
38
- "German": "de",
39
- "Greeek": "el",
40
- "English": "en",
41
- "Spanish": "es",
42
- "Estonian": "et",
43
- "Persian": "fa",
44
- "Fulah": "ff",
45
- "Finnish": "fi",
46
- "French": "fr",
47
- "Western Frisian": "fy",
48
- "Irish": "ga",
49
- "Gaelic": "gd",
50
- "Galician": "gl",
51
- "Gujarati": "gu",
52
- "Hausa": "ha",
53
- "Hebrew": "he",
54
- "Hindi": "hi",
55
- "Croatian": "hr",
56
- "Haitian": "ht",
57
- "Hungarian": "hu",
58
- "Armenian": "hy",
59
- "Indonesian": "id",
60
- "Igbo": "ig",
61
- "Iloko": "ilo",
62
- "Icelandic": "is",
63
- "Italian": "it",
64
- "Japanese": "ja",
65
- "Javanese": "jv",
66
- "Georgian": "ka",
67
- "Kazakh": "kk",
68
- "Central Khmer": "km",
69
- "Kannada": "kn",
70
- "Korean": "ko",
71
- "Luxembourgish": "lb",
72
- "Ganda": "lg",
73
- "Lingala": "ln",
74
- "Lao": "lo",
75
- "Lithuanian": "lt",
76
- "Latvian": "lv",
77
- "Malagasy": "mg",
78
- "Macedonian": "mk",
79
- "Malayalam": "ml",
80
- "Mongolian": "mn",
81
- "Marathi": "mr",
82
- "Malay": "ms",
83
- "Burmese": "my",
84
- "Nepali": "ne",
85
- "Dutch": "nl",
86
- "Norwegian": "no",
87
- "Northern Sotho": "ns",
88
- "Occitan": "oc",
89
- "Oriya": "or",
90
- "Panjabi": "pa",
91
- "Polish": "pl",
92
- "Pushto": "ps",
93
- "Portuguese": "pt",
94
- "Romanian": "ro",
95
- "Russian": "ru",
96
- "Sindhi": "sd",
97
- "Sinhala": "si",
98
- "Slovak": "sk",
99
- "Slovenian": "sl",
100
- "Somali": "so",
101
- "Albanian": "sq",
102
- "Serbian": "sr",
103
- "Swati": "ss",
104
- "Sundanese": "su",
105
- "Swedish": "sv",
106
- "Swahili": "sw",
107
- "Tamil": "ta",
108
- "Thai": "th",
109
- "Tagalog": "tl",
110
- "Tswana": "tn",
111
- "Turkish": "tr",
112
- "Ukrainian": "uk",
113
- "Urdu": "ur",
114
- "Uzbek": "uz",
115
- "Vietnamese": "vi",
116
- "Wolof": "wo",
117
- "Xhosa": "xh",
118
- "Yiddish": "yi",
119
- "Yoruba": "yo",
120
- "Chinese": "zh",
121
- "Zulu": "zu",
122
  }
123
 
124
-
125
- @st.cache_resource(suppress_st_warning=True, allow_output_mutation=True)
126
- def load_model(
127
- pretrained_model: str = "facebook/m2m100_1.2B",
128
- cache_dir: str = "models/",
129
- ):
130
  tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
131
  model = M2M100ForConditionalGeneration.from_pretrained(
132
  pretrained_model, cache_dir=cache_dir
@@ -134,40 +54,48 @@ def load_model(
134
  model.eval()
135
  return tokenizer, model
136
 
 
 
 
 
 
 
 
 
137
 
138
- st.title("M2M100 Translator")
139
- st.write("M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper https://arxiv.org/abs/2010.11125 and first released in https://github.com/pytorch/fairseq/tree/master/examples/m2m_100 repository. The model that can directly translate between the 9,900 directions of 100 languages.\n")
140
-
141
- st.write(" This demo uses the facebook/m2m100_1.2B model. For local inference see https://github.com/ikergarcia1996/Easy-Translate")
142
-
143
-
144
- user_input: str = st.text_area(
145
- "Input text",
146
  height=200,
147
  max_chars=5120,
 
148
  )
149
 
150
- source_lang = st.selectbox(label="Source language", options=list(lang_id.keys()))
151
- target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))
 
152
 
153
- if st.button("Run"):
154
- with st.spinner("Translating... please wait..."):
 
155
  time_start = time.time()
156
  tokenizer, model = load_model()
157
 
158
  src_lang = lang_id[source_lang]
159
  trg_lang = lang_id[target_lang]
 
160
  tokenizer.src_lang = src_lang
161
  with torch.no_grad():
162
  encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
163
  generated_tokens = model.generate(
164
- **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
 
165
  )
166
  translated_text = tokenizer.batch_decode(
167
  generated_tokens, skip_special_tokens=True
168
  )[0]
169
 
170
  time_end = time.time()
171
- st.success(translated_text)
172
- st.write(f"Computation time: {round((time_end - time_start), 3)} seconds")
173
-
 
1
  import streamlit as st
2
  import os
 
 
3
  import time
 
 
4
  import torch
 
5
  import logging
6
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
7
+
8
+ # Set Streamlit page configuration
9
+ st.set_page_config(page_title="M2M100 Translator")
10
 
11
+ # Check device
12
  if torch.cuda.is_available():
13
  device = torch.device("cuda:0")
14
  else:
15
  device = torch.device("cpu")
16
  logging.warning("GPU not found, using CPU, translation will be very slow.")
17
 
18
+ # Language code mapping
 
 
19
  lang_id = {
20
+ "Afrikaans": "af", "Amharic": "am", "Arabic": "ar", "Asturian": "ast",
21
+ "Azerbaijani": "az", "Bashkir": "ba", "Belarusian": "be", "Bulgarian": "bg",
22
+ "Bengali": "bn", "Breton": "br", "Bosnian": "bs", "Catalan": "ca",
23
+ "Cebuano": "ceb", "Czech": "cs", "Welsh": "cy", "Danish": "da",
24
+ "German": "de", "Greeek": "el", "English": "en", "Spanish": "es",
25
+ "Estonian": "et", "Persian": "fa", "Fulah": "ff", "Finnish": "fi",
26
+ "French": "fr", "Western Frisian": "fy", "Irish": "ga", "Gaelic": "gd",
27
+ "Galician": "gl", "Gujarati": "gu", "Hausa": "ha", "Hebrew": "he",
28
+ "Hindi": "hi", "Croatian": "hr", "Haitian": "ht", "Hungarian": "hu",
29
+ "Armenian": "hy", "Indonesian": "id", "Igbo": "ig", "Iloko": "ilo",
30
+ "Icelandic": "is", "Italian": "it", "Japanese": "ja", "Javanese": "jv",
31
+ "Georgian": "ka", "Kazakh": "kk", "Central Khmer": "km", "Kannada": "kn",
32
+ "Korean": "ko", "Luxembourgish": "lb", "Ganda": "lg", "Lingala": "ln",
33
+ "Lao": "lo", "Lithuanian": "lt", "Latvian": "lv", "Malagasy": "mg",
34
+ "Macedonian": "mk", "Malayalam": "ml", "Mongolian": "mn", "Marathi": "mr",
35
+ "Malay": "ms", "Burmese": "my", "Nepali": "ne", "Dutch": "nl",
36
+ "Norwegian": "no", "Northern Sotho": "ns", "Occitan": "oc", "Oriya": "or",
37
+ "Panjabi": "pa", "Polish": "pl", "Pushto": "ps", "Portuguese": "pt",
38
+ "Romanian": "ro", "Russian": "ru", "Sindhi": "sd", "Sinhala": "si",
39
+ "Slovak": "sk", "Slovenian": "sl", "Somali": "so", "Albanian": "sq",
40
+ "Serbian": "sr", "Swati": "ss", "Sundanese": "su", "Swedish": "sv",
41
+ "Swahili": "sw", "Tamil": "ta", "Thai": "th", "Tagalog": "tl",
42
+ "Tswana": "tn", "Turkish": "tr", "Ukrainian": "uk", "Urdu": "ur",
43
+ "Uzbek": "uz", "Vietnamese": "vi", "Wolof": "wo", "Xhosa": "xh",
44
+ "Yiddish": "yi", "Yoruba": "yo", "Chinese": "zh", "Zulu": "zu",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  }
46
 
47
+ # Cache the model and tokenizer using new API
48
+ @st.cache_resource
49
+ def load_model(pretrained_model="facebook/m2m100_1.2B", cache_dir="models/"):
 
 
 
50
  tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
51
  model = M2M100ForConditionalGeneration.from_pretrained(
52
  pretrained_model, cache_dir=cache_dir
 
54
  model.eval()
55
  return tokenizer, model
56
 
57
+ # App Title and Intro
58
+ st.title("🌐 M2M100 Translator")
59
+ st.write("""
60
+ M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
61
+ It supports **100 languages** and translates in **9900 directions**.
62
+ Model: `facebook/m2m100_1.2B`
63
+ More info: [Paper](https://arxiv.org/abs/2010.11125) | [Repo](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100)
64
+ """)
65
 
66
+ # Input Text Area
67
+ user_input = st.text_area(
68
+ "Enter text to translate:",
 
 
 
 
 
69
  height=200,
70
  max_chars=5120,
71
+ placeholder="Type your sentence here..."
72
  )
73
 
74
+ # Language selectors
75
+ source_lang = st.selectbox("Select source language", sorted(lang_id.keys()))
76
+ target_lang = st.selectbox("Select target language", sorted(lang_id.keys()))
77
 
78
+ # Translate Button
79
+ if st.button("Translate"):
80
+ with st.spinner("Translating... Please wait"):
81
  time_start = time.time()
82
  tokenizer, model = load_model()
83
 
84
  src_lang = lang_id[source_lang]
85
  trg_lang = lang_id[target_lang]
86
+
87
  tokenizer.src_lang = src_lang
88
  with torch.no_grad():
89
  encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
90
  generated_tokens = model.generate(
91
+ **encoded_input,
92
+ forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
93
  )
94
  translated_text = tokenizer.batch_decode(
95
  generated_tokens, skip_special_tokens=True
96
  )[0]
97
 
98
  time_end = time.time()
99
+ st.success("Translation complete!")
100
+ st.markdown(f"**Translated Text:**\n\n{translated_text}")
101
+ st.caption(f"Time taken: {round(time_end - time_start, 2)} seconds")