dia-tts-server / dia /model.py
Michael Hu
initial check in of the dia tts server
ac5de5b
# dia/model.py
import os
import logging
import time
import dac # Keep this import name
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file # <<< ADDED Import for safetensors
from .audio import audio_to_codebook, codebook_to_audio
from .config import (
DiaConfig,
) # Assuming this is the Pydantic config for model structure
from .layers import DiaModel, KVCache # Assuming these are the nn.Module definitions
# --- Get a logger instance for this module ---
logger = logging.getLogger(__name__)
# Optional: Add a check after import to verify the library looks correct
# Note: We now expect 'utils' based on original code
if (
not hasattr(dac, "utils")
or not hasattr(dac.utils, "download")
or not hasattr(dac, "DAC")
):
logger.warning(
"The imported 'dac' module does not appear to have the 'utils.download' structure expected by the original Dia code."
)
logger.warning(
"Ensure 'descript-audio-codec' is installed correctly (pip install descript-audio-codec)."
)
# If this check fails, _load_dac_model will likely raise an error later anyway.
def _sample_next_token(
logits_BCxV: torch.Tensor,
temperature: float,
top_p: float,
use_cfg_filter: bool,
cfg_filter_top_k: int | None = None,
) -> torch.Tensor:
"""Samples the next token based on logits, temperature, and top_p."""
if temperature == 0.0:
# Greedy sampling
return torch.argmax(logits_BCxV, dim=-1)
# Apply temperature scaling
logits_BCxV = logits_BCxV / temperature
# Apply CFG Top-K filtering (optional)
if use_cfg_filter and cfg_filter_top_k is not None:
# Get top K values and indices
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
# Create a mask to keep only top K logits
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
mask.scatter_(
dim=-1, index=top_k_indices_BCxV, value=False
) # Set top K positions to False (don't mask)
# Mask out logits not in the top K
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
# Apply Top-P (Nucleus) sampling
if top_p < 1.0:
# Convert logits to probabilities
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
# Sort probabilities in descending order
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
probs_BCxV, dim=-1, descending=True
)
# Calculate cumulative probabilities
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
# Create mask for tokens to remove (those exceeding top_p threshold)
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
# Shift the mask: keep the first token that crosses the threshold
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
..., :-1
].clone()
sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
# Scatter the mask back to the original order
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
indices_to_remove_BCxV.scatter_(
dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
)
# Apply the mask to the logits
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
# Calculate final probabilities after filtering
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
# Sample from the filtered distribution
# multinomial expects probabilities for each item in the batch
sampled_indices_BC = torch.multinomial(
final_probs_BCxV, num_samples=1
) # Shape [B*C, 1]
sampled_indices_C = sampled_indices_BC.squeeze(
-1
) # Shape [B*C] -> should be [C] if input was [C,V]
return sampled_indices_C
class Dia:
"""
Main class for the Dia Text-to-Speech model, handling loading and generation.
"""
def __init__(self, config: DiaConfig, device: torch.device = torch.device("cuda")):
"""
Initializes the Dia model structure based on the provided configuration.
Does not load weights here.
Args:
config: The DiaConfig object defining model parameters.
device: The torch device (e.g., 'cuda', 'cpu') the model should eventually run on.
Note: The model is instantiated but not moved to the device here.
"""
super().__init__()
logger.info(
f"Initializing Dia model structure with config version: {config.version}"
)
self.config = config
# Store the target device, but don't move the model yet. Loading weights will handle device placement.
self.target_device = device
# Instantiate the underlying PyTorch model based on the config
self.model = DiaModel(config)
self.dac_model = None # DAC model will be loaded separately
logger.info("Dia model structure initialized.")
@classmethod
def load_model_from_files(
cls,
config_path: str,
weights_path: str,
device: torch.device = torch.device("cuda"),
) -> "Dia":
"""
Loads the Dia model from local configuration and weights files.
Handles both .pth and .safetensors weight formats.
Args:
config_path: Path to the configuration JSON file (e.g., 'config.json').
weights_path: Path to the model weights file (e.g., 'model.pth' or 'model.safetensors').
device: The torch device ('cuda', 'cpu', etc.) to load the model onto.
Returns:
An instance of the Dia model loaded with weights and set to eval mode.
Raises:
FileNotFoundError: If the config or weights file is not found.
ValueError: If the weights file format is unsupported.
RuntimeError: If there is an error loading the config, weights, or DAC model.
"""
logger.info(f"Loading Dia model from local files:")
logger.info(f" Config: {config_path}")
logger.info(f" Weights: {weights_path}")
logger.info(f" Target Device: {device}")
# 1. Load Configuration
try:
config = DiaConfig.load(config_path)
if config is None:
# DiaConfig.load returns None on FileNotFoundError
logger.error(f"Configuration file not found at {config_path}")
raise FileNotFoundError(
f"Configuration file not found at {config_path}"
)
logger.info("Configuration loaded successfully.")
except Exception as e:
logger.error(
f"Error loading or validating configuration from {config_path}: {e}",
exc_info=True,
)
raise RuntimeError(
f"Failed to load configuration from {config_path}"
) from e
# 2. Instantiate Model Structure
# Pass the target device during instantiation if the underlying DiaModel supports it,
# otherwise, we move it later. Assuming __init__ doesn't take device for now.
dia_instance = cls(
config, device
) # Pass device mainly for storing target_device
# 3. Load Weights (State Dictionary)
try:
logger.info(f"Loading weights from: {weights_path}")
weights_filename = os.path.basename(weights_path)
state_dict = None
if weights_filename.endswith(".safetensors"):
logger.info(
"Detected .safetensors format. Loading using safetensors library."
)
# load_file loads directly to the specified device
state_dict = load_file(weights_path, device=str(device))
logger.info("Safetensors weights loaded.")
elif weights_filename.endswith(".pth"):
logger.info("Detected .pth format. Loading using torch.load.")
# torch.load needs map_location to load onto the correct device
state_dict = torch.load(weights_path, map_location=device)
logger.info("PyTorch weights (.pth) loaded.")
else:
logger.error(
f"Unsupported weights file format: {weights_filename}. Expected .pth or .safetensors."
)
raise ValueError(f"Unsupported weights file format: {weights_filename}")
# Load the state dictionary into the model structure
logger.info("Applying loaded weights to the model structure...")
# Use strict=True by default to catch mismatches. Can be set to False if needed for specific conversions (e.g., BF16 -> FP32 partial loads)
dia_instance.model.load_state_dict(state_dict, strict=True)
logger.info("Weights applied successfully.")
except FileNotFoundError:
logger.error(f"Weights file not found at {weights_path}")
raise FileNotFoundError(f"Weights file not found at {weights_path}")
except Exception as e:
logger.error(
f"Error loading weights from {weights_path}: {e}", exc_info=True
)
raise RuntimeError(f"Error loading weights from {weights_path}") from e
# 4. Move Model to Device and Set Eval Mode
logger.info(f"Moving model to device: {device}...")
dia_instance.model.to(device)
logger.info("Setting model to evaluation mode...")
dia_instance.model.eval()
# 5. Load Associated DAC Model
logger.info("Loading associated DAC model...")
dia_instance._load_dac_model() # This will log its own progress/errors
logger.info("Dia model fully loaded and ready.")
return dia_instance
# REMOVED from_pretrained - Responsibility moved to engine.py
# @classmethod
# def from_pretrained(...) -> "Dia": ...
def _load_dac_model(self):
"""Loads the Descript Audio Codec (DAC) model using the original project's method."""
if self.dac_model is not None:
logger.info("DAC model already loaded.")
return
# Verify the imported module has the necessary structure expected by original code
if (
not hasattr(dac, "utils")
or not hasattr(dac.utils, "download")
or not hasattr(dac, "DAC")
):
logger.error(
"Imported 'dac' module structure mismatch. Expected 'dac.utils.download()' and 'dac.DAC'."
)
logger.error(
"Ensure 'descript-audio-codec' is installed correctly via pip."
)
raise RuntimeError(
"Failed to load DAC model: required functions/structure missing from 'dac' module."
)
try:
# Use the original method found in the Dia repository
logger.info("Downloading/finding DAC model using dac.utils.download()...")
# This assumes dac.utils.download() handles caching internally
dac_model_path = dac.utils.download()
logger.info(f"DAC model path determined: {dac_model_path}")
logger.info("Loading DAC model from path...")
# Load DAC model and move it to the same device as the main Dia model
dac_model = dac.DAC.load(dac_model_path).to(self.target_device)
logger.info("DAC model loaded successfully.")
except AttributeError as ae:
logger.error(
f"AttributeError loading DAC model: '{ae}'. The installed 'descript-audio-codec' version might be incompatible with Dia's original code which expects 'dac.utils.download()'."
)
logger.error(
"Please check for specific version requirements of 'descript-audio-codec' for Dia, or potential installation issues."
)
raise RuntimeError(
"Failed to load DAC model due to incompatible library version or structure"
) from ae
except Exception as e:
logger.error(f"General error loading DAC model: {e}", exc_info=True)
raise RuntimeError("Failed to load DAC model") from e
self.dac_model = dac_model
def _create_attn_mask(
self,
q_padding_mask_1d: torch.Tensor,
k_padding_mask_1d: torch.Tensor,
is_causal: bool = False,
) -> torch.Tensor:
"""
Creates the attention mask (self or cross) based on padding masks.
Mimics JAX segment ID logic where attention is allowed between non-padding tokens
OR between padding tokens, but not across the boundary.
Args:
q_padding_mask_1d: Boolean tensor [Batch, SeqLenQ] where True indicates non-padding.
k_padding_mask_1d: Boolean tensor [Batch, SeqLenK] where True indicates non-padding.
is_causal: If True, applies an additional causal mask (for decoder self-attention).
Returns:
Boolean attention mask tensor [Batch, 1, SeqLenQ, SeqLenK] ready for F.scaled_dot_product_attention.
"""
B1, Tq = q_padding_mask_1d.shape
B2, Tk = k_padding_mask_1d.shape
if B1 != B2:
logger.warning(
f"Query ({B1}) and key ({B2}) batch dimensions do not match in _create_attn_mask"
)
assert B1 == B2, "Query and key batch dimensions must match"
# Expand masks for broadcasting: [B, Tq, 1] and [B, 1, Tk]
p_mask_q = q_padding_mask_1d.unsqueeze(2)
p_mask_k = k_padding_mask_1d.unsqueeze(1)
# True where a non-padding query token attends to a non-padding key token
non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
# True where a padding query token attends to a padding key token
pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
# Combine: Attention is allowed if tokens are both non-padding OR both padding.
mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
if is_causal:
# Apply causal mask for self-attention (query cannot attend to future keys)
if Tq != Tk:
logger.warning(f"Causal mask requested but Tq ({Tq}) != Tk ({Tk})")
assert (
Tq == Tk
), "Causal mask requires query and key sequence lengths to be equal"
# Create lower triangular matrix (True allows attention)
causal_mask_2d = torch.tril(
torch.ones((Tq, Tk), dtype=torch.bool, device=self.target_device)
)
# Combine with padding compatibility mask
mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
# Add head dimension for broadcasting: [B, 1, Tq, Tk]
return mask.unsqueeze(1)
def _prepare_text_input(
self, text: str
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Encodes text prompt into byte tokens, pads to max length,
and creates position IDs and padding mask.
Args:
text: The input text string.
Returns:
Tuple containing:
- src_tokens: Padded token IDs [1, SeqLen].
- src_positions: Position IDs [1, SeqLen].
- src_padding_mask: Boolean mask (True=non-pad) [1, SeqLen].
- enc_self_attn_mask: Attention mask for encoder [1, 1, SeqLen, SeqLen].
"""
text_pad_value = self.config.data.text_pad_value
max_len = self.config.data.text_length
logger.debug(
f"Preparing text input. Max length: {max_len}, Pad value: {text_pad_value}"
)
logger.debug(f"Original text (start): '{text[:100]}...'")
# Convert text to bytes and replace special speaker tokens
byte_text = text.encode("utf-8")
# Assuming Dia uses byte values 1 and 2 for S1/S2 based on original code context
replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
text_tokens = list(replaced_bytes) # List of integer byte values
logger.debug(
f"Text tokens after byte conversion (first 10): {text_tokens[:10]}"
)
# Pad or truncate sequence
current_len = len(text_tokens)
padding_needed = max_len - current_len
if padding_needed <= 0:
if current_len > max_len:
logger.warning(
f"Input text length ({current_len}) exceeds max length ({max_len}). Truncating."
)
text_tokens = text_tokens[:max_len]
padded_text_np = np.array(text_tokens, dtype=np.uint8)
else:
logger.debug(f"Padding text input with {padding_needed} pad tokens.")
padded_text_np = np.pad(
text_tokens,
(0, padding_needed),
mode="constant",
constant_values=text_pad_value,
).astype(np.uint8)
# Convert to tensors and add batch dimension [1, SeqLen]
src_tokens = (
torch.from_numpy(padded_text_np)
.to(torch.long)
.to(self.target_device)
.unsqueeze(0)
)
src_positions = (
torch.arange(max_len, device=self.target_device).to(torch.long).unsqueeze(0)
)
# Create padding mask (True where token is NOT the pad value)
src_padding_mask = src_tokens != text_pad_value # Shape [1, SeqLen]
# Create attention mask for the encoder (non-causal self-attention)
# Needs shape [B, 1, Tq, Tk] -> [1, 1, SeqLen, SeqLen]
enc_self_attn_mask = self._create_attn_mask(
src_padding_mask, src_padding_mask, is_causal=False
)
logger.debug(f"Prepared src_tokens shape: {src_tokens.shape}")
logger.debug(f"Prepared src_positions shape: {src_positions.shape}")
logger.debug(
f"Prepared src_padding_mask shape: {src_padding_mask.shape} (True means non-padding)"
)
logger.debug(f"Prepared enc_self_attn_mask shape: {enc_self_attn_mask.shape}")
return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
@torch.inference_mode()
def generate(
self,
text: str,
max_tokens: int | None = None,
cfg_scale: float = 3.0,
temperature: float = 1.3,
top_p: float = 0.95,
use_cfg_filter: bool = True,
use_torch_compile: bool = False, # Default to False for broader compatibility
cfg_filter_top_k: int = 35,
audio_prompt_path: str | None = None,
) -> np.ndarray:
"""
Generates audio waveform from a text prompt, optionally conditioned on an audio prompt.
Args:
text: The input text string. For dialogue, use [S1]/[S2] markers.
For voice cloning, prepend the transcript of the audio prompt.
max_tokens: Maximum number of audio tokens (frames) to generate. Defaults to config value.
cfg_scale: Classifier-Free Guidance scale. Higher values increase adherence to text.
temperature: Sampling temperature. Higher values increase randomness.
top_p: Nucleus sampling probability. Filters vocabulary during sampling.
use_cfg_filter: Whether to apply Top-K filtering based on CFG logits.
use_torch_compile: If True, attempts to compile the decoder step for potential speedup.
cfg_filter_top_k: The 'K' value for CFG Top-K filtering.
audio_prompt_path: Path to an audio file (e.g., WAV, MP3) to use as a voice prompt/clone target.
Returns:
A 1D NumPy array containing the generated audio waveform (float32).
"""
start_time_gen = time.time()
logger.info("Starting audio generation...")
logger.info(f" Text (start): '{text[:100]}...'")
logger.info(
f" Max tokens: {max_tokens if max_tokens is not None else 'Model Default'}"
)
logger.info(f" CFG Scale: {cfg_scale}")
logger.info(f" Temperature: {temperature}")
logger.info(f" Top P: {top_p}")
logger.info(f" Use CFG Filter: {use_cfg_filter}, Top K: {cfg_filter_top_k}")
logger.info(
f" Audio Prompt: {audio_prompt_path if audio_prompt_path else 'None'}"
)
logger.info(f" Use torch.compile: {use_torch_compile}")
logger.info(f" Target Device: {self.target_device}")
# --- Parameter Setup ---
num_channels = self.config.data.channels
audio_bos_value = self.config.data.audio_bos_value
audio_eos_value = self.config.data.audio_eos_value
audio_pad_value = self.config.data.audio_pad_value
delay_pattern = self.config.data.delay_pattern
# Use model's default audio length if max_tokens not provided
effective_max_tokens = (
max_tokens if max_tokens is not None else self.config.data.audio_length
)
logger.info(f" Effective max_tokens for generation: {effective_max_tokens}")
# Ensure delay pattern is usable
if not isinstance(delay_pattern, list) or not delay_pattern:
logger.warning("Delay pattern is invalid or empty. Using default [0].")
delay_pattern = [
0
] * num_channels # Fallback, though config should provide default
delay_tensor = torch.tensor(
delay_pattern, dtype=torch.long, device=self.target_device
)
max_delay_pattern = max(delay_pattern) if delay_pattern else 0
self.model.eval() # Ensure model is in eval mode
# --- Prepare Conditional and Unconditional Inputs ---
logger.info(
"Preparing text inputs for conditional and unconditional generation..."
)
(
cond_src_BxS,
cond_src_positions_BxS,
cond_src_padding_mask_BxS,
cond_enc_self_attn_mask_Bx1xSxS,
) = self._prepare_text_input(text)
# Create unconditional input (batch of zeros representing padding)
# Assuming pad value 0 for text based on config default
unc_src_BxS = torch.full_like(
cond_src_BxS, fill_value=self.config.data.text_pad_value
)
# Batch conditional and unconditional inputs together [2, SeqLen]
src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
# Expand other inputs to match batch size 2
src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
src_padding_mask_BxS = torch.cat(
[
torch.zeros_like(cond_src_padding_mask_BxS[0:1]),
cond_src_padding_mask_BxS,
],
dim=0,
) # Uncond mask is all False (padding)
# Encoder mask needs to handle the batched input correctly
# For CFG, typically the unconditional branch attends to nothing useful from text,
# but the structure needs to be maintained. We can reuse the conditional mask structure,
# but the actual attention scores will be based on the zeroed-out unconditional input.
# Alternatively, create a specific mask for the unconditional part if needed.
# Let's expand the conditional mask for simplicity, assuming the model handles zero inputs appropriately.
enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(
2, -1, -1, -1
)
logger.info("Text inputs prepared (batch size 2 for CFG).")
# --- Encoder Pass ---
logger.info("Running encoder pass...")
start_time_enc = time.time()
# Potentially use autocast for mixed precision if supported and beneficial on device
# Example: with torch.autocast(device_type=self.target_device.type, dtype=torch.bfloat16 if self.target_device.type == 'cuda' else torch.float32):
encoder_out = self.model.encoder(
x_ids=src_BxS, # Shape [2, S]
src_positions=src_positions_BxS, # Shape [2, S]
deterministic=True, # No dropout during inference
attn_mask=enc_self_attn_mask_Bx1xSxS, # Shape [2, 1, S, S]
)
logger.info(
f"Encoder pass completed in {time.time() - start_time_enc:.3f}s. Output shape: {encoder_out.shape}"
) # Shape: [2, S, E]
# --- Prepare Decoder Inputs & KV Cache ---
logger.info("Preparing decoder inputs and KV cache...")
start_time_kv = time.time()
# 3-1. Precompute Cross-Attention KV Cache (Static) from encoder output
# This cache is computed once and reused for every decoding step.
decoder_cross_attention_cache: list[KVCache] = (
self.model.decoder.precompute_cross_attention_kv(
effective_max_tokens, encoder_out, src_positions_BxS
)
)
logger.debug(
f"Precomputed cross-attention KV cache for {len(decoder_cross_attention_cache)} layers."
)
# 3-2. Initialize Self-Attention KV Cache (Dynamic, grows with each step)
decoder_self_attention_cache: list[KVCache] = []
for i in range(self.model.decoder.num_layers):
decoder_self_attention_cache.append(
KVCache(
self.config.model.decoder.gqa_query_heads,
effective_max_tokens, # Max length the cache can hold
self.config.model.decoder.gqa_head_dim,
self.target_device, # Cache tensors should be on the target device
)
)
logger.debug(
f"Initialized self-attention KV cache for {len(decoder_self_attention_cache)} layers."
)
logger.info(
f"KV cache preparation completed in {time.time() - start_time_kv:.3f}s."
)
# 3-3. Initialize Decoder Start Tokens (BOS)
# Shape [2, 1, C] (Batch=2 for cond/uncond, T=1 for first step, C=channels)
generated_tokens_history = torch.full(
(2, 1, num_channels),
fill_value=audio_bos_value,
dtype=torch.long,
device=self.target_device,
)
logger.debug(f"Initial decoder input (BOS): {generated_tokens_history.shape}")
current_step_index = (
0 # Index of the step we are currently generating (starts at 0)
)
prompt_len_inc_bos = 1 # Length of the initial prompt (just BOS initially)
# 3-4. Handle Audio Prompt (Prefill KV Cache)
if audio_prompt_path is not None:
logger.info("Processing audio prompt for prefilling...")
start_time_prompt = time.time()
try:
# Load and potentially resample audio
audio_prompt_waveform, sr = torchaudio.load(audio_prompt_path)
logger.debug(
f"Loaded audio prompt: {audio_prompt_waveform.shape}, Sample Rate: {sr}"
)
if sr != 44100:
logger.info(f"Resampling audio prompt from {sr}Hz to 44100Hz")
audio_prompt_waveform = torchaudio.functional.resample(
audio_prompt_waveform, sr, 44100
)
# Ensure correct shape [B, C, T_audio] and device
# Assuming DAC expects channels first, add batch dim
if audio_prompt_waveform.ndim == 1: # Mono
audio_prompt_waveform = audio_prompt_waveform.unsqueeze(
0
) # Add channel dim
audio_prompt_waveform = audio_prompt_waveform.unsqueeze(0).to(
self.target_device
) # Add batch dim
# Encode audio prompt to codes using DAC
logger.info("Encoding audio prompt to codes using DAC...")
if self.dac_model is None:
raise RuntimeError(
"DAC model not loaded, required for audio prompt."
)
# audio_to_codebook returns [B, T_codes, C]
audio_prompt_codes = audio_to_codebook(
self.dac_model, audio_prompt_waveform, data_config=self.config.data
) # Shape [1, T_codes, C]
logger.info(
f"Encoded audio prompt to codes: {audio_prompt_codes.shape}"
)
# Concatenate BOS tokens with prompt codes
# Expand prompt codes to batch size 2 (for cond/uncond)
generated_tokens_history = torch.cat(
[generated_tokens_history, audio_prompt_codes.expand(2, -1, -1)],
dim=1,
) # Shape [2, 1 + T_codes, C]
logger.debug(
f"Decoder input history after prompt concatenation: {generated_tokens_history.shape}"
)
prefill_len = generated_tokens_history.shape[
1
] # Length including BOS + prompt
prompt_len_inc_bos = prefill_len
logger.info(f"Prefilling KV cache with length {prefill_len}...")
# Prepare inputs for prefill forward pass
prefill_tgt_pos = (
torch.arange(prefill_len, device=self.target_device)
.unsqueeze(0)
.expand(2, -1)
) # Shape [2, T_prefill]
# Padding mask based on actual tokens (BOS and prompt codes are not PAD)
# Shape [2, T_prefill] (True where not PAD)
prefill_tgt_padding_mask = (
generated_tokens_history != audio_pad_value
).any(dim=2)
# Create attention masks for prefill
# Shape [2, 1, T_prefill, T_prefill]
prefill_self_attn_mask = self._create_attn_mask(
prefill_tgt_padding_mask,
prefill_tgt_padding_mask,
is_causal=True,
)
# Shape [2, 1, T_prefill, S]
prefill_cross_attn_mask = self._create_attn_mask(
prefill_tgt_padding_mask,
src_padding_mask_BxS,
is_causal=False,
)
# Run forward pass through decoder to fill the self-attention KV cache
# We discard the logits from prefill
_ = self.model.decoder.forward(
tgt_ids_BxTxC=generated_tokens_history, # Pass the full history [2, T_prefill, C]
encoder_out=encoder_out,
tgt_positions=prefill_tgt_pos,
src_positions=src_positions_BxS,
deterministic=True,
self_attn_mask=prefill_self_attn_mask,
cross_attn_mask=prefill_cross_attn_mask,
self_attention_cache=decoder_self_attention_cache, # Pass cache to be filled
cross_attention_cache=decoder_cross_attention_cache, # Pass precomputed cache
# prefill=True # Pass prefill flag if decoder layer uses it
)
# Update the current step index. The next token to generate is at index prefill_len.
current_step_index = prefill_len
logger.info(
f"KV cache prefilled in {time.time() - start_time_prompt:.3f}s. Next step index: {current_step_index}"
)
except Exception as e:
logger.error(f"Error processing audio prompt: {e}", exc_info=True)
raise RuntimeError("Failed to process audio prompt") from e
# --- Autoregressive Generation Loop ---
logger.info("Starting autoregressive generation loop...")
start_time_loop = time.time()
eos_detected_channel_0 = False
eos_countdown = -1 # Countdown after EOS detected in channel 0
extra_steps_after_eos = (
30 # Generate a few extra steps for delay pattern completion
)
# Pre-allocate tensor for storing *newly* generated tokens for efficiency
# We already have the prompt in generated_tokens_history
num_steps_to_generate = effective_max_tokens
newly_generated_tokens = torch.full(
(2, num_steps_to_generate, num_channels),
fill_value=audio_pad_value, # Fill with pad initially
dtype=torch.long,
device=self.target_device,
)
logger.debug(
f"Allocated tensor for newly generated tokens: {newly_generated_tokens.shape}"
)
# --- Compile decode_step if requested ---
decode_step_fn = self.model.decoder.decode_step
if use_torch_compile:
logger.info("Compiling decoder step function with torch.compile...")
try:
# Experiment with modes: "default", "reduce-overhead", "max-autotune"
decode_step_fn = torch.compile(decode_step_fn, mode="reduce-overhead")
logger.info("Decoder step function compiled.")
except Exception as e:
logger.warning(
f"torch.compile failed: {e}. Using eager mode.", exc_info=True
)
# --- Prepare static cross-attention mask for single-step decoding ---
# Query mask is always [B, 1] (True, as generated tokens are not PAD)
step_tgt_padding_mask = torch.ones(
(2, 1), dtype=torch.bool, device=self.target_device
)
# Shape [2, 1, 1, S]
step_decoder_cross_attn_mask = self._create_attn_mask(
step_tgt_padding_mask,
src_padding_mask_BxS,
is_causal=False,
)
# --- Generation Loop ---
steps_taken = 0
for step_offset in range(num_steps_to_generate):
# Absolute step index considering prompt length
current_absolute_step = current_step_index + step_offset
# Get the token IDs for the *previous* step to predict the current one
# Shape [2, 1, C]
# If step_offset is 0, use the last token from the prompt history
if step_offset == 0:
input_token_ids = generated_tokens_history[:, -1, :].unsqueeze(1)
else:
# Use the token generated in the previous iteration of this loop
input_token_ids = newly_generated_tokens[
:, step_offset - 1, :
].unsqueeze(1)
# Position ID for the current absolute step
# Shape [2, 1]
tgt_pos_Bx1 = torch.full(
(2, 1),
fill_value=current_absolute_step,
dtype=torch.long,
device=self.target_device,
)
# --- Call Decoder Step ---
# self_attn_mask is None because KV cache handles causality implicitly in single-step decoding
logits_Bx1xCxV, new_self_kv_cache_list = decode_step_fn(
tgt_ids_Bx1xC=input_token_ids,
tgt_pos_Bx1=tgt_pos_Bx1,
encoder_out=encoder_out,
self_attn_mask=None,
cross_attn_mask=step_decoder_cross_attn_mask,
self_attention_cache=decoder_self_attention_cache,
cross_attention_cache=decoder_cross_attention_cache,
) # Logits shape: [2, 1, C, V]
# --- Update Self-Attention KV Cache ---
for i, layer_cache in enumerate(decoder_self_attention_cache):
if (
new_self_kv_cache_list
and i < len(new_self_kv_cache_list)
and new_self_kv_cache_list[i] is not None
):
# new_self_kv_cache_list[i] is a tuple (k_tensor, v_tensor) for the current step
# k_tensor shape: [2, NumHeads, 1, HeadDim]
# v_tensor shape: [2, NumHeads, 1, HeadDim]
layer_cache.update_cache(
new_self_kv_cache_list[i][0], new_self_kv_cache_list[i][1]
)
else:
logger.warning(
f"Missing KV cache update for layer {i} at step {current_absolute_step}"
)
# --- Sampling ---
V = self.config.model.tgt_vocab_size
# Get logits for the generated step [2, C, V]
logits_last_BxCxV = logits_Bx1xCxV.squeeze(1)
# Separate conditional and unconditional logits
uncond_logits_CxV = logits_last_BxCxV[0, :, :] # Shape [C, V]
cond_logits_CxV = logits_last_BxCxV[1, :, :] # Shape [C, V]
# Apply Classifier-Free Guidance (CFG)
cfg_logits_CxV = cond_logits_CxV + cfg_scale * (
cond_logits_CxV - uncond_logits_CxV
) # Shape [C, V]
# --- Prevent sampling PAD/EOS/BOS tokens inappropriately ---
logits_for_sampling_CxV = (
cfg_logits_CxV.clone()
) # Clone to avoid modifying original logits
logits_for_sampling_CxV[:, audio_pad_value] = -torch.inf # Never sample PAD
logits_for_sampling_CxV[:, audio_bos_value] = (
-torch.inf
) # Never sample BOS after start
# Allow EOS only if not already detected or in countdown
if eos_detected_channel_0 and eos_countdown <= 0:
logits_for_sampling_CxV[:, audio_eos_value] = -torch.inf
# --- Sample the next token for each channel ---
pred_C = _sample_next_token(
logits_for_sampling_CxV.float(), # Ensure float32 for sampling stability
temperature=temperature,
top_p=top_p,
use_cfg_filter=use_cfg_filter,
cfg_filter_top_k=cfg_filter_top_k,
) # Shape [C]
# --- Handle Delay Pattern (Only if no audio prompt was given) ---
# If there's no prompt, the first few tokens should be BOS according to delay
# generation_step_index is how many tokens generated *after* prompt/initial BOS
generation_step_index = step_offset
if audio_prompt_path is None:
is_before_delay = generation_step_index < delay_tensor # Shape [C]
pred_C = torch.where(
is_before_delay,
torch.tensor(
audio_bos_value, device=self.target_device, dtype=torch.long
),
pred_C,
)
# --- Store the predicted token in the newly_generated_tokens tensor ---
newly_generated_tokens[:, step_offset, :] = pred_C.unsqueeze(0).expand(
2, -1
)
steps_taken += 1 # Increment steps taken in this loop
# --- EOS Handling ---
if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
logger.info(
f"EOS token detected in channel 0 at step {current_absolute_step}. Starting countdown."
)
eos_detected_channel_0 = True
eos_countdown = extra_steps_after_eos
if eos_countdown > 0:
step_after_eos = extra_steps_after_eos - eos_countdown
logger.debug(
f"EOS countdown: {eos_countdown}, Step after EOS: {step_after_eos}"
)
# Modify the token *just generated* if needed for EOS/PAD forcing
current_new_tokens = newly_generated_tokens[
:, step_offset, :
] # Shape [2, C]
for i, d in enumerate(delay_pattern):
if step_after_eos == d:
logger.debug(
f" Forcing EOS in channel {i} at step {current_absolute_step}"
)
current_new_tokens[:, i] = audio_eos_value
elif step_after_eos > d:
logger.debug(
f" Forcing PAD in channel {i} at step {current_absolute_step}"
)
current_new_tokens[:, i] = audio_pad_value
# Put the potentially modified tokens back
newly_generated_tokens[:, step_offset, :] = current_new_tokens
eos_countdown -= 1
if eos_countdown == 0:
logger.info(
f"EOS countdown finished at step {current_absolute_step}. Stopping generation."
)
break # Stop generation loop
# Check if we reached the max *new* tokens requested
if steps_taken >= num_steps_to_generate:
logger.info(
f"Reached max generation steps ({num_steps_to_generate}). Stopping."
)
break
logger.info(
f"Autoregressive loop finished after {steps_taken} steps in {time.time() - start_time_loop:.3f}s."
)
# --- Extract Generated Codes ---
# Get the conditional generation result (index 1) from the *newly* generated tokens
# Only take the number of steps actually taken
final_new_codes = newly_generated_tokens[
1, :steps_taken, :
] # Shape [T_generated, C]
logger.info(f"Extracted newly generated codes shape: {final_new_codes.shape}")
# --- Convert Codes to Audio using DAC ---
logger.info("Converting generated codes to audio using DAC...")
start_time_decode = time.time()
if self.dac_model is None:
raise RuntimeError("DAC model not loaded, required for audio decoding.")
# codebook_to_audio expects codes shape [C, T]
generated_codes_CxT = final_new_codes.transpose(0, 1) # Shape [C, T_generated]
if generated_codes_CxT.numel() == 0:
logger.warning("No new codes were generated. Returning empty audio.")
return np.array([], dtype=np.float32)
# Call the decoding function (handles delay reversal and DAC decoding)
audio_waveform = codebook_to_audio(
generated_codes_CxT,
self.dac_model,
delay_pattern,
B=1, # Batch size for decoding is 1
T=generated_codes_CxT.shape[1], # Pass the actual length of generated codes
C=num_channels,
) # Returns shape [1, T_audio] or [T_audio]
# Ensure output is a 1D numpy array on CPU
final_audio_np = audio_waveform.squeeze().cpu().numpy()
logger.info(
f"Audio decoding completed in {time.time() - start_time_decode:.3f}s. Output shape: {final_audio_np.shape}"
)
logger.info(f"Total generation time: {time.time() - start_time_gen:.3f}s")
return final_audio_np