DHEIVER's picture
Create app.py
23dd469 verified
raw
history blame
4.06 kB
import torch
import torchaudio
import scipy.io.wavfile
from transformers import AutoProcessor, SeamlessM4Tv2Model
from pathlib import Path
from typing import Optional, Union
class SeamlessTranslator:
"""
A wrapper class for Facebook's SeamlessM4T translation model.
Handles both text-to-speech and speech-to-speech translation.
"""
def __init__(self, model_name: str = "facebook/seamless-m4t-v2-large"):
"""
Initialize the translator with the specified model.
Args:
model_name (str): Name of the model to use
"""
try:
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = SeamlessM4Tv2Model.from_pretrained(model_name)
self.sample_rate = self.model.config.sampling_rate
except Exception as e:
raise RuntimeError(f"Failed to initialize model: {str(e)}")
def translate_text(self, text: str, src_lang: str, tgt_lang: str) -> numpy.ndarray:
"""
Translate text to speech in the target language.
Args:
text (str): Input text to translate
src_lang (str): Source language code (e.g., 'eng')
tgt_lang (str): Target language code (e.g., 'rus')
Returns:
numpy.ndarray: Audio waveform array
"""
try:
inputs = self.processor(text=text, src_lang=src_lang, return_tensors="pt")
audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
return audio_array
except Exception as e:
raise RuntimeError(f"Text translation failed: {str(e)}")
def translate_audio(self, audio_path: Union[str, Path], tgt_lang: str) -> numpy.ndarray:
"""
Translate audio to speech in the target language.
Args:
audio_path (str or Path): Path to input audio file
tgt_lang (str): Target language code (e.g., 'rus')
Returns:
numpy.ndarray: Audio waveform array
"""
try:
# Load and resample audio
audio, orig_freq = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(
audio,
orig_freq=orig_freq,
new_freq=16_000
)
# Process and generate translation
inputs = self.processor(audios=audio, return_tensors="pt")
audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
return audio_array
except Exception as e:
raise RuntimeError(f"Audio translation failed: {str(e)}")
def save_audio(self, audio_array: numpy.ndarray, output_path: Union[str, Path]) -> None:
"""
Save an audio array to a WAV file.
Args:
audio_array (numpy.ndarray): Audio data to save
output_path (str or Path): Path where to save the WAV file
"""
try:
scipy.io.wavfile.write(
output_path,
rate=self.sample_rate,
data=audio_array
)
except Exception as e:
raise RuntimeError(f"Failed to save audio: {str(e)}")
def main():
"""Example usage of the SeamlessTranslator class."""
try:
# Initialize translator
translator = SeamlessTranslator()
# Example text translation
text_audio = translator.translate_text(
text="Hello, my dog is cute",
src_lang="eng",
tgt_lang="rus"
)
translator.save_audio(text_audio, "output_from_text.wav")
# Example audio translation
audio_audio = translator.translate_audio(
audio_path="input_audio.wav",
tgt_lang="rus"
)
translator.save_audio(audio_audio, "output_from_audio.wav")
except Exception as e:
print(f"Translation failed: {str(e)}")
if __name__ == "__main__":
main()