File size: 3,459 Bytes
030c851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import time
import logging
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional

from dia.model import Dia

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"

# Global model instance (lazy loaded)
_model = None


def _get_model() -> Dia:
    """Lazy-load the Dia model to avoid loading it until needed"""
    global _model
    if _model is None:
        logger.info("Loading Dia model...")
        try:
            _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
            logger.info("Dia model loaded successfully")
        except Exception as e:
            logger.error(f"Error loading Dia model: {e}", exc_info=True)
            raise
    return _model


def generate_speech(text: str, language: str = "zh") -> str:
    """Public interface for TTS generation using Dia model
    
    Args:
        text (str): Input text to synthesize
        language (str): Language code (not used in Dia model, kept for API compatibility)
        
    Returns:
        str: Path to the generated audio file
    """
    logger.info(f"Generating speech for text length: {len(text)}")
    
    try:
        # Create output directory if it doesn't exist
        os.makedirs("temp/outputs", exist_ok=True)
        
        # Generate unique output path
        output_path = f"temp/outputs/output_{int(time.time())}.wav"
        
        # Get the model
        model = _get_model()
        
        # Generate audio
        start_time = time.time()
        
        with torch.inference_mode():
            output_audio_np = model.generate(
                text,
                max_tokens=None,  # Use default from model config
                cfg_scale=3.0,
                temperature=1.3,
                top_p=0.95,
                cfg_filter_top_k=35,
                use_torch_compile=False,  # Keep False for stability
                verbose=False
            )
        
        end_time = time.time()
        logger.info(f"Generation finished in {end_time - start_time:.2f} seconds")
        
        # Process the output
        if output_audio_np is not None:
            # Apply a slight slowdown for better quality (0.94x speed)
            speed_factor = 0.94
            original_len = len(output_audio_np)
            target_len = int(original_len / speed_factor)
            
            if target_len != original_len and target_len > 0:
                x_original = np.arange(original_len)
                x_resampled = np.linspace(0, original_len - 1, target_len)
                output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
                logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
            
            # Save the audio file
            sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
            logger.info(f"Audio saved to {output_path}")
            
            return output_path
        else:
            logger.warning("Generation produced no output, returning dummy audio")
            return "temp/outputs/dummy.wav"
            
    except Exception as e:
        logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
        # Return dummy path in case of error
        return "temp/outputs/dummy.wav"