SparkTTS / inference.py
spark-tts
init
7aa1435
raw
history blame
2.1 kB
import os
import argparse
import torch
import soundfile as sf
import logging
from datetime import datetime
from cli.SparkTTS import SparkTTS
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Run TTS inference.")
parser.add_argument("--model_dir", type=str, default="pretrained_models/Spark-TTS-0.5B",
help="Path to the model directory")
parser.add_argument("--save_dir", type=str, default="example/results",
help="Directory to save generated audio files")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--text", type=str, required=True, help="Text for TTS generation")
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
parser.add_argument("--prompt_speech_path", type=str, required=True,
help="Path to the prompt audio file")
return parser.parse_args()
def run_tts(args):
"""Perform TTS inference and save the generated audio."""
logging.info(f"Using model from: {args.model_dir}")
logging.info(f"Saving audio to: {args.save_dir}")
# Ensure the save directory exists
os.makedirs(args.save_dir, exist_ok=True)
# Convert device argument to torch.device
device = torch.device(f"cuda:{args.device}")
# Initialize the model
model = SparkTTS(args.model_dir, device)
# Generate unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(args.save_dir, f"{timestamp}.wav")
logging.info("Starting inference...")
# Perform inference and save the output audio
with torch.no_grad():
wav = model.inference(args.text, args.prompt_speech_path, prompt_text=args.prompt_text)
sf.write(save_path, wav, samplerate=16000)
logging.info(f"Audio saved at: {save_path}")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
args = parse_args()
run_tts(args)