Yilin0601 commited on
Commit
c7f56a8
·
verified ·
1 Parent(s): 4f0da2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -28
app.py CHANGED
@@ -28,51 +28,94 @@ translation_models = {
28
  "Korean": "Helsinki-NLP/opus-mt-en-ko"
29
  }
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  tts_models = {
32
  "Spanish": "tts_models/es/tacotron2-DDC",
33
  "French": "tts_models/fr/tacotron2",
34
  "German": "tts_models/de/tacotron2",
35
- "Chinese": "tts_models/zh/tacotron2",
36
- "Russian": "tts_models/ru/tacotron2",
37
- "Arabic": "tts_models/ar/tacotron2",
38
- "Portuguese": "tts_models/pt/tacotron2",
39
- "Japanese": "tts_models/ja/tacotron2",
40
- "Italian": "tts_models/it/tacotron2",
41
- "Korean": "tts_models/ko/tacotron2"
42
  }
43
 
 
44
  # Caches for translator and TTS pipelines
 
45
  translator_cache = {}
46
  tts_cache = {}
47
 
48
  def get_translator(target_language):
 
 
 
49
  if target_language in translator_cache:
50
  return translator_cache[target_language]
 
51
  model_name = translation_models[target_language]
52
- # Pipeline task naming is case sensitive; here we assume task "translation_en_to_<lang>"
53
- translator = pipeline("translation_en_to_" + target_language.lower(), model=model_name)
 
54
  translator_cache[target_language] = translator
55
  return translator
56
 
57
  def get_tts(target_language):
 
 
 
58
  if target_language in tts_cache:
59
  return tts_cache[target_language]
60
- model_name = tts_models[target_language]
61
- tts = pipeline("text-to-speech", model=model_name)
62
- tts_cache[target_language] = tts
63
- return tts
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # --------------------------------------------------
66
  # Prediction Function
67
  # --------------------------------------------------
68
  def predict(audio, text, target_language):
69
- # Use text input if provided; otherwise, use ASR on audio
70
- if text.strip() != "":
 
 
 
 
 
71
  english_text = text.strip()
72
  elif audio is not None:
73
  sample_rate, audio_data = audio
74
 
75
- # Ensure the audio is floating-point for librosa
76
  if audio_data.dtype not in [np.float32, np.float64]:
77
  audio_data = audio_data.astype(np.float32)
78
 
@@ -90,16 +133,24 @@ def predict(audio, text, target_language):
90
  else:
91
  return "No input provided.", "", None
92
 
93
- # Translation step
94
  translator = get_translator(target_language)
95
- translation_result = translator(english_text)
96
- translated_text = translation_result[0]["translation_text"]
97
-
98
- # TTS step: synthesize speech from the translated text
99
- tts = get_tts(target_language)
100
- tts_result = tts(translated_text)
101
- # The TTS pipeline returns a dict with "wav" and "sample_rate"
102
- synthesized_audio = (tts_result["sample_rate"], tts_result["wav"])
 
 
 
 
 
 
 
 
103
 
104
  return english_text, translated_text, synthesized_audio
105
 
@@ -122,10 +173,11 @@ iface = gr.Interface(
122
  description=(
123
  "This app helps language learners by providing three outputs:\n"
124
  "1. English transcription (from ASR or text input),\n"
125
- "2. Translation to a target language, and\n"
126
  "3. Synthetic speech in the target language.\n\n"
127
- "Choose one of the top 10 commonly used languages from the dropdown.\n"
128
- "You can either record/upload an English audio sample or enter English text directly."
 
129
  ),
130
  allow_flagging="never"
131
  )
 
28
  "Korean": "Helsinki-NLP/opus-mt-en-ko"
29
  }
30
 
31
+ # Each language often requires a specific pipeline task name
32
+ # (e.g., "translation_en_to_zh" rather than "translation_en_to_chinese")
33
+ translation_tasks = {
34
+ "Spanish": "translation_en_to_es",
35
+ "French": "translation_en_to_fr",
36
+ "German": "translation_en_to_de",
37
+ "Chinese": "translation_en_to_zh",
38
+ "Russian": "translation_en_to_ru",
39
+ "Arabic": "translation_en_to_ar",
40
+ "Portuguese": "translation_en_to_pt",
41
+ "Japanese": "translation_en_to_ja",
42
+ "Italian": "translation_en_to_it",
43
+ "Korean": "translation_en_to_ko"
44
+ }
45
+
46
+ # TTS models (some may not exist or may be unofficial)
47
  tts_models = {
48
  "Spanish": "tts_models/es/tacotron2-DDC",
49
  "French": "tts_models/fr/tacotron2",
50
  "German": "tts_models/de/tacotron2",
51
+ "Chinese": "tts_models/zh/tacotron2", # Verify if this actually exists on Hugging Face
52
+ "Russian": "tts_models/ru/tacotron2", # Same note
53
+ "Arabic": "tts_models/ar/tacotron2", # Same note
54
+ "Portuguese": "tts_models/pt/tacotron2", # Same note
55
+ "Japanese": "tts_models/ja/tacotron2", # Same note
56
+ "Italian": "tts_models/it/tacotron2", # Same note
57
+ "Korean": "tts_models/ko/tacotron2" # Same note
58
  }
