Nick021402 commited on
Commit
0bdd1cb
·
verified ·
1 Parent(s): 05f10a7

Update tts_engine.py

Browse files
Files changed (1) hide show
  1. tts_engine.py +75 -57
tts_engine.py CHANGED
@@ -1,53 +1,71 @@
1
- # tts_engine.py - TTS engine wrapper for Nari DIA
2
  import logging
3
  import os
4
  from typing import Optional
5
  import tempfile
6
  import numpy as np
7
  import soundfile as sf
8
- import torch # Import torch for model operations
9
 
10
- # Import the actual Nari DIA model
11
- try:
12
- from dia.model import Dia
13
- except ImportError:
14
- logging.error("Nari DIA library not found. Please ensure 'git+https://github.com/nari-labs/dia.git' is in your requirements.txt and installed.")
15
- Dia = None # Set to None to prevent further errors
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
- class NariDIAEngine:
20
  def __init__(self):
 
21
  self.model = None
22
- # No separate processor object for Dia, it handles internal processing
 
23
  self._initialize_model()
24
 
25
  def _initialize_model(self):
26
- """Initialize the Nari DIA 1.6B model."""
27
- if Dia is None:
28
- logger.error("Nari DIA library is not available. Cannot initialize model.")
29
- return
30
-
31
  try:
32
- logger.info("Initializing Nari DIA 1.6B model from nari-labs/Dia-1.6B...")
33
 
34
- # Load the Nari DIA model
35
- # Use compute_dtype="float16" for potentially better performance/memory on GPU
36
- # Ensure you have a GPU with ~10GB VRAM for this.
37
- self.model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
38
 
39
- # Move model to GPU if available
40
- if torch.cuda.is_available():
41
- self.model.to("cuda")
42
- logger.info("Nari DIA model moved to GPU (CUDA).")
43
- else:
44
- logger.warning("CUDA not available. Nari DIA model will run on CPU, which is not officially supported and will be very slow.")
45
 
46
- logger.info("Nari DIA model initialized successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  except Exception as e:
49
- logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True)
 
50
  self.model = None
 
51
 
52
  def synthesize_segment(
53
  self,
@@ -56,7 +74,7 @@ class NariDIAEngine:
56
  output_path: str
57
  ) -> Optional[str]:
58
  """
59
- Synthesize speech for a text segment using Nari DIA.
60
 
61
  Args:
62
  text: Text to synthesize
@@ -66,45 +84,45 @@ class NariDIAEngine:
66
  Returns:
67
  Path to the generated audio file, or None if failed
68
  """
69
- if not self.model:
70
- logger.error("Nari DIA model not initialized. Cannot synthesize speech.")
71
  return None
72
 
73
  try:
74
- # Nari DIA expects [S1] or [S2] tags.
75
- # The segmenter is directly outputting "S1" or "S2".
76
- # We just need to wrap it in brackets.
77
- if speaker in ["S1", "S2"]:
78
- dia_speaker_tag = f"[{speaker}]"
79
- else:
80
- # Fallback in case segmenter outputs something unexpected
81
- logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].")
82
- dia_speaker_tag = "[S1]"
83
-
84
- # Nari DIA expects the speaker tag at the beginning of the segment
85
- full_text_input = f"{dia_speaker_tag} {text}"
86
 
87
- # Generate audio using the Nari DIA model
88
- logger.info(f"Synthesizing with Nari DIA: {full_text_input[:100]}...") # Log beginning of text
89
 
90
- # Pass the text directly to the model's generate method
91
- # Nari DIA's Dia class handles internal processing/tokenization
 
92
  with torch.no_grad():
93
- # The .generate method should return audio waveform as a PyTorch tensor
94
- audio_waveform_tensor = self.model.generate(full_text_input)
95
- audio_waveform = audio_waveform_tensor.cpu().numpy().squeeze()
 
 
 
 
 
 
96
 
97
- # Nari DIA's sampling rate is typically 22050 Hz.
98
- # If the Dia model object itself exposes a sampling_rate attribute, use it.
99
- # Otherwise, default to 22050 as it's common for TTS models.
100
- sampling_rate = getattr(self.model, 'sampling_rate', 22050)
101
 
102
- # Save as WAV file
103
  sf.write(output_path, audio_waveform, sampling_rate)
104
 
105
- logger.info(f"Generated audio for {speaker} ({dia_speaker_tag}): {len(text)} characters to {output_path}")
106
  return output_path
107
 
108
  except Exception as e:
109
- logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback
110
  return None
 
 
1
+ # tts_engine.py - TTS engine wrapper for CPU-friendly SpeechT5
2
  import logging
3
  import os
4
  from typing import Optional
5
  import tempfile
6
  import numpy as np
7
  import soundfile as sf
8
+ import torch
9
 
10
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
11
+ from datasets import load_dataset # To get speaker embeddings from VCTK
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+ class CPUMultiSpeakerTTS:
16
  def __init__(self):
