Yilin0601 commited on
Commit
c098e72
·
verified ·
1 Parent(s): df9ae3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -54
app.py CHANGED
@@ -3,8 +3,6 @@ import torch
3
  import numpy as np
4
  import librosa
5
  from transformers import pipeline
6
- from transformers import VitsModel, AutoTokenizer
7
- import scipy # imported if needed for processing
8
 
9
  # --------------------------------------------------
10
  # ASR Pipeline (for English transcription)
@@ -15,48 +13,24 @@ asr = pipeline(
15
  )
16
 
17
  # --------------------------------------------------
18
- # Mapping for Target Languages and Translation Pipelines
19
  # --------------------------------------------------
20
  translation_models = {
21
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
22
- "French": "Helsinki-NLP/opus-mt-en-fr",
23
- "German": "Helsinki-NLP/opus-mt-en-de",
24
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
25
- "Russian": "Helsinki-NLP/opus-mt-en-ru",
26
- "Arabic": "Helsinki-NLP/opus-mt-en-ar",
27
- "Portuguese": "Helsinki-NLP/opus-mt-en-pt",
28
- "Japanese": "Helsinki-NLP/opus-mt-en-ja",
29
- "Italian": "Helsinki-NLP/opus-mt-en-it",
30
- "Korean": "Helsinki-NLP/opus-mt-en-ko"
31
  }
32
 
33
  translation_tasks = {
34
  "Spanish": "translation_en_to_es",
35
- "French": "translation_en_to_fr",
36
- "German": "translation_en_to_de",
37
  "Chinese": "translation_en_to_zh",
38
- "Russian": "translation_en_to_ru",
39
- "Arabic": "translation_en_to_ar",
40
- "Portuguese": "translation_en_to_pt",
41
- "Japanese": "translation_en_to_ja",
42
- "Italian": "translation_en_to_it",
43
- "Korean": "translation_en_to_ko"
44
  }
45
 
46
- # --------------------------------------------------
47
- # TTS Models (using real Facebook MMS TTS & others)
48
- # --------------------------------------------------
49
  tts_models = {
50
  "Spanish": "facebook/mms-tts-spa",
51
- "French": "facebook/mms-tts-fra",
52
- "German": "facebook/mms-tts-deu",
53
  "Chinese": "facebook/mms-tts-che",
54
- "Russian": "facebook/mms-tts-rus",
55
- "Arabic": "facebook/mms-tts-ara",
56
- "Portuguese": "facebook/mms-tts-por",
57
- "Japanese": "esnya/japanese_speecht5_tts",
58
- "Italian": "tts_models/it/tacotron2",
59
- "Korean": "facebook/mms-tts-kor"
60
  }
61
 
62
  # --------------------------------------------------
@@ -66,12 +40,8 @@ translator_cache = {}
66
  tts_cache = {}
67
 
68
  def get_translator(target_language):
69
- """
70
- Retrieve or create a translation pipeline for the specified language.
71
- """
72
  if target_language in translator_cache:
73
  return translator_cache[target_language]
74
-
75
  model_name = translation_models[target_language]
76
  task_name = translation_tasks[target_language]
77
  translator = pipeline(task_name, model=model_name)
@@ -79,23 +49,15 @@ def get_translator(target_language):
79
  return translator
80
 
81
  def get_tts(target_language):
82
- """
83
- Retrieve or create a TTS pipeline for the specified language.
84
- """
85
  if target_language in tts_cache:
86
  return tts_cache[target_language]
87
-
88
  model_name = tts_models.get(target_language)
89
  if model_name is None:
90
  raise ValueError(f"No TTS model available for {target_language}.")
91
-
92
  try:
93
  tts_pipeline = pipeline("text-to-speech", model=model_name)
94
  except Exception as e:
95
- raise ValueError(
96
- f"Failed to load TTS model for {target_language} with model '{model_name}'.\nError: {e}"
97
- )
98
-
99
  tts_cache[target_language] = tts_pipeline
100
  return tts_pipeline
101
 
@@ -103,12 +65,7 @@ def get_tts(target_language):
103
  # Prediction Function
104
  # --------------------------------------------------
105
  def predict(audio, text, target_language):
106
- """
107
- 1. Obtain English text (from text input or ASR).
108
- 2. Translate English -> target_language.
109
- 3. Synthesize speech in target_language.
110
- """
111
- # Step 1: Get English text from text input (if provided) or from ASR.
112
  if text.strip():
113
  english_text = text.strip()
114
  elif audio is not None:
@@ -125,7 +82,7 @@ def predict(audio, text, target_language):
125
  else:
126
  return "No input provided.", "", None
127
 
128
- # Step 2: Translation
129
  translator = get_translator(target_language)
130
  try:
131
  translation_result = translator(english_text)
