Lorenzob commited on
Commit
d98c95d
·
verified ·
1 Parent(s): 93bd931

Fix speaker embeddings for SpeechT5 model

Browse files
Files changed (1) hide show
  1. app.py +61 -41
app.py CHANGED
@@ -21,10 +21,10 @@ MODEL_REPO = "Lorenzob/aurora-1.6b-complete" # Repository del modello completo
21
  CACHE_DIR = "./model_cache" # Directory per la cache del modello
22
  SAMPLE_RATE = 24000 # Frequenza di campionamento
23
 
24
- # Cache per il modello e il processor
25
  processor = None
26
  model = None
27
- speaker_embeddings = None
28
 
29
  def download_file(url, save_path):
30
  """Scarica un file da un URL"""
@@ -38,6 +38,45 @@ def download_file(url, save_path):
38
 
39
  return save_path
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def ensure_model_config():
42
  """Assicura che il modello abbia una configurazione corretta"""
43
  try:
@@ -75,21 +114,6 @@ def ensure_model_config():
75
  print(f"Errore nella configurazione del modello: {e}")
76
  return None
77
 
78
- def get_speaker_embeddings():
79
- """Ottieni gli speaker embeddings per il modello TTS"""
80
- global speaker_embeddings
81
-
82
- if speaker_embeddings is None:
83
- try:
84
- # Crea gli speaker embeddings predefiniti (vettore di zeri)
85
- speaker_embeddings = torch.zeros(1, 512)
86
- print("Speaker embeddings creati con successo")
87
- except Exception as e:
88
- print(f"Errore nella creazione degli speaker embeddings: {e}")
89
- speaker_embeddings = None
90
-
91
- return speaker_embeddings
92
-
93
  def load_model_and_processor():
94
  """Carica il modello e il processor con caricamento manuale della configurazione"""
95
  global model, processor
@@ -150,6 +174,9 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
150
  # Carica il modello e il processor
151
  model, processor = load_model_and_processor()
152
 
 
 
 
153
  # Controlla se stiamo usando il modello di Microsoft
154
  is_microsoft_model = "microsoft" in str(type(model))
155
 
@@ -158,16 +185,6 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
158
  if show_log:
159
  print("Utilizzo del modello Microsoft SpeechT5...")
160
 
161
- # Carica speaker embeddings
162
- speaker_embeddings_path = f"https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors/resolve/main/cmu_us_{speaker_id:02d}_xvector.pt"
163
- try:
164
- tmp_file = os.path.join(tempfile.gettempdir(), f"speaker_{speaker_id}.pt")
165
- download_file(speaker_embeddings_path, tmp_file)
166
- speaker_embeddings = torch.load(tmp_file)
167
- except:
168
- # Usa un embedding predefinito
169
- speaker_embeddings = torch.zeros(1, 512)
170
-
171
  # Crea input IDs dal testo
172
  inputs = processor(text=text, return_tensors="pt")
173
 
@@ -175,7 +192,7 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
175
  with torch.no_grad():
176
  speech = model.generate_speech(
177
  inputs["input_ids"],
178
- speaker_embeddings
179
  )
180
 
181
  # Imposta la frequenza di campionamento
@@ -185,9 +202,6 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
185
  if show_log:
186
  print("Utilizzo del modello Aurora-1.6b-complete...")
187
 
188
- # Ottieni gli speaker embeddings
189
- speaker_emb = get_speaker_embeddings()
190
-
191
  # Prepara gli input
192
  inputs = processor(
193
  text=text,
@@ -200,7 +214,10 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
200
  if hasattr(v, "to"):
201
  inputs[k] = v.to(model.device)
202
 
203
- # Genera il speech - NOTA: Qui non passiamo parametri non supportati
 
 
 
204
  with torch.no_grad():
205
  if hasattr(model, "generate_speech") and callable(model.generate_speech):
206
  # Usa generate_speech se disponibile
@@ -209,9 +226,10 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
209
  speaker_emb
210
  )
211
  else:
212
- # Altrimenti prova con generate
213
  speech = model.generate(
214
- **inputs
 
215
  )
216
 
217
  # Imposta la frequenza di campionamento
@@ -239,11 +257,11 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
239
  # Esempi predefiniti per l'interfaccia