17
+ self.processor = None
18
  self.model = None
19
+ self.vocoder = None
20
+ self.speaker_embeddings = {} # Will store speaker embeddings for S1, S2 etc.
21
  self._initialize_model()
22
 
23
  def _initialize_model(self):
24
+ """Initialize the SpeechT5 model and vocoder on CPU."""
 
 
 
 
25
  try:
26
+ logger.info("Initializing SpeechT5 model for CPU...")
27
 
28
+ self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
29
+ self.model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
30
+ self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
 
31
 
32
+ # Ensure all components are on CPU explicitly
33
+ self.model.to("cpu")
34
+ self.vocoder.to("cpu")
35
+ logger.info("SpeechT5 model and vocoder initialized successfully on CPU.")
 
 
36
 
37
+ # Load speaker embeddings for multiple voices
38
+ logger.info("Loading VCTK dataset for speaker embeddings...")
39
+ # VCTK is a multi-speaker dataset used with SpeechT5
40
+ # We'll pick a few representative speaker embeddings for S1, S2, etc.
41
+
42
+ # This loads the 'xvector' split of the vctk dataset which contains pre-computed embeddings
43
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
44
+
45
+ # Map 'S1' and 'S2' to specific speaker embeddings from the dataset
46
+ # You can pick any speaker IDs from the dataset.
47
+ # Common ones from VCTK for examples are 'p280', 'p272', 'p232', 'p249' etc.
48
+ # Let's map S1 to a male voice and S2 to a female voice from common VCTK examples.
49
+
50
+ # You can get a list of available speakers from the dataset:
51
+ # print(embeddings_dataset.features['speaker_id'].names)
52
+
53
+ # Let's use two distinct speakers for S1 and S2
54
+ # These are common speaker IDs from VCTK used in SpeechT5 examples
55
+ self.speaker_embeddings["S1"] = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0) # Speaker p280
56
+ self.speaker_embeddings["S2"] = torch.tensor(embeddings_dataset[1]["xvector"]).unsqueeze(0) # Speaker p272
57
+
58
+ # Ensure embeddings are also on CPU
59
+ self.speaker_embeddings["S1"] = self.speaker_embeddings["S1"].to("cpu")
60
+ self.speaker_embeddings["S2"] = self.speaker_embeddings["S2"].to("cpu")
61
+
62
+ logger.info("Speaker embeddings loaded for S1 and S2.")
63
 
64
  except Exception as e:
65
+ logger.error(f"Failed to initialize TTS model (SpeechT5): {e}", exc_info=True)
66
+ self.processor = None
67
  self.model = None
68
+ self.vocoder = None
69
 
70
  def synthesize_segment(
71
  self,
 
74
  output_path: str
75
  ) -> Optional[str]:
76
  """
77
+ Synthesize speech for a text segment using SpeechT5.
78
 
79
  Args:
80
  text: Text to synthesize
 
84
  Returns:
85
  Path to the generated audio file, or None if failed
86
  """
87
+ if not self.model or not self.processor or not self.vocoder:
88
+ logger.error("SpeechT5 model, processor, or vocoder not initialized. Cannot synthesize speech.")
89
  return None
90
 
91
  try:
92
+ # Get the correct speaker embedding
93
+ speaker_embedding = self.speaker_embeddings.get(speaker)
94
+ if speaker_embedding is None:
95
+ logger.warning(f"Speaker '{speaker}' not found in pre-loaded embeddings. Defaulting to S1.")
96
+ speaker_embedding = self.speaker_embeddings["S1"] # Fallback to S1
97
+
98
+ logger.info(f"Synthesizing text for speaker {speaker}: {text[:100]}...")
 
 
 
 
 
99
 
100
+ # Prepare inputs
101
+ inputs = self.processor(text=text, return_tensors="pt")
102
 
103
+ # Ensure inputs are on CPU
104
+ inputs = {k: v.to("cpu") for k, v in inputs.items()}
105
+
106
  with torch.no_grad():
107
+ # Generate speech
108
+ # SpeechT5 returns logits/features, which then need to be passed to the vocoder
109
+ speech = self.model.generate_speech(
110
+ inputs["input_ids"],
111
+ speaker_embedding, # Pass the speaker embedding here
112
+ vocoder=self.vocoder
113
+ )
114
+
115
+ audio_waveform = speech.cpu().numpy().squeeze()
116
 
117
+ # Sampling rate from the vocoder or model config (typically 16000 for SpeechT5)
118
+ sampling_rate = self.vocoder.config.sampling_rate if hasattr(self.vocoder.config, 'sampling_rate') else 16000
 
 
119
 
 
120
  sf.write(output_path, audio_waveform, sampling_rate)
121
 
122
+ logger.info(f"Generated audio for {speaker}: {len(text)} characters to {output_path}")
123
  return output_path
124
 
125
  except Exception as e:
126
+ logger.error(f"Failed to synthesize segment with SpeechT5: {e}", exc_info=True)
127
  return None
128
+