59
 
60
+ # --------------------------------------------------
61
  # Caches for translator and TTS pipelines
62
+ # --------------------------------------------------
63
  translator_cache = {}
64
  tts_cache = {}
65
 
66
  def get_translator(target_language):
67
+ """
68
+ Retrieve or create a translation pipeline for the specified language.
69
+ """
70
  if target_language in translator_cache:
71
  return translator_cache[target_language]
72
+
73
  model_name = translation_models[target_language]
74
+ task_name = translation_tasks[target_language]
75
+
76
+ translator = pipeline(task_name, model=model_name)
77
  translator_cache[target_language] = translator
78
  return translator
79
 
80
  def get_tts(target_language):
81
+ """
82
+ Retrieve or create a TTS pipeline for the specified language, if available.
83
+ """
84
  if target_language in tts_cache:
85
  return tts_cache[target_language]
86
+
87
+ model_name = tts_models.get(target_language)
88
+ if model_name is None:
89
+ # If no TTS model is mapped, raise an error or handle gracefully
90
+ raise ValueError(f"No TTS model available for {target_language}.")
91
+
92
+ try:
93
+ tts_pipeline = pipeline("text-to-speech", model=model_name)
94
+ except Exception as e:
95
+ raise ValueError(
96
+ f"Failed to load TTS model for {target_language}. "
97
+ f"Make sure '{model_name}' exists on Hugging Face.\nError: {e}"
98
+ )
99
+
100
+ tts_cache[target_language] = tts_pipeline
101
+ return tts_pipeline
102
 
103
  # --------------------------------------------------
104
  # Prediction Function
105
  # --------------------------------------------------
106
  def predict(audio, text, target_language):
107
+ """
108
+ 1. Obtain English text (from text input or ASR).
109
+ 2. Translate English -> target_language.
110
+ 3. Synthesize speech in target_language.
111
+ """
112
+ # 1. English text from text input (if provided), else from audio via ASR
113
+ if text.strip():
114
  english_text = text.strip()
115
  elif audio is not None:
116
  sample_rate, audio_data = audio
117
 
118
+ # Ensure the audio is float32 for librosa
119
  if audio_data.dtype not in [np.float32, np.float64]:
120
  audio_data = audio_data.astype(np.float32)
121
 
 
133
  else:
134
  return "No input provided.", "", None
135
 
136
+ # 2. Translation step
137
  translator = get_translator(target_language)
138
+ try:
139
+ translation_result = translator(english_text)
140
+ translated_text = translation_result[0]["translation_text"]
141
+ except Exception as e:
142
+ # If there's an error in translation, return partial results
143
+ return english_text, f"Translation error: {e}", None
144
+
145
+ # 3. TTS step: synthesize speech from the translated text
146
+ try:
147
+ tts_pipeline = get_tts(target_language)
148
+ tts_result = tts_pipeline(translated_text)
149
+ # The TTS pipeline returns a dict with "wav" and "sample_rate"
150
+ synthesized_audio = (tts_result["sample_rate"], tts_result["wav"])
151
+ except Exception as e:
152
+ # If TTS fails, return partial results
153
+ return english_text, translated_text, f"TTS error: {e}"
154
 
155
  return english_text, translated_text, synthesized_audio
156
 
 
173
  description=(
174
  "This app helps language learners by providing three outputs:\n"
175
  "1. English transcription (from ASR or text input),\n"
176
+ "2. Translation to a target language (using Helsinki-NLP models), and\n"
177
  "3. Synthetic speech in the target language.\n\n"
178
+ "Select one of the top 10 commonly used languages from the dropdown.\n"
179
+ "Either record/upload an English audio sample or enter English text directly.\n\n"
180
+ "Note: Some TTS models may not exist or be unstable for certain languages."
181
  ),
182
  allow_flagging="never"
183
  )