240
  examples = [
241
  ["Ciao, mi chiamo Aurora e sono un assistente vocale italiano.", "it", 0, 1.0, False],
242
- ["Hello, my name is Aurora and I'm an Italian voice assistant.", "en", 0, 1.0, False],
243
- ["Hola, me llamo Aurora y soy un asistente de voz italiano.", "es", 0, 1.0, False],
244
  ["La vita è bella e il sole splende nel cielo azzurro.", "it", 0, 1.0, False],
245
- ["Mi piace viaggiare e scoprire nuove città e culture.", "it", 0, 1.2, False],
246
- ["L'intelligenza artificiale sta trasformando il modo in cui interagiamo con i computer e con il mondo che ci circonda.", "it", 0, 0.9, False]
247
  ]
248
 
249
  # Definizione dell'interfaccia Gradio
@@ -272,13 +290,13 @@ with gr.Blocks(title="Aurora-1.6b TTS Demo", theme=gr.themes.Soft()) as demo:
272
  value="it",
273
  info="Seleziona la lingua del testo"
274
  )
275
- speaker_input = gr.Number(
276
  label="Speaker ID",
277
  value=0,
278
  minimum=0,
279
- maximum=10,
280
  step=1,
281
- info="ID dello speaker (solo per modelli multi-speaker)"
282
  )
