Spaces:
Runtime error
Runtime error
File size: 4,917 Bytes
a932168 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
import os
import random
import numpy as np
import soundfile as sf
import torch
from dia.model import Dia
def set_seed(seed: int):
"""Sets the random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Ensure deterministic behavior for cuDNN (if used)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
parser = argparse.ArgumentParser(description="Generate audio using the Dia model.")
parser.add_argument("text", type=str, help="Input text for speech generation.")
parser.add_argument(
"--output", type=str, required=True, help="Path to save the generated audio file (e.g., output.wav)."
)
parser.add_argument(
"--repo-id",
type=str,
default="nari-labs/Dia-1.6B",
help="Hugging Face repository ID (e.g., nari-labs/Dia-1.6B).",
)
parser.add_argument(
"--local-paths", action="store_true", help="Load model from local config and checkpoint files."
)
parser.add_argument(
"--config", type=str, help="Path to local config.json file (required if --local-paths is set)."
)
parser.add_argument(
"--checkpoint", type=str, help="Path to local model checkpoint .pth file (required if --local-paths is set)."
)
parser.add_argument(
"--audio-prompt", type=str, default=None, help="Path to an optional audio prompt WAV file for voice cloning."
)
gen_group = parser.add_argument_group("Generation Parameters")
gen_group.add_argument(
"--max-tokens",
type=int,
default=None,
help="Maximum number of audio tokens to generate (defaults to config value).",
)
gen_group.add_argument(
"--cfg-scale", type=float, default=3.0, help="Classifier-Free Guidance scale (default: 3.0)."
)
gen_group.add_argument(
"--temperature", type=float, default=1.3, help="Sampling temperature (higher is more random, default: 0.7)."
)
gen_group.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling probability (default: 0.95).")
infra_group = parser.add_argument_group("Infrastructure")
infra_group.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
infra_group.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on (e.g., 'cuda', 'cpu', default: auto).",
)
args = parser.parse_args()
# Validation for local paths
if args.local_paths:
if not args.config:
parser.error("--config is required when --local-paths is set.")
if not args.checkpoint:
parser.error("--checkpoint is required when --local-paths is set.")
if not os.path.exists(args.config):
parser.error(f"Config file not found: {args.config}")
if not os.path.exists(args.checkpoint):
parser.error(f"Checkpoint file not found: {args.checkpoint}")
# Set seed if provided
if args.seed is not None:
set_seed(args.seed)
print(f"Using random seed: {args.seed}")
# Determine device
device = torch.device(args.device)
print(f"Using device: {device}")
# Load model
print("Loading model...")
if args.local_paths:
print(f"Loading from local paths: config='{args.config}', checkpoint='{args.checkpoint}'")
try:
model = Dia.from_local(args.config, args.checkpoint, device=device)
except Exception as e:
print(f"Error loading local model: {e}")
exit(1)
else:
print(f"Loading from Hugging Face Hub: repo_id='{args.repo_id}'")
try:
model = Dia.from_pretrained(args.repo_id, device=device)
except Exception as e:
print(f"Error loading model from Hub: {e}")
exit(1)
print("Model loaded.")
# Generate audio
print("Generating audio...")
try:
sample_rate = 44100 # Default assumption
output_audio = model.generate(
text=args.text,
audio_prompt=args.audio_prompt,
max_tokens=args.max_tokens,
cfg_scale=args.cfg_scale,
temperature=args.temperature,
top_p=args.top_p,
)
print("Audio generation complete.")
print(f"Saving audio to {args.output}...")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
sf.write(args.output, output_audio, sample_rate)
print(f"Audio successfully saved to {args.output}")
except Exception as e:
print(f"Error during audio generation or saving: {e}")
exit(1)
if __name__ == "__main__":
main()
|