Spaces:
Running
Running
# tts_engine.py - TTS engine wrapper for CPU-friendly SpeechT5 | |
import logging | |
import os | |
from typing import Optional | |
import tempfile | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
from datasets import load_dataset # To get speaker embeddings from VCTK | |
logger = logging.getLogger(__name__) | |
class CPUMultiSpeakerTTS: | |
def __init__(self): | |
self.processor = None | |
self.model = None | |
self.vocoder = None | |
self.speaker_embeddings = {} # Will store speaker embeddings for S1, S2 etc. | |
self._initialize_model() | |
def _initialize_model(self): | |
"""Initialize the SpeechT5 model and vocoder on CPU.""" | |
try: | |
logger.info("Initializing SpeechT5 model for CPU...") | |
self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
self.model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | |
self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
# Ensure all components are on CPU explicitly | |
self.model.to("cpu") | |
self.vocoder.to("cpu") | |
logger.info("SpeechT5 model and vocoder initialized successfully on CPU.") | |
# Load speaker embeddings for multiple voices | |
logger.info("Loading VCTK dataset for speaker embeddings...") | |
# VCTK is a multi-speaker dataset used with SpeechT5 | |
# We'll pick a few representative speaker embeddings for S1, S2, etc. | |
# This loads the 'xvector' split of the vctk dataset which contains pre-computed embeddings | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
# Map 'S1' and 'S2' to specific speaker embeddings from the dataset | |
# You can pick any speaker IDs from the dataset. | |
# Common ones from VCTK for examples are 'p280', 'p272', 'p232', 'p249' etc. | |
# Let's map S1 to a male voice and S2 to a female voice from common VCTK examples. | |
# You can get a list of available speakers from the dataset: | |
# print(embeddings_dataset.features['speaker_id'].names) | |
# Let's use two distinct speakers for S1 and S2 | |
# These are common speaker IDs from VCTK used in SpeechT5 examples | |
self.speaker_embeddings["S1"] = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0) # Speaker p280 | |
self.speaker_embeddings["S2"] = torch.tensor(embeddings_dataset[1]["xvector"]).unsqueeze(0) # Speaker p272 | |
# Ensure embeddings are also on CPU | |
self.speaker_embeddings["S1"] = self.speaker_embeddings["S1"].to("cpu") | |
self.speaker_embeddings["S2"] = self.speaker_embeddings["S2"].to("cpu") | |
logger.info("Speaker embeddings loaded for S1 and S2.") | |
except Exception as e: | |
logger.error(f"Failed to initialize TTS model (SpeechT5): {e}", exc_info=True) | |
self.processor = None | |
self.model = None | |
self.vocoder = None | |
def synthesize_segment( | |
self, | |
text: str, | |
speaker: str, # This will be 'S1' or 'S2' from segmenter | |
output_path: str | |
) -> Optional[str]: | |
""" | |
Synthesize speech for a text segment using SpeechT5. | |
Args: | |
text: Text to synthesize | |
speaker: Speaker identifier ('S1' or 'S2' expected from segmenter) | |
output_path: Path to save the audio file | |
Returns: | |
Path to the generated audio file, or None if failed | |
""" | |
if not self.model or not self.processor or not self.vocoder: | |
logger.error("SpeechT5 model, processor, or vocoder not initialized. Cannot synthesize speech.") | |
return None | |
try: | |
# Get the correct speaker embedding | |
speaker_embedding = self.speaker_embeddings.get(speaker) | |
if speaker_embedding is None: | |
logger.warning(f"Speaker '{speaker}' not found in pre-loaded embeddings. Defaulting to S1.") | |
speaker_embedding = self.speaker_embeddings["S1"] # Fallback to S1 | |
logger.info(f"Synthesizing text for speaker {speaker}: {text[:100]}...") | |
# Prepare inputs | |
inputs = self.processor(text=text, return_tensors="pt") | |
# Ensure inputs are on CPU | |
inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
with torch.no_grad(): | |
# Generate speech | |
# SpeechT5 returns logits/features, which then need to be passed to the vocoder | |
speech = self.model.generate_speech( | |
inputs["input_ids"], | |
speaker_embedding, # Pass the speaker embedding here | |
vocoder=self.vocoder | |
) | |
audio_waveform = speech.cpu().numpy().squeeze() | |
# Sampling rate from the vocoder or model config (typically 16000 for SpeechT5) | |
sampling_rate = self.vocoder.config.sampling_rate if hasattr(self.vocoder.config, 'sampling_rate') else 16000 | |
sf.write(output_path, audio_waveform, sampling_rate) | |
logger.info(f"Generated audio for {speaker}: {len(text)} characters to {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"Failed to synthesize segment with SpeechT5: {e}", exc_info=True) | |
return None | |