283
  speed_input = gr.Slider(
284
  minimum=0.5,
@@ -312,6 +330,7 @@ with gr.Blocks(title="Aurora-1.6b TTS Demo", theme=gr.themes.Soft()) as demo:
312
 
313
  - Il modello funziona meglio con frasi di lunghezza media (fino a 20-30 parole)
314
  - Per l'italiano, il modello è stato ottimizzato per una pronuncia naturale
 
315
  - La velocità di generazione dipende dalle risorse disponibili sul server
316
 
317
  ## 🔗 Crediti
@@ -320,6 +339,7 @@ with gr.Blocks(title="Aurora-1.6b TTS Demo", theme=gr.themes.Soft()) as demo:
320
  - [Lorenzob/aurora-1.6b](https://huggingface.co/Lorenzob/aurora-1.6b) (versione fine-tuned)
321
  - [Lorenzob/aurora-1.6b-complete](https://huggingface.co/Lorenzob/aurora-1.6b-complete) (versione completa con pesi)
322
  - [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) (modello base originale)
 
323
  """)
324
 
325
  # Configurazione degli eventi
 
21
  CACHE_DIR = "./model_cache" # Directory per la cache del modello
22
  SAMPLE_RATE = 24000 # Frequenza di campionamento
23
 
24
+ # Cache per il modello, processor e speaker embeddings
25
  processor = None
26
  model = None
27
+ speaker_embeddings_cache = {}
28
 
29
  def download_file(url, save_path):
30
  """Scarica un file da un URL"""
 
38
 
39
  return save_path
40
 
41
+ def get_speaker_embeddings(speaker_id=0):
42
+ """Ottieni gli speaker embeddings dal dataset CMU Arctic"""
43
+ global speaker_embeddings_cache
44
+
45
+ if speaker_id in speaker_embeddings_cache:
46
+ return speaker_embeddings_cache[speaker_id]
47
+
48
+ try:
49
+ # Limita lo speaker_id a un intervallo valido (0-9)
50
+ speaker_id = max(0, min(9, speaker_id))
51
+
52
+ # Genera l'URL per gli embeddings
53
+ url = f"https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors/resolve/main/cmu_us_{speaker_id:02d}_xvector.pt"
54
+
55
+ # Crea un file temporaneo per gli embeddings
56
+ tmp_dir = os.path.join(CACHE_DIR, "speakers")
57
+ os.makedirs(tmp_dir, exist_ok=True)
58
+ tmp_file = os.path.join(tmp_dir, f"speaker_{speaker_id:02d}.pt")
59
+
60
+ # Scarica gli embeddings se non esistono già
61
+ if not os.path.exists(tmp_file):
62
+ print(f"Scaricamento degli speaker embeddings per lo speaker {speaker_id}...")
63
+ download_file(url, tmp_file)
64
+
65
+ # Carica gli embeddings
66
+ speaker_embeddings = torch.load(tmp_file)
67
+
68
+ # Memorizza gli embeddings nella cache
69
+ speaker_embeddings_cache[speaker_id] = speaker_embeddings
70
+
71
+ print(f"Speaker embeddings caricati per lo speaker {speaker_id}")
72
+ return speaker_embeddings
73
+ except Exception as e:
74
+ print(f"Errore nel caricamento degli speaker embeddings: {e}")
75
+ # Crea dei default embeddings
76
+ default_embeddings = torch.zeros(1, 512)
77
+ speaker_embeddings_cache[speaker_id] = default_embeddings
78
+ return default_embeddings
79
+
80
  def ensure_model_config():
81
  """Assicura che il modello abbia una configurazione corretta"""
82
  try:
 
114
  print(f"Errore nella configurazione del modello: {e}")
115
  return None
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def load_model_and_processor():
118
  """Carica il modello e il processor con caricamento manuale della configurazione"""
119
  global model, processor
 
174
  # Carica il modello e il processor
175
  model, processor = load_model_and_processor()
176
 
177
+ # Ottieni gli speaker embeddings
178
+ speaker_emb = get_speaker_embeddings(speaker_id)
179
+
180
  # Controlla se stiamo usando il modello di Microsoft
181
  is_microsoft_model = "microsoft" in str(type(model))
182
 
 
185
  if show_log:
186
  print("Utilizzo del modello Microsoft SpeechT5...")
187
 
 
 
 
 
 
 
 
 
 
 
188
  # Crea input IDs dal testo
189
  inputs = processor(text=text, return_tensors="pt")
190
 
 
192
  with torch.no_grad():
193
  speech = model.generate_speech(
194
  inputs["input_ids"],
195
+ speaker_emb
196
  )
197
 
198
  # Imposta la frequenza di campionamento
 
202
  if show_log:
203
  print("Utilizzo del modello Aurora-1.6b-complete...")
204
 
 
 
 
205
  # Prepara gli input
206
  inputs = processor(
207
  text=text,
 
214
  if hasattr(v, "to"):
215
  inputs[k] = v.to(model.device)
216
 
217
+ # Sposta gli speaker embeddings sul dispositivo di calcolo
218
+ speaker_emb = speaker_emb.to(model.device)
219
+
220
+ # Genera il speech
221
  with torch.no_grad():
222
  if hasattr(model, "generate_speech") and callable(model.generate_speech):
223
  # Usa generate_speech se disponibile
 
226
  speaker_emb
227
  )
228
  else:
229
+ # Prova a passare gli speaker embeddings come parametro
230
  speech = model.generate(
231
+ **inputs,
232
+ speaker_embeddings=speaker_emb
233
  )
234
 
235
  # Imposta la frequenza di campionamento
 
257
  # Esempi predefiniti per l'interfaccia
258
  examples = [
259
  ["Ciao, mi chiamo Aurora e sono un assistente vocale italiano.", "it", 0, 1.0, False],
260
+ ["Hello, my name is Aurora and I'm an Italian voice assistant.", "en", 2, 1.0, False],
261
+ ["Hola, me llamo Aurora y soy un asistente de voz italiano.", "es", 4, 1.0, False],
262
  ["La vita è bella e il sole splende nel cielo azzurro.", "it", 0, 1.0, False],
263
+ ["Mi piace viaggiare e scoprire nuove città e culture.", "it", 7, 1.2, False],
264
+ ["L'intelligenza artificiale sta trasformando il modo in cui interagiamo con i computer e con il mondo che ci circonda.", "it", 9, 0.9, False]
265
  ]
266
 
267
  # Definizione dell'interfaccia Gradio
 
290
  value="it",
291
  info="Seleziona la lingua del testo"
292
  )
293
+ speaker_input = gr.Slider(
294
  label="Speaker ID",
295
  value=0,
296
  minimum=0,
297
+ maximum=9,
298
  step=1,
299
+ info="ID dello speaker (0-9, ogni ID ha caratteristiche vocali diverse)"
300
  )
301
  speed_input = gr.Slider(
302
  minimum=0.5,
 
330
 
331
  - Il modello funziona meglio con frasi di lunghezza media (fino a 20-30 parole)
332
  - Per l'italiano, il modello è stato ottimizzato per una pronuncia naturale
333
+ - Puoi cambiare lo Speaker ID per ottenere voci con caratteristiche diverse
334
  - La velocità di generazione dipende dalle risorse disponibili sul server
335
 
336
  ## 🔗 Crediti
 
339
  - [Lorenzob/aurora-1.6b](https://huggingface.co/Lorenzob/aurora-1.6b) (versione fine-tuned)
340
  - [Lorenzob/aurora-1.6b-complete](https://huggingface.co/Lorenzob/aurora-1.6b-complete) (versione completa con pesi)
341
  - [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) (modello base originale)
342
+ - [CMU Arctic XVectors](https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors) (speaker embeddings)
343
  """)
344
 
345
  # Configurazione degli eventi