|
import os |
|
import argparse |
|
import torch |
|
import soundfile as sf |
|
import logging |
|
from datetime import datetime |
|
|
|
from cli.SparkTTS import SparkTTS |
|
|
|
|
|
def parse_args(): |
|
"""Parse command-line arguments.""" |
|
parser = argparse.ArgumentParser(description="Run TTS inference.") |
|
|
|
parser.add_argument("--model_dir", type=str, default="pretrained_models/Spark-TTS-0.5B", |
|
help="Path to the model directory") |
|
parser.add_argument("--save_dir", type=str, default="example/results", |
|
help="Directory to save generated audio files") |
|
parser.add_argument("--device", type=int, default=0, help="CUDA device number") |
|
parser.add_argument("--text", type=str, required=True, help="Text for TTS generation") |
|
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio") |
|
parser.add_argument("--prompt_speech_path", type=str, required=True, |
|
help="Path to the prompt audio file") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def run_tts(args): |
|
"""Perform TTS inference and save the generated audio.""" |
|
logging.info(f"Using model from: {args.model_dir}") |
|
logging.info(f"Saving audio to: {args.save_dir}") |
|
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
|
|
|
|
device = torch.device(f"cuda:{args.device}") |
|
|
|
|
|
model = SparkTTS(args.model_dir, device) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
save_path = os.path.join(args.save_dir, f"{timestamp}.wav") |
|
|
|
logging.info("Starting inference...") |
|
|
|
|
|
with torch.no_grad(): |
|
wav = model.inference(args.text, args.prompt_speech_path, prompt_text=args.prompt_text) |
|
sf.write(save_path, wav, samplerate=16000) |
|
|
|
logging.info(f"Audio saved at: {save_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
|
args = parse_args() |
|
run_tts(args) |
|
|