|
""" |
|
Alias module to redirect whisper imports to whisperx. |
|
This allows OuteTTS to use whisperx instead of the standard whisper package. |
|
""" |
|
|
|
import sys |
|
import importlib.util |
|
|
|
def setup_whisper_alias(): |
|
"""Setup alias so that 'import whisper' uses whisperx instead.""" |
|
try: |
|
|
|
whisperx_spec = importlib.util.find_spec("whisperx") |
|
if whisperx_spec is None: |
|
print("Warning: whisperx not found, falling back to regular whisper") |
|
return |
|
|
|
|
|
import whisperx |
|
|
|
|
|
class WhisperAlias: |
|
def __init__(self): |
|
self.model = whisperx.WhisperModel if hasattr(whisperx, 'WhisperModel') else None |
|
self.load_model = self._load_model |
|
|
|
def _load_model(self, name, **kwargs): |
|
"""Load model with whisperx compatible interface.""" |
|
|
|
device = "cuda" if kwargs.get("device", "auto") == "cuda" else "cpu" |
|
compute_type = "float16" if device == "cuda" else "int8" |
|
|
|
model = whisperx.load_model( |
|
name, |
|
device=device, |
|
compute_type=compute_type |
|
) |
|
|
|
return WhisperXModelWrapper(model, device) |
|
|
|
class WhisperXModelWrapper: |
|
"""Wrapper to make whisperx compatible with whisper interface.""" |
|
|
|
def __init__(self, model, device): |
|
self.model = model |
|
self.device = device |
|
|
|
def transcribe(self, audio, **kwargs): |
|
"""Transcribe audio with whisper-compatible interface.""" |
|
|
|
original_word_timestamps = kwargs.get('word_timestamps', False) |
|
|
|
|
|
if isinstance(audio, str): |
|
audio_data = whisperx.load_audio(audio) |
|
else: |
|
audio_data = audio |
|
|
|
|
|
batch_size = kwargs.get('batch_size', 16) |
|
result = self.model.transcribe(audio_data, batch_size=batch_size) |
|
|
|
|
|
if original_word_timestamps and result.get("segments"): |
|
try: |
|
|
|
model_a, metadata = whisperx.load_align_model( |
|
language_code=result.get("language", "en"), |
|
device=self.device |
|
) |
|
|
|
|
|
result = whisperx.align( |
|
result["segments"], |
|
model_a, |
|
metadata, |
|
audio_data, |
|
self.device, |
|
return_char_alignments=False |
|
) |
|
except Exception as e: |
|
print(f"Warning: Could not perform alignment: {e}") |
|
|
|
|
|
|
|
if "segments" not in result: |
|
result["segments"] = [] |
|
|
|
|
|
if "text" not in result: |
|
result["text"] = " ".join([segment.get("text", "") for segment in result.get("segments", [])]) |
|
|
|
|
|
for segment in result.get("segments", []): |
|
if original_word_timestamps and "words" not in segment: |
|
|
|
segment["words"] = [] |
|
|
|
return result |
|
|
|
|
|
whisper_alias = WhisperAlias() |
|
|
|
|
|
sys.modules['whisper'] = whisper_alias |
|
|
|
print("✅ Successfully aliased whisper to whisperx") |
|
|
|
except ImportError as e: |
|
print(f"Warning: Could not setup whisper alias: {e}") |
|
print("Falling back to regular whisper (if available)") |
|
except Exception as e: |
|
print(f"Warning: Error setting up whisper alias: {e}") |
|
|
|
|
|
setup_whisper_alias() |