@@ -133,11 +90,10 @@ def predict(audio, text, target_language):
133
  except Exception as e:
134
  return english_text, f"Translation error: {e}", None
135
 
136
- # Step 3: TTS synthesis using Facebook MMS TTS (or alternative) pipeline.
137
  try:
138
  tts_pipeline = get_tts(target_language)
139
  tts_result = tts_pipeline(translated_text)
140
- # Expected output: a dict with "wav" and "sample_rate"
141
  synthesized_audio = (tts_result["sample_rate"], tts_result["wav"])
142
  except Exception as e:
143
  return english_text, translated_text, f"TTS error: {e}"
@@ -163,9 +119,8 @@ iface = gr.Interface(
163
  description=(
164
  "This app provides three outputs:\n"
165
  "1. English transcription (from ASR or text input),\n"
166
- "2. Translation to a target language (using Helsinki-NLP models), and\n"
167
  "3. Synthetic speech in the target language (using Facebook MMS TTS or equivalent).\n\n"
168
- "Select one of the top 10 commonly used languages from the dropdown.\n"
169
  "Either record/upload an English audio sample or enter English text directly."
170
  ),
171
  allow_flagging="never"
 
3
  import numpy as np
4
  import librosa
5
  from transformers import pipeline
 
 
6
 
7
  # --------------------------------------------------
8
  # ASR Pipeline (for English transcription)
 
13
  )
14
 
15
  # --------------------------------------------------
16
+ # Mapping for Target Languages (Spanish, Chinese, Japanese)
17
  # --------------------------------------------------
18
  translation_models = {
19
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
 
 
20
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
21
+ "Japanese": "Helsinki-NLP/opus-mt-en-ja"
 
 
 
 
 
22
  }
23
 
24
  translation_tasks = {
25
  "Spanish": "translation_en_to_es",
 
 
26
  "Chinese": "translation_en_to_zh",
27
+ "Japanese": "translation_en_to_ja"
 
 
 
 
 
28
  }
29
 
 
 
 
30
  tts_models = {
31
  "Spanish": "facebook/mms-tts-spa",
 
 
32
  "Chinese": "facebook/mms-tts-che",
33
+ "Japanese": "esnya/japanese_speecht5_tts"
 
 
 
 
 
34
  }
35
 
36
  # --------------------------------------------------
 
40
  tts_cache = {}
41
 
42
  def get_translator(target_language):
 
 
 
43
  if target_language in translator_cache:
44
  return translator_cache[target_language]
 
45
  model_name = translation_models[target_language]
46
  task_name = translation_tasks[target_language]
47
  translator = pipeline(task_name, model=model_name)
 
49
  return translator
50
 
51
  def get_tts(target_language):
 
 
 
52
  if target_language in tts_cache:
53
  return tts_cache[target_language]
 
54
  model_name = tts_models.get(target_language)
55
  if model_name is None:
56
  raise ValueError(f"No TTS model available for {target_language}.")
 
57
  try:
58
  tts_pipeline = pipeline("text-to-speech", model=model_name)
59
  except Exception as e:
60
+ raise ValueError(f"Failed to load TTS model for {target_language} with model '{model_name}'.\nError: {e}")
 
 
 
61
  tts_cache[target_language] = tts_pipeline
62
  return tts_pipeline
63
 
 
65
  # Prediction Function
66
  # --------------------------------------------------
67
  def predict(audio, text, target_language):
68
+ # Step 1: Obtain English text from text input if provided, otherwise use ASR.
 
 
 
 
 
69
  if text.strip():
70
  english_text = text.strip()
71
  elif audio is not None:
 
82
  else:
83
  return "No input provided.", "", None
84
 
85
+ # Step 2: Translate the English text to the target language.
86
  translator = get_translator(target_language)
87
  try:
88
  translation_result = translator(english_text)
 
90
  except Exception as e:
91
  return english_text, f"Translation error: {e}", None
92
 
93
+ # Step 3: Synthesize speech using the TTS pipeline.
94
  try:
95
  tts_pipeline = get_tts(target_language)
96
  tts_result = tts_pipeline(translated_text)
 
97
  synthesized_audio = (tts_result["sample_rate"], tts_result["wav"])
98
  except Exception as e:
99
  return english_text, translated_text, f"TTS error: {e}"
 
119
  description=(
120
  "This app provides three outputs:\n"
121
  "1. English transcription (from ASR or text input),\n"
122
+ "2. Translation to Spanish, Chinese, or Japanese (using Helsinki-NLP models), and\n"
123
  "3. Synthetic speech in the target language (using Facebook MMS TTS or equivalent).\n\n"
 
124
  "Either record/upload an English audio sample or enter English text directly."
125
  ),
126
  allow_flagging="never"