dia / cli.py
developer052's picture
Upload folder using huggingface_hub
a932168 verified
import argparse
import os
import random
import numpy as np
import soundfile as sf
import torch
from dia.model import Dia
def set_seed(seed: int):
"""Sets the random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Ensure deterministic behavior for cuDNN (if used)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
parser = argparse.ArgumentParser(description="Generate audio using the Dia model.")
parser.add_argument("text", type=str, help="Input text for speech generation.")
parser.add_argument(
"--output", type=str, required=True, help="Path to save the generated audio file (e.g., output.wav)."
)
parser.add_argument(
"--repo-id",
type=str,
default="nari-labs/Dia-1.6B",
help="Hugging Face repository ID (e.g., nari-labs/Dia-1.6B).",
)
parser.add_argument(
"--local-paths", action="store_true", help="Load model from local config and checkpoint files."
)
parser.add_argument(
"--config", type=str, help="Path to local config.json file (required if --local-paths is set)."
)
parser.add_argument(
"--checkpoint", type=str, help="Path to local model checkpoint .pth file (required if --local-paths is set)."
)
parser.add_argument(
"--audio-prompt", type=str, default=None, help="Path to an optional audio prompt WAV file for voice cloning."
)
gen_group = parser.add_argument_group("Generation Parameters")
gen_group.add_argument(
"--max-tokens",
type=int,
default=None,
help="Maximum number of audio tokens to generate (defaults to config value).",
)
gen_group.add_argument(
"--cfg-scale", type=float, default=3.0, help="Classifier-Free Guidance scale (default: 3.0)."
)
gen_group.add_argument(
"--temperature", type=float, default=1.3, help="Sampling temperature (higher is more random, default: 0.7)."
)
gen_group.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling probability (default: 0.95).")
infra_group = parser.add_argument_group("Infrastructure")
infra_group.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
infra_group.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on (e.g., 'cuda', 'cpu', default: auto).",
)
args = parser.parse_args()
# Validation for local paths
if args.local_paths:
if not args.config:
parser.error("--config is required when --local-paths is set.")
if not args.checkpoint:
parser.error("--checkpoint is required when --local-paths is set.")
if not os.path.exists(args.config):
parser.error(f"Config file not found: {args.config}")
if not os.path.exists(args.checkpoint):
parser.error(f"Checkpoint file not found: {args.checkpoint}")
# Set seed if provided
if args.seed is not None:
set_seed(args.seed)
print(f"Using random seed: {args.seed}")
# Determine device
device = torch.device(args.device)
print(f"Using device: {device}")
# Load model
print("Loading model...")
if args.local_paths:
print(f"Loading from local paths: config='{args.config}', checkpoint='{args.checkpoint}'")
try:
model = Dia.from_local(args.config, args.checkpoint, device=device)
except Exception as e:
print(f"Error loading local model: {e}")
exit(1)
else:
print(f"Loading from Hugging Face Hub: repo_id='{args.repo_id}'")
try:
model = Dia.from_pretrained(args.repo_id, device=device)
except Exception as e:
print(f"Error loading model from Hub: {e}")
exit(1)
print("Model loaded.")
# Generate audio
print("Generating audio...")
try:
sample_rate = 44100 # Default assumption
output_audio = model.generate(
text=args.text,
audio_prompt=args.audio_prompt,
max_tokens=args.max_tokens,
cfg_scale=args.cfg_scale,
temperature=args.temperature,
top_p=args.top_p,
)
print("Audio generation complete.")
print(f"Saving audio to {args.output}...")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
sf.write(args.output, output_audio, sample_rate)
print(f"Audio successfully saved to {args.output}")
except Exception as e:
print(f"Error during audio generation or saving: {e}")
exit(1)
if __name__ == "__main__":
main()