Michael Hu commited on
Commit
c7f7521
·
1 Parent(s): cf7f5a3

remove legacy impl

Browse files
utils/stt.py DELETED
@@ -1,175 +0,0 @@
1
- """
2
- Speech Recognition Module
3
- Supports multiple ASR models including Whisper and Parakeet
4
- Handles audio preprocessing and transcription
5
- """
6
-
7
- import logging
8
- import numpy as np
9
- import os
10
- from abc import ABC, abstractmethod
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
- from faster_whisper import WhisperModel as FasterWhisperModel
15
- from pydub import AudioSegment
16
-
17
- class ASRModel(ABC):
18
- """Base class for ASR models"""
19
-
20
- @abstractmethod
21
- def load_model(self):
22
- """Load the ASR model"""
23
- pass
24
-
25
- @abstractmethod
26
- def transcribe(self, audio_path):
27
- """Transcribe audio to text"""
28
- pass
29
-
30
- def preprocess_audio(self, audio_path):
31
- """Convert audio to required format"""
32
- logger.info("Converting audio format")
33
- audio = AudioSegment.from_file(audio_path)
34
- processed_audio = audio.set_frame_rate(16000).set_channels(1)
35
- wav_path = audio_path.replace(".mp3", ".wav") if audio_path.endswith(".mp3") else audio_path
36
- if not wav_path.endswith(".wav"):
37
- wav_path = f"{os.path.splitext(wav_path)[0]}.wav"
38
- processed_audio.export(wav_path, format="wav")
39
- logger.info(f"Audio converted to: {wav_path}")
40
- return wav_path
41
-
42
-
43
- class WhisperModel(ASRModel):
44
- """Faster Whisper ASR model implementation"""
45
-
46
- def __init__(self):
47
- self.model = None
48
- # Check for CUDA availability without torch dependency
49
- try:
50
- import torch
51
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
- except ImportError:
53
- # Fallback to CPU if torch is not available
54
- self.device = "cpu"
55
- self.compute_type = "float16" if self.device == "cuda" else "int8"
56
-
57
- def load_model(self):
58
- """Load Faster Whisper model"""
59
- logger.info("Loading Faster Whisper model")
60
- logger.info(f"Using device: {self.device}")
61
- logger.info(f"Using compute type: {self.compute_type}")
62
-
63
- # Use large-v3 model with appropriate compute type based on device
64
- self.model = FasterWhisperModel(
65
- "large-v3",
66
- device=self.device,
67
- compute_type=self.compute_type
68
- )
69
- logger.info("Faster Whisper model loaded successfully")
70
-
71
- def transcribe(self, audio_path):
72
- """Transcribe audio using Faster Whisper"""
73
- if self.model is None:
74
- self.load_model()
75
-
76
- wav_path = self.preprocess_audio(audio_path)
77
-
78
- # Transcription with Faster Whisper
79
- logger.info("Generating transcription with Faster Whisper")
80
- segments, info = self.model.transcribe(
81
- wav_path,
82
- beam_size=5,
83
- language="en",
84
- task="transcribe"
85
- )
86
-
87
- logger.info(f"Detected language '{info.language}' with probability {info.language_probability}")
88
-
89
- # Collect all segments into a single text
90
- result_text = ""
91
- for segment in segments:
92
- result_text += segment.text + " "
93
- logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
94
-
95
- result = result_text.strip()
96
- logger.info(f"Transcription completed successfully")
97
- return result
98
-
99
-
100
- class ParakeetModel(ASRModel):
101
- """Parakeet ASR model implementation"""
102
-
103
- def __init__(self):
104
- self.model = None
105
-
106
- def load_model(self):
107
- """Load Parakeet model"""
108
- try:
109
- import nemo.collections.asr as nemo_asr
110
- logger.info("Loading Parakeet model")
111
- self.model = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/parakeet-tdt-0.6b-v2")
112
- logger.info("Parakeet model loaded successfully")
113
- except ImportError:
114
- logger.error("Failed to import nemo_toolkit. Please install with: pip install -U 'nemo_toolkit[asr]'")
115
- raise
116
-
117
- def transcribe(self, audio_path):
118
- """Transcribe audio using Parakeet"""
119
- if self.model is None:
120
- self.load_model()
121
-
122
- wav_path = self.preprocess_audio(audio_path)
123
-
124
- # Transcription
125
- logger.info("Generating transcription with Parakeet")
126
- output = self.model.transcribe([wav_path])
127
- result = output[0].text
128
- logger.info(f"Transcription completed successfully")
129
- return result
130
-
131
-
132
- class ASRFactory:
133
- """Factory for creating ASR model instances"""
134
-
135
- @staticmethod
136
- def get_model(model_name="parakeet"):
137
- """
138
- Get ASR model by name
139
- Args:
140
- model_name: Name of the model to use (whisper or parakeet)
141
- Returns:
142
- ASR model instance
143
- """
144
- if model_name.lower() == "whisper":
145
- return WhisperModel()
146
- elif model_name.lower() == "parakeet":
147
- return ParakeetModel()
148
- else:
149
- logger.warning(f"Unknown model: {model_name}, falling back to Whisper")
150
- return WhisperModel()
151
-
152
-
153
- def transcribe_audio(audio_path, model_name="parakeet"):
154
- """
155
- Convert audio file to text using specified ASR model
156
- Args:
157
- audio_path: Path to input audio file
158
- model_name: Name of the ASR model to use (whisper or parakeet)
159
- Returns:
160
- Transcribed English text
161
- """
162
- logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
163
-
164
- try:
165
- # Get the appropriate model
166
- asr_model = ASRFactory.get_model(model_name)
167
-
168
- # Transcribe audio
169
- result = asr_model.transcribe(audio_path)
170
- logger.info(f"transcription: %s" % result)
171
- return result
172
-
173
- except Exception as e:
174
- logger.error(f"Transcription failed: {str(e)}", exc_info=True)
175
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/translation.py DELETED
@@ -1,65 +0,0 @@
1
- """
2
- Text Translation Module using NLLB-3.3B model
3
- Handles text segmentation and batch translation
4
- """
5
-
6
- import logging
7
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- def translate_text(text):
12
- """
13
- Translate English text to Simplified Chinese
14
- Args:
15
- text: Input English text
16
- Returns:
17
- Translated Chinese text
18
- """
19
- logger.info(f"Starting translation for text length: {len(text)}")
20
-
21
- try:
22
- # Model initialization with explicit language codes
23
- logger.info("Loading NLLB model")
24
- tokenizer = AutoTokenizer.from_pretrained(
25
- "facebook/nllb-200-3.3B",
26
- src_lang="eng_Latn" # Specify source language
27
- )
28
- model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B")
29
- logger.info("Translation model loaded")
30
-
31
- # Text processing
32
- max_chunk_length = 1000
33
- text_chunks = [text[i:i+max_chunk_length] for i in range(0, len(text), max_chunk_length)]
34
- logger.info(f"Split text into {len(text_chunks)} chunks")
35
-
36
- translated_chunks = []
37
- for i, chunk in enumerate(text_chunks):
38
- logger.info(f"Processing chunk {i+1}/{len(text_chunks)}")
39
-
40
- # Tokenize with source language specification
41
- inputs = tokenizer(
42
- chunk,
43
- return_tensors="pt",
44
- max_length=1024,
45
- truncation=True
46
- )
47
-
48
- # Generate translation with target language specification
49
- outputs = model.generate(
50
- **inputs,
51
- forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
52
- max_new_tokens=1024
53
- )
54
-
55
- translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- translated_chunks.append(translated)
57
- logger.info(f"Chunk {i+1} translated successfully")
58
-
59
- result = "".join(translated_chunks)
60
- logger.info(f"Translation completed. Total length: {len(result)}")
61
- return result
62
-
63
- except Exception as e:
64
- logger.error(f"Translation failed: {str(e)}", exc_info=True)
65
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts.py DELETED
@@ -1,126 +0,0 @@
1
- import logging
2
- from typing import Optional, Generator, Tuple, List, Dict, Any
3
- import numpy as np
4
-
5
- # Import the base class and dummy implementation
6
- from utils.tts_base import TTSBase
7
- from utils.tts_dummy import DummyTTS
8
-
9
- # Import the specific TTS implementations
10
- from utils.tts_kokoro import KokoroTTS, KOKORO_AVAILABLE
11
- from utils.tts_dia import DiaTTS, DIA_AVAILABLE
12
- from utils.tts_cosyvoice2 import CosyVoice2TTS, COSYVOICE2_AVAILABLE
13
-
14
- # Configure logging
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- def get_available_engines() -> List[str]:
19
- """Get a list of available TTS engines
20
-
21
- Returns:
22
- List[str]: List of available engine names
23
- """
24
- available = []
25
-
26
- if KOKORO_AVAILABLE:
27
- available.append('kokoro')
28
-
29
- if DIA_AVAILABLE:
30
- available.append('dia')
31
-
32
- if COSYVOICE2_AVAILABLE:
33
- available.append('cosyvoice2')
34
-
35
- # Dummy is always available
36
- available.append('dummy')
37
-
38
- return available
39
-
40
-
41
- def get_tts_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSBase:
42
- """Get a TTS engine instance
43
-
44
- Args:
45
- engine_type (str, optional): Type of engine to create ('kokoro', 'dia', 'cosyvoice2', 'dummy')
46
- If None, the best available engine will be used
47
- lang_code (str): Language code for the engine
48
-
49
- Returns:
50
- TTSBase: An instance of a TTS engine
51
- """
52
- # Get available engines
53
- available_engines = get_available_engines()
54
- logger.info(f"Available TTS engines: {available_engines}")
55
-
56
- # If engine_type is specified, try to create that specific engine
57
- if engine_type is not None:
58
- if engine_type == 'kokoro' and KOKORO_AVAILABLE:
59
- logger.info("Creating Kokoro TTS engine")
60
- return KokoroTTS(lang_code)
61
- elif engine_type == 'dia' and DIA_AVAILABLE:
62
- logger.info("Creating Dia TTS engine")
63
- return DiaTTS(lang_code)
64
- elif engine_type == 'cosyvoice2' and COSYVOICE2_AVAILABLE:
65
- logger.info("Creating CosyVoice2 TTS engine")
66
- return CosyVoice2TTS(lang_code)
67
- elif engine_type == 'dummy':
68
- logger.info("Creating Dummy TTS engine")
69
- return DummyTTS(lang_code)
70
- else:
71
- logger.warning(f"Requested engine '{engine_type}' is not available")
72
-
73
- # If no specific engine is requested or the requested engine is not available,
74
- # use the best available engine based on priority
75
- priority_order = ['cosyvoice2', 'kokoro', 'dia', 'dummy']
76
- for engine in priority_order:
77
- if engine in available_engines:
78
- logger.info(f"Using best available engine: {engine}")
79
- if engine == 'kokoro':
80
- return KokoroTTS(lang_code)
81
- elif engine == 'dia':
82
- return DiaTTS(lang_code)
83
- elif engine == 'cosyvoice2':
84
- return CosyVoice2TTS(lang_code)
85
- elif engine == 'dummy':
86
- return DummyTTS(lang_code)
87
-
88
- # Fallback to dummy engine if no engines are available
89
- logger.warning("No TTS engines available, falling back to dummy engine")
90
- return DummyTTS(lang_code)
91
-
92
-
93
- def generate_speech(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
94
- voice: str = 'default', speed: float = 1.0) -> Optional[str]:
95
- """Generate speech using the specified or best available TTS engine
96
-
97
- Args:
98
- text (str): Input text to synthesize
99
- engine_type (str, optional): Type of engine to use
100
- lang_code (str): Language code
101
- voice (str): Voice ID to use
102
- speed (float): Speech speed multiplier
103
-
104
- Returns:
105
- Optional[str]: Path to the generated audio file or None if generation fails
106
- """
107
- engine = get_tts_engine(engine_type, lang_code)
108
- return engine.generate_speech(text, voice, speed)
109
-
110
-
111
- def generate_speech_stream(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
112
- voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
113
- """Generate speech stream using the specified or best available TTS engine
114
-
115
- Args:
116
- text (str): Input text to synthesize
117
- engine_type (str, optional): Type of engine to use
118
- lang_code (str): Language code
119
- voice (str): Voice ID to use
120
- speed (float): Speech speed multiplier
121
-
122
- Yields:
123
- tuple: (sample_rate, audio_data) pairs for each segment
124
- """
125
- engine = get_tts_engine(engine_type, lang_code)
126
- yield from engine.generate_speech_stream(text, voice, speed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts_README.md DELETED
@@ -1,64 +0,0 @@
1
- # TTS Structure
2
-
3
- This directory contains a Text-to-Speech (TTS) implementation that supports three specific models:
4
-
5
- 1. Kokoro: https://github.com/hexgrad/kokoro
6
- 2. Dia: https://github.com/nari-labs/dia
7
- 3. CosyVoice2: https://github.com/nari-labs/dia
8
-
9
- ## Structure
10
-
11
- The TTS implementation follows a simple, clean structure:
12
-
13
- - `tts.py`: Contains the base `TTSBase` abstract class and `DummyTTS` implementation
14
- - `tts_kokoro.py`: Kokoro TTS implementation
15
- - `tts_dia.py`: Dia TTS implementation
16
- - `tts_cosyvoice2.py`: CosyVoice2 TTS implementation
17
- - `tts_main.py`: Main entry point for TTS functionality
18
-
19
- ## Usage
20
-
21
- ```python
22
- # Import the main TTS functions
23
- from utils.tts_main import generate_speech, generate_speech_stream, get_tts_engine
24
-
25
- # Generate speech using the best available engine
26
- audio_path = generate_speech("Hello, world!")
27
-
28
- # Generate speech using a specific engine
29
- audio_path = generate_speech("Hello, world!", engine_type="kokoro")
30
-
31
- # Generate speech with specific parameters
32
- audio_path = generate_speech(
33
- "Hello, world!",
34
- engine_type="dia",
35
- lang_code="en",
36
- voice="default",
37
- speed=1.0
38
- )
39
-
40
- # Generate speech stream
41
- for sample_rate, audio_data in generate_speech_stream("Hello, world!"):
42
- # Process audio data
43
- pass
44
-
45
- # Get a specific TTS engine instance
46
- engine = get_tts_engine("kokoro")
47
- audio_path = engine.generate_speech("Hello, world!")
48
- ```
49
-
50
- ## Error Handling
51
-
52
- All TTS implementations include robust error handling:
53
-
54
- 1. Each implementation checks for the availability of its dependencies
55
- 2. If a specific engine fails, it automatically falls back to the `DummyTTS` implementation
56
- 3. The main module prioritizes engines based on availability
57
-
58
- ## Adding New Engines
59
-
60
- To add a new TTS engine:
61
-
62
- 1. Create a new file `tts_<engine_name>.py`
63
- 2. Implement a class that inherits from `TTSBase`
64
- 3. Add the engine to the available engines list in `tts_main.py`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts_base.py DELETED
@@ -1,69 +0,0 @@
1
- import logging
2
- import os
3
- import time
4
- import numpy as np
5
- import soundfile as sf
6
- from typing import Optional, Generator, Tuple, List
7
- from abc import ABC, abstractmethod
8
-
9
- # Configure logging
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class TTSBase(ABC):
14
- """Base class for all TTS engines
15
-
16
- This abstract class defines the interface that all TTS engines must implement.
17
- """
18
-
19
- def __init__(self, lang_code: str = 'z'):
20
- """Initialize the TTS engine
21
-
22
- Args:
23
- lang_code (str): Language code for the engine
24
- """
25
- self.lang_code = lang_code
26
-
27
- @abstractmethod
28
- def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
29
- """Generate speech from text
30
-
31
- Args:
32
- text (str): Input text to synthesize
33
- voice (str): Voice ID to use
34
- speed (float): Speech speed multiplier
35
-
36
- Returns:
37
- Optional[str]: Path to the generated audio file or None if generation fails
38
- """
39
- pass
40
-
41
- @abstractmethod
42
- def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
43
- """Generate speech stream from text
44
-
45
- Args:
46
- text (str): Input text to synthesize
47
- voice (str): Voice ID to use
48
- speed (float): Speech speed multiplier
49
-
50
- Yields:
51
- tuple: (sample_rate, audio_data) pairs for each segment
52
- """
53
- pass
54
-
55
- def _generate_output_path(self, prefix: str = "tts", extension: str = "wav") -> str:
56
- """Generate a unique output path for the audio file
57
-
58
- Args:
59
- prefix (str): Prefix for the filename
60
- extension (str): File extension
61
-
62
- Returns:
63
- str: Path to the output file
64
- """
65
- timestamp = int(time.time() * 1000)
66
- filename = f"{prefix}_{timestamp}.{extension}"
67
- output_dir = os.path.join(os.getcwd(), "output")
68
- os.makedirs(output_dir, exist_ok=True)
69
- return os.path.join(output_dir, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts_cosyvoice2.py DELETED
@@ -1,209 +0,0 @@
1
- import logging
2
- import numpy as np
3
- import soundfile as sf
4
- from typing import Optional, Generator, Tuple
5
-
6
- from utils.tts_base import TTSBase
7
-
8
- # Configure logging
9
- logger = logging.getLogger(__name__)
10
-
11
- # Flag to track CosyVoice2 availability
12
- COSYVOICE2_AVAILABLE = False
13
- DEFAULT_SAMPLE_RATE = 24000
14
-
15
- # Try to import CosyVoice2 dependencies
16
- try:
17
- import torch
18
- import torchaudio
19
- # Import CosyVoice2 from the correct package
20
- # Based on https://github.com/FunAudioLLM/CosyVoice
21
- from cosyvoice.cli.cosyvoice import CosyVoice
22
- COSYVOICE2_AVAILABLE = True
23
- logger.info("CosyVoice2 TTS engine is available")
24
- except ImportError as e:
25
- logger.warning(f"CosyVoice2 TTS engine is not available - ImportError: {str(e)}")
26
- COSYVOICE2_AVAILABLE = False
27
- except ModuleNotFoundError as e:
28
- logger.warning(f"CosyVoice2 TTS engine is not available - ModuleNotFoundError: {str(e)}")
29
- COSYVOICE2_AVAILABLE = False
30
-
31
-
32
- def _get_model():
33
- """Lazy-load the CosyVoice2 model
34
-
35
- Returns:
36
- CosyVoice2 or None: The CosyVoice2 model or None if not available
37
- """
38
- if not COSYVOICE2_AVAILABLE:
39
- logger.warning("CosyVoice2 TTS engine is not available")
40
- return None
41
-
42
- try:
43
- import torch
44
- import torchaudio
45
- from cosyvoice.cli.cosyvoice import CosyVoice
46
-
47
- # Initialize the model with correct path
48
- model = CosyVoice('pretrained_models/CosyVoice-300M')
49
- logger.info("CosyVoice2 model successfully loaded")
50
- return model
51
- except ImportError as e:
52
- logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}")
53
- return None
54
- except FileNotFoundError as e:
55
- logger.error(f"Failed to load CosyVoice2 model files: {str(e)}")
56
- return None
57
- except Exception as e:
58
- logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}")
59
- return None
60
-
61
-
62
- class CosyVoice2TTS(TTSBase):
63
- """CosyVoice2 TTS engine implementation
64
-
65
- This engine uses the CosyVoice2 model for TTS generation.
66
- """
67
-
68
- def __init__(self, lang_code: str = 'z'):
69
- """Initialize the CosyVoice2 TTS engine
70
-
71
- Args:
72
- lang_code (str): Language code for the engine
73
- """
74
- super().__init__(lang_code)
75
- self.model = None
76
-
77
- def _ensure_model(self):
78
- """Ensure the model is loaded
79
-
80
- Returns:
81
- bool: True if model is available, False otherwise
82
- """
83
- if self.model is None:
84
- self.model = _get_model()
85
-
86
- return self.model is not None
87
-
88
- def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
89
- """Generate speech using CosyVoice2 TTS engine
90
-
91
- Args:
92
- text (str): Input text to synthesize
93
- voice (str): Voice ID (may not be used in CosyVoice2)
94
- speed (float): Speech speed multiplier (may not be used in CosyVoice2)
95
-
96
- Returns:
97
- Optional[str]: Path to the generated audio file or None if generation fails
98
- """
99
- logger.info(f"Generating speech with CosyVoice2 for text length: {len(text)}")
100
-
101
- # Check if CosyVoice2 is available
102
- if not COSYVOICE2_AVAILABLE:
103
- logger.error("CosyVoice2 TTS engine is not available")
104
- return None
105
-
106
- # Ensure model is loaded
107
- if not self._ensure_model():
108
- logger.error("Failed to load CosyVoice2 model")
109
- return None
110
-
111
- try:
112
- import torch
113
-
114
- # Generate unique output path
115
- output_path = self._generate_output_path(prefix="cosyvoice2")
116
-
117
- # Generate audio using CosyVoice2
118
- try:
119
- # Use the inference method from CosyVoice
120
- output_audio_tensor = self.model.inference_sft(text, '中文女')
121
-
122
- # Convert tensor to numpy array
123
- if isinstance(output_audio_tensor, torch.Tensor):
124
- output_audio_np = output_audio_tensor.cpu().numpy()
125
- else:
126
- output_audio_np = output_audio_tensor
127
- except Exception as api_error:
128
- # Try alternative API if the first one fails
129
- try:
130
- output_audio_tensor = self.model.inference_zero_shot(text, '请输入提示文本', '中文女')
131
- if isinstance(output_audio_tensor, torch.Tensor):
132
- output_audio_np = output_audio_tensor.cpu().numpy()
133
- else:
134
- output_audio_np = output_audio_tensor
135
- except Exception as alt_error:
136
- logger.error(f"CosyVoice2 inference failed: {str(api_error)}")
137
- return None
138
-
139
- if output_audio_np is not None:
140
- logger.info(f"Successfully generated audio with CosyVoice2 (length: {len(output_audio_np)})")
141
- sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
142
- logger.info(f"CosyVoice2 audio generation complete: {output_path}")
143
- return output_path
144
- else:
145
- logger.error("CosyVoice2 model returned None for audio output")
146
- return None
147
-
148
- except Exception as e:
149
- logger.error(f"Error generating speech with CosyVoice2: {str(e)}", exc_info=True)
150
- return None
151
-
152
- def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
153
- """Generate speech stream using CosyVoice2 TTS engine
154
-
155
- Args:
156
- text (str): Input text to synthesize
157
- voice (str): Voice ID (may not be used in CosyVoice2)
158
- speed (float): Speech speed multiplier (may not be used in CosyVoice2)
159
-
160
- Yields:
161
- tuple: (sample_rate, audio_data) pairs for each segment
162
- """
163
- logger.info(f"Generating speech stream with CosyVoice2 for text length: {len(text)}")
164
-
165
- # Check if CosyVoice2 is available
166
- if not COSYVOICE2_AVAILABLE:
167
- logger.error("CosyVoice2 TTS engine is not available")
168
- return
169
-
170
- # Ensure model is loaded
171
- if not self._ensure_model():
172
- logger.error("Failed to load CosyVoice2 model")
173
- return
174
-
175
- try:
176
- import torch
177
-
178
- # Generate audio using CosyVoice2
179
- try:
180
- # Use the inference method from CosyVoice
181
- output_audio_tensor = self.model.inference_sft(text, '中文女')
182
-
183
- # Convert tensor to numpy array
184
- if isinstance(output_audio_tensor, torch.Tensor):
185
- output_audio_np = output_audio_tensor.cpu().numpy()
186
- else:
187
- output_audio_np = output_audio_tensor
188
- except Exception as api_error:
189
- # Try alternative API if the first one fails
190
- try:
191
- output_audio_tensor = self.model.inference_zero_shot(text, '请输入提示文本', '中文女')
192
- if isinstance(output_audio_tensor, torch.Tensor):
193
- output_audio_np = output_audio_tensor.cpu().numpy()
194
- else:
195
- output_audio_np = output_audio_tensor
196
- except Exception as alt_error:
197
- logger.error(f"CosyVoice2 inference failed: {str(api_error)}")
198
- return
199
-
200
- if output_audio_np is not None:
201
- logger.info(f"Successfully generated audio with CosyVoice2 (length: {len(output_audio_np)})")
202
- yield DEFAULT_SAMPLE_RATE, output_audio_np
203
- else:
204
- logger.error("CosyVoice2 model returned None for audio output")
205
- return
206
-
207
- except Exception as e:
208
- logger.error(f"Error generating speech stream with CosyVoice2: {str(e)}", exc_info=True)
209
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts_dia.py DELETED
@@ -1,201 +0,0 @@
1
- import logging
2
- import numpy as np
3
- import soundfile as sf
4
- from typing import Optional, Generator, Tuple
5
-
6
- from utils.tts_base import TTSBase
7
-
8
- # Configure logging
9
- logger = logging.getLogger(__name__)
10
-
11
- # Flag to track Dia availability
12
- DIA_AVAILABLE = False
13
- DEFAULT_SAMPLE_RATE = 24000
14
-
15
- # Try to import Dia dependencies
16
- try:
17
- import torch
18
- from dia.model import Dia
19
- DIA_AVAILABLE = True
20
- logger.info("Dia TTS engine is available")
21
- except ImportError:
22
- logger.warning("Dia TTS engine is not available")
23
- except ModuleNotFoundError as e:
24
- if "dac" in str(e):
25
- logger.warning("Dia TTS engine is not available due to missing 'dac' module")
26
- else:
27
- logger.warning(f"Dia TTS engine is not available: {str(e)}")
28
- DIA_AVAILABLE = False
29
-
30
-
31
- def _get_model():
32
- """Lazy-load the Dia model
33
-
34
- Returns:
35
- Dia or None: The Dia model or None if not available
36
- """
37
- if not DIA_AVAILABLE:
38
- logger.warning("Dia TTS engine is not available")
39
- return None
40
-
41
- try:
42
- import torch
43
- from dia.model import Dia
44
-
45
- # Initialize the model
46
- model = Dia.from_pretrained()
47
- logger.info("Dia model successfully loaded")
48
- return model
49
- except ImportError as e:
50
- logger.error(f"Failed to import Dia dependencies: {str(e)}")
51
- return None
52
- except FileNotFoundError as e:
53
- logger.error(f"Failed to load Dia model files: {str(e)}")
54
- return None
55
- except Exception as e:
56
- logger.error(f"Failed to initialize Dia model: {str(e)}")
57
- return None
58
-
59
-
60
- class DiaTTS(TTSBase):
61
- """Dia TTS engine implementation
62
-
63
- This engine uses the Dia model for TTS generation.
64
- """
65
-
66
- def __init__(self, lang_code: str = 'z'):
67
- """Initialize the Dia TTS engine
68
-
69
- Args:
70
- lang_code (str): Language code for the engine
71
- """
72
- super().__init__(lang_code)
73
- self.model = None
74
-
75
- def _ensure_model(self):
76
- """Ensure the model is loaded
77
-
78
- Returns:
79
- bool: True if model is available, False otherwise
80
- """
81
- if self.model is None:
82
- self.model = _get_model()
83
-
84
- return self.model is not None
85
-
86
- def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
87
- """Generate speech using Dia TTS engine
88
-
89
- Args:
90
- text (str): Input text to synthesize
91
- voice (str): Voice ID (not used in Dia)
92
- speed (float): Speech speed multiplier (not used in Dia)
93
-
94
- Returns:
95
- Optional[str]: Path to the generated audio file or None if generation fails
96
- """
97
- logger.info(f"Generating speech with Dia for text length: {len(text)}")
98
-
99
- # Check if Dia is available
100
- if not DIA_AVAILABLE:
101
- logger.error("Dia TTS engine is not available")
102
- return None
103
-
104
- # Ensure model is loaded
105
- if not self._ensure_model():
106
- logger.error("Failed to load Dia model")
107
- return None
108
-
109
- try:
110
- import torch
111
-
112
- # Generate unique output path
113
- output_path = self._generate_output_path(prefix="dia")
114
-
115
- # Generate audio
116
- with torch.inference_mode():
117
- output_audio_np = self.model.generate(
118
- text,
119
- max_tokens=None,
120
- cfg_scale=3.0,
121
- temperature=1.3,
122
- top_p=0.95,
123
- cfg_filter_top_k=35,
124
- use_torch_compile=False,
125
- verbose=False
126
- )
127
-
128
- if output_audio_np is not None:
129
- logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
130
- sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
131
- logger.info(f"Dia audio generation complete: {output_path}")
132
- return output_path
133
- else:
134
- logger.error("Dia model returned None for audio output")
135
- return None
136
-
137
- except ModuleNotFoundError as e:
138
- if "dac" in str(e):
139
- logger.error("Dia TTS engine failed due to missing 'dac' module")
140
- else:
141
- logger.error(f"Module not found error in Dia TTS: {str(e)}")
142
- return None
143
- except Exception as e:
144
- logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
145
- return None
146
-
147
- def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
148
- """Generate speech stream using Dia TTS engine
149
-
150
- Args:
151
- text (str): Input text to synthesize
152
- voice (str): Voice ID (not used in Dia)
153
- speed (float): Speech speed multiplier (not used in Dia)
154
-
155
- Yields:
156
- tuple: (sample_rate, audio_data) pairs for each segment
157
- """
158
- logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
159
-
160
- # Check if Dia is available
161
- if not DIA_AVAILABLE:
162
- logger.error("Dia TTS engine is not available")
163
- return
164
-
165
- # Ensure model is loaded
166
- if not self._ensure_model():
167
- logger.error("Failed to load Dia model")
168
- return
169
-
170
- try:
171
- import torch
172
-
173
- # Generate audio
174
- with torch.inference_mode():
175
- output_audio_np = self.model.generate(
176
- text,
177
- max_tokens=None,
178
- cfg_scale=3.0,
179
- temperature=1.3,
180
- top_p=0.95,
181
- cfg_filter_top_k=35,
182
- use_torch_compile=False,
183
- verbose=False
184
- )
185
-
186
- if output_audio_np is not None:
187
- logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
188
- yield DEFAULT_SAMPLE_RATE, output_audio_np
189
- else:
190
- logger.error("Dia model returned None for audio output")
191
- return
192
-
193
- except ModuleNotFoundError as e:
194
- if "dac" in str(e):
195
- logger.error("Dia TTS engine failed due to missing 'dac' module")
196
- else:
197
- logger.error(f"Module not found error in Dia TTS: {str(e)}")
198
- return
199
- except Exception as e:
200
- logger.error(f"Error generating speech stream with Dia: {str(e)}", exc_info=True)
201
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts_dummy.py DELETED
@@ -1,65 +0,0 @@
1
- import logging
2
- import os
3
- import time
4
- import numpy as np
5
- import soundfile as sf
6
- from typing import Optional, Generator, Tuple, List
7
- from .tts_base import TTSBase
8
-
9
- # Configure logging
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class DummyTTS(TTSBase):
14
- """Dummy TTS engine that generates sine wave audio
15
-
16
- This class is used as a fallback when no other TTS engine is available.
17
- """
18
-
19
- def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> str:
20
- """Generate a dummy sine wave audio file
21
-
22
- Args:
23
- text (str): Input text (not used)
24
- voice (str): Voice ID (not used)
25
- speed (float): Speech speed multiplier (not used)
26
-
27
- Returns:
28
- str: Path to the generated audio file
29
- """
30
- logger.info(f"Generating dummy speech for text length: {len(text)}")
31
-
32
- # Generate a simple sine wave
33
- sample_rate = 24000
34
- duration = min(len(text) / 20, 10) # Rough approximation of speech duration
35
- t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
36
- audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave
37
-
38
- # Save to file
39
- output_path = self._generate_output_path(prefix="dummy")
40
- sf.write(output_path, audio, sample_rate)
41
-
42
- logger.info(f"Generated dummy audio: {output_path}")
43
- return output_path
44
-
45
- def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
46
- """Generate a dummy sine wave audio stream
47
-
48
- Args:
49
- text (str): Input text (not used)
50
- voice (str): Voice ID (not used)
51
- speed (float): Speech speed multiplier (not used)
52
-
53
- Yields:
54
- tuple: (sample_rate, audio_data) pairs
55
- """
56
- logger.info(f"Generating dummy speech stream for text length: {len(text)}")
57
-
58
- # Generate a simple sine wave
59
- sample_rate = 24000
60
- duration = min(len(text) / 20, 10) # Rough approximation of speech duration
61
- t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
62
- audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave
63
-
64
- # Yield the audio data
65
- yield sample_rate, audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts_kokoro.py DELETED
@@ -1,144 +0,0 @@
1
- import logging
2
- import numpy as np
3
- import soundfile as sf
4
- from typing import Optional, Generator, Tuple
5
-
6
- from utils.tts_base import TTSBase
7
-
8
- # Configure logging
9
- logger = logging.getLogger(__name__)
10
-
11
- # Flag to track Kokoro availability
12
- KOKORO_AVAILABLE = False
13
-
14
- # Try to import Kokoro
15
- try:
16
- from kokoro import KPipeline
17
- KOKORO_AVAILABLE = True
18
- logger.info("Kokoro TTS engine is available")
19
- except ImportError:
20
- logger.warning("Kokoro TTS engine is not available")
21
- except Exception as e:
22
- logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
23
- KOKORO_AVAILABLE = False
24
-
25
-
26
- def _get_pipeline(lang_code: str = 'z'):
27
- """Lazy-load the Kokoro pipeline
28
-
29
- Args:
30
- lang_code (str): Language code for the pipeline
31
-
32
- Returns:
33
- KPipeline or None: The Kokoro pipeline or None if not available
34
- """
35
- if not KOKORO_AVAILABLE:
36
- logger.warning("Kokoro TTS engine is not available")
37
- return None
38
-
39
- try:
40
- pipeline = KPipeline(lang_code=lang_code)
41
- logger.info("Kokoro pipeline successfully loaded")
42
- return pipeline
43
- except Exception as e:
44
- logger.error(f"Failed to initialize Kokoro pipeline: {str(e)}")
45
- return None
46
-
47
-
48
- class KokoroTTS(TTSBase):
49
- """Kokoro TTS engine implementation
50
-
51
- This engine uses the Kokoro library for TTS generation.
52
- """
53
-
54
- def __init__(self, lang_code: str = 'z'):
55
- """Initialize the Kokoro TTS engine
56
-
57
- Args:
58
- lang_code (str): Language code for the engine
59
- """
60
- super().__init__(lang_code)
61
- self.pipeline = None
62
-
63
- def _ensure_pipeline(self):
64
- """Ensure the pipeline is loaded
65
-
66
- Returns:
67
- bool: True if pipeline is available, False otherwise
68
- """
69
- if self.pipeline is None:
70
- self.pipeline = _get_pipeline(self.lang_code)
71
-
72
- return self.pipeline is not None
73
-
74
- def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]:
75
- """Generate speech using Kokoro TTS engine
76
-
77
- Args:
78
- text (str): Input text to synthesize
79
- voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
80
- speed (float): Speech speed multiplier (0.5 to 2.0)
81
-
82
- Returns:
83
- Optional[str]: Path to the generated audio file or None if generation fails
84
- """
85
- logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
86
-
87
- # Check if Kokoro is available
88
- if not KOKORO_AVAILABLE:
89
- logger.error("Kokoro TTS engine is not available")
90
- return None
91
-
92
- # Ensure pipeline is loaded
93
- if not self._ensure_pipeline():
94
- logger.error("Failed to load Kokoro pipeline")
95
- return None
96
-
97
- try:
98
- # Generate unique output path
99
- output_path = self._generate_output_path(prefix="kokoro")
100
-
101
- # Generate speech
102
- generator = self.pipeline(text, voice=voice, speed=speed)
103
- for _, _, audio in generator:
104
- logger.info(f"Saving Kokoro audio to {output_path}")
105
- sf.write(output_path, audio, 24000)
106
- break
107
-
108
- logger.info(f"Kokoro audio generation complete: {output_path}")
109
- return output_path
110
- except Exception as e:
111
- logger.error(f"Error generating speech with Kokoro: {str(e)}", exc_info=True)
112
- return None
113
-
114
- def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
115
- """Generate speech stream using Kokoro TTS engine
116
-
117
- Args:
118
- text (str): Input text to synthesize
119
- voice (str): Voice ID to use
120
- speed (float): Speech speed multiplier
121
-
122
- Yields:
123
- tuple: (sample_rate, audio_data) pairs for each segment
124
- """
125
- logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}")
126
-
127
- # Check if Kokoro is available
128
- if not KOKORO_AVAILABLE:
129
- logger.error("Kokoro TTS engine is not available")
130
- return
131
-
132
- # Ensure pipeline is loaded
133
- if not self._ensure_pipeline():
134
- logger.error("Failed to load Kokoro pipeline")
135
- return
136
-
137
- try:
138
- # Generate speech stream
139
- generator = self.pipeline(text, voice=voice, speed=speed)
140
- for _, _, audio in generator:
141
- yield 24000, audio
142
- except Exception as e:
143
- logger.error(f"Error generating speech stream with Kokoro: {str(e)}", exc_info=True)
144
- return