Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						4aa0f34
	
1
								Parent(s):
							
							9577cb2
								
update to faster inference
Browse files- app.py +17 -31
 - dia/audio.py +27 -104
 - dia/config.py +17 -26
 - dia/layers.py +106 -337
 - dia/model.py +314 -257
 - dia/state.py +234 -0
 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,9 +1,7 @@ 
     | 
|
| 1 | 
         
            -
            import argparse
         
     | 
| 2 | 
         
             
            import tempfile
         
     | 
| 3 | 
         
             
            import time
         
     | 
| 4 | 
         
             
            from pathlib import Path
         
     | 
| 5 | 
         
             
            from typing import Optional, Tuple
         
     | 
| 6 | 
         
            -
            import spaces
         
     | 
| 7 | 
         | 
| 8 | 
         
             
            import gradio as gr
         
     | 
| 9 | 
         
             
            import numpy as np
         
     | 
| 
         @@ -12,40 +10,17 @@ import torch 
     | 
|
| 12 | 
         | 
| 13 | 
         
             
            from dia.model import Dia
         
     | 
| 14 | 
         | 
| 15 | 
         
            -
            # --- Global Setup ---
         
     | 
| 16 | 
         
            -
            parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
         
     | 
| 17 | 
         
            -
            parser.add_argument(
         
     | 
| 18 | 
         
            -
                "--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')"
         
     | 
| 19 | 
         
            -
            )
         
     | 
| 20 | 
         
            -
            parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
         
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
            args = parser.parse_args()
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
            # Determine device
         
     | 
| 26 | 
         
            -
            if args.device:
         
     | 
| 27 | 
         
            -
                device = torch.device(args.device)
         
     | 
| 28 | 
         
            -
            elif torch.cuda.is_available():
         
     | 
| 29 | 
         
            -
                device = torch.device("cuda")
         
     | 
| 30 | 
         
            -
            # Simplified MPS check for broader compatibility
         
     | 
| 31 | 
         
            -
            elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
         
     | 
| 32 | 
         
            -
                # Basic check is usually sufficient, detailed check can be problematic
         
     | 
| 33 | 
         
            -
                device = torch.device("mps")
         
     | 
| 34 | 
         
            -
            else:
         
     | 
| 35 | 
         
            -
                device = torch.device("cpu")
         
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
            print(f"Using device: {device}")
         
     | 
| 38 | 
         | 
| 39 | 
         
             
            # Load Nari model and config
         
     | 
| 40 | 
         
             
            print("Loading Nari model...")
         
     | 
| 41 | 
         
             
            try:
         
     | 
| 42 | 
         
             
                # Use the function from inference.py
         
     | 
| 43 | 
         
            -
                model = Dia.from_pretrained("nari-labs/Dia-1.6B")
         
     | 
| 44 | 
         
             
            except Exception as e:
         
     | 
| 45 | 
         
             
                print(f"Error loading Nari model: {e}")
         
     | 
| 46 | 
         
             
                raise
         
     | 
| 47 | 
         | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
             
            def run_inference(
         
     | 
| 50 | 
         
             
                text_input: str,
         
     | 
| 51 | 
         
             
                audio_prompt_input: Optional[Tuple[int, np.ndarray]],
         
     | 
| 
         @@ -60,7 +35,7 @@ def run_inference( 
     | 
|
| 60 | 
         
             
                Runs Nari inference using the globally loaded model and provided inputs.
         
     | 
| 61 | 
         
             
                Uses temporary files for text and audio prompt compatibility with inference.generate.
         
     | 
| 62 | 
         
             
                """
         
     | 
| 63 | 
         
            -
                 
     | 
| 64 | 
         | 
| 65 | 
         
             
                if not text_input or text_input.isspace():
         
     | 
| 66 | 
         
             
                    raise gr.Error("Text input cannot be empty.")
         
     | 
| 
         @@ -146,10 +121,9 @@ def run_inference( 
     | 
|
| 146 | 
         
             
                            cfg_scale=cfg_scale,
         
     | 
| 147 | 
         
             
                            temperature=temperature,
         
     | 
| 148 | 
         
             
                            top_p=top_p,
         
     | 
| 149 | 
         
            -
                            use_cfg_filter=True,
         
     | 
| 150 | 
         
             
                            cfg_filter_top_k=cfg_filter_top_k,  # Pass the value here
         
     | 
| 151 | 
         
             
                            use_torch_compile=False,  # Keep False for Gradio stability
         
     | 
| 152 | 
         
            -
                             
     | 
| 153 | 
         
             
                        )
         
     | 
| 154 | 
         | 
| 155 | 
         
             
                    end_time = time.time()
         
     | 
| 
         @@ -192,6 +166,16 @@ def run_inference( 
     | 
|
| 192 | 
         
             
                            f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
         
     | 
| 193 | 
         
             
                        )
         
     | 
| 194 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 195 | 
         
             
                    else:
         
     | 
| 196 | 
         
             
                        print("\nGeneration finished, but no valid tokens were produced.")
         
     | 
| 197 | 
         
             
                        # Return default silence
         
     | 
| 
         @@ -383,8 +367,10 @@ with gr.Blocks(css=css) as demo: 
     | 
|
| 383 | 
         
             
                else:
         
     | 
| 384 | 
         
             
                    gr.Markdown("_(No examples configured or example prompt file missing)_")
         
     | 
| 385 | 
         | 
| 386 | 
         
            -
             
     | 
| 387 | 
         
             
            # --- Launch the App ---
         
     | 
| 388 | 
         
             
            if __name__ == "__main__":
         
     | 
| 389 | 
         
             
                print("Launching Gradio interface...")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 390 | 
         
             
                demo.launch()
         
     | 
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import tempfile
         
     | 
| 2 | 
         
             
            import time
         
     | 
| 3 | 
         
             
            from pathlib import Path
         
     | 
| 4 | 
         
             
            from typing import Optional, Tuple
         
     | 
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
             
            import gradio as gr
         
     | 
| 7 | 
         
             
            import numpy as np
         
     | 
| 
         | 
|
| 10 | 
         | 
| 11 | 
         
             
            from dia.model import Dia
         
     | 
| 12 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
             
            # Load Nari model and config
         
     | 
| 15 | 
         
             
            print("Loading Nari model...")
         
     | 
| 16 | 
         
             
            try:
         
     | 
| 17 | 
         
             
                # Use the function from inference.py
         
     | 
| 18 | 
         
            +
                model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
         
     | 
| 19 | 
         
             
            except Exception as e:
         
     | 
| 20 | 
         
             
                print(f"Error loading Nari model: {e}")
         
     | 
| 21 | 
         
             
                raise
         
     | 
| 22 | 
         | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
             
            def run_inference(
         
     | 
| 25 | 
         
             
                text_input: str,
         
     | 
| 26 | 
         
             
                audio_prompt_input: Optional[Tuple[int, np.ndarray]],
         
     | 
| 
         | 
|
| 35 | 
         
             
                Runs Nari inference using the globally loaded model and provided inputs.
         
     | 
| 36 | 
         
             
                Uses temporary files for text and audio prompt compatibility with inference.generate.
         
     | 
| 37 | 
         
             
                """
         
     | 
| 38 | 
         
            +
                global model, device  # Access global model, config, device
         
     | 
| 39 | 
         | 
| 40 | 
         
             
                if not text_input or text_input.isspace():
         
     | 
| 41 | 
         
             
                    raise gr.Error("Text input cannot be empty.")
         
     | 
| 
         | 
|
| 121 | 
         
             
                            cfg_scale=cfg_scale,
         
     | 
| 122 | 
         
             
                            temperature=temperature,
         
     | 
| 123 | 
         
             
                            top_p=top_p,
         
     | 
| 
         | 
|
| 124 | 
         
             
                            cfg_filter_top_k=cfg_filter_top_k,  # Pass the value here
         
     | 
| 125 | 
         
             
                            use_torch_compile=False,  # Keep False for Gradio stability
         
     | 
| 126 | 
         
            +
                            audio_prompt=prompt_path_for_generate,
         
     | 
| 127 | 
         
             
                        )
         
     | 
| 128 | 
         | 
| 129 | 
         
             
                    end_time = time.time()
         
     | 
| 
         | 
|
| 166 | 
         
             
                            f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
         
     | 
| 167 | 
         
             
                        )
         
     | 
| 168 | 
         | 
| 169 | 
         
            +
                        # Explicitly convert to int16 to prevent Gradio warning
         
     | 
| 170 | 
         
            +
                        if (
         
     | 
| 171 | 
         
            +
                            output_audio[1].dtype == np.float32
         
     | 
| 172 | 
         
            +
                            or output_audio[1].dtype == np.float64
         
     | 
| 173 | 
         
            +
                        ):
         
     | 
| 174 | 
         
            +
                            audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
         
     | 
| 175 | 
         
            +
                            audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
         
     | 
| 176 | 
         
            +
                            output_audio = (output_sr, audio_for_gradio)
         
     | 
| 177 | 
         
            +
                            print("Converted audio to int16 for Gradio output.")
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
             
                    else:
         
     | 
| 180 | 
         
             
                        print("\nGeneration finished, but no valid tokens were produced.")
         
     | 
| 181 | 
         
             
                        # Return default silence
         
     | 
| 
         | 
|
| 367 | 
         
             
                else:
         
     | 
| 368 | 
         
             
                    gr.Markdown("_(No examples configured or example prompt file missing)_")
         
     | 
| 369 | 
         | 
| 
         | 
|
| 370 | 
         
             
            # --- Launch the App ---
         
     | 
| 371 | 
         
             
            if __name__ == "__main__":
         
     | 
| 372 | 
         
             
                print("Launching Gradio interface...")
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
         
     | 
| 375 | 
         
            +
                # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
         
     | 
| 376 | 
         
             
                demo.launch()
         
     | 
    	
        dia/audio.py
    CHANGED
    
    | 
         @@ -2,10 +2,10 @@ import typing as tp 
     | 
|
| 2 | 
         | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         | 
| 5 | 
         
            -
            from .config import DataConfig
         
     | 
| 6 | 
         | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
             
     | 
| 
         | 
|
| 9 | 
         
             
                """
         
     | 
| 10 | 
         
             
                Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
         
     | 
| 11 | 
         
             
                Negative t_idx => BOS; t_idx >= T => PAD.
         
     | 
| 
         @@ -69,7 +69,9 @@ def apply_audio_delay( 
     | 
|
| 69 | 
         | 
| 70 | 
         
             
                # Equivalent of tf.gather_nd using advanced indexing
         
     | 
| 71 | 
         
             
                # Ensure indices are long type if not already (build_delay_indices should handle this)
         
     | 
| 72 | 
         
            -
                gathered_flat = audio_BxTxC[ 
     | 
| 
         | 
|
| 
         | 
|
| 73 | 
         
             
                gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
         
     | 
| 74 | 
         | 
| 75 | 
         
             
                # Create masks on the correct device
         
     | 
| 
         @@ -82,65 +84,16 @@ def apply_audio_delay( 
     | 
|
| 82 | 
         | 
| 83 | 
         
             
                # If mask_bos, BOS; else if mask_pad, PAD; else original gather
         
     | 
| 84 | 
         
             
                # All tensors should now be on the same device
         
     | 
| 85 | 
         
            -
                result_BxTxC = torch.where( 
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
                return result_BxTxC
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
            @torch.no_grad()
         
     | 
| 91 | 
         
            -
            @torch.inference_mode()
         
     | 
| 92 | 
         
            -
            def audio_to_codebook(
         
     | 
| 93 | 
         
            -
                model,
         
     | 
| 94 | 
         
            -
                input_values,
         
     | 
| 95 | 
         
            -
                data_config: DataConfig,
         
     | 
| 96 | 
         
            -
                padding_mask=None,
         
     | 
| 97 | 
         
            -
                sample_rate=44100,
         
     | 
| 98 | 
         
            -
            ):
         
     | 
| 99 | 
         
            -
                """
         
     | 
| 100 | 
         
            -
                Encodes the input audio waveform into discrete codes.
         
     | 
| 101 | 
         
            -
             
     | 
| 102 | 
         
            -
                Args:
         
     | 
| 103 | 
         
            -
                    model: The model to use for encoding.
         
     | 
| 104 | 
         
            -
                    input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
         
     | 
| 105 | 
         
            -
                        Float values of the input audio waveform.
         
     | 
| 106 | 
         
            -
                    padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
         
     | 
| 107 | 
         
            -
                        Padding mask used to pad the `input_values`.
         
     | 
| 108 | 
         
            -
                    sample_rate (`int`, *optional*) :
         
     | 
| 109 | 
         
            -
                        Signal sampling_rate
         
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
                Returns:
         
     | 
| 112 | 
         
            -
                    A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
         
     | 
| 113 | 
         
            -
                    factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
         
     | 
| 114 | 
         
            -
                    `codebook` of shape `[batch_size, num_codebooks, frames]`.
         
     | 
| 115 | 
         
            -
                    Scale is not used here.
         
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
                """
         
     | 
| 118 | 
         
            -
                audio_data = model.preprocess(input_values, sample_rate)
         
     | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
            -
                if padding_mask is None:
         
     | 
| 121 | 
         
            -
                    padding_mask = torch.ones_like(input_values).bool()
         
     | 
| 122 | 
         
            -
             
     | 
| 123 | 
         
            -
                _, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None)  # 1, C, T
         
     | 
| 124 | 
         
            -
                seq_length = encoded_frame.shape[2]
         
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
            -
                t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
         
     | 
| 127 | 
         
            -
                    B=1,
         
     | 
| 128 | 
         
            -
                    T=seq_length,
         
     | 
| 129 | 
         
            -
                    C=data_config.channels,
         
     | 
| 130 | 
         
            -
                    delay_pattern=data_config.delay_pattern,
         
     | 
| 131 | 
         
             
                )
         
     | 
| 132 | 
         | 
| 133 | 
         
            -
                 
     | 
| 134 | 
         
            -
                    audio_BxTxC=encoded_frame.transpose(1, 2),  # 1, T, C
         
     | 
| 135 | 
         
            -
                    pad_value=data_config.audio_pad_value,
         
     | 
| 136 | 
         
            -
                    bos_value=data_config.audio_bos_value,
         
     | 
| 137 | 
         
            -
                    precomp=(t_idx_BxTxC, indices_BTCx3),
         
     | 
| 138 | 
         
            -
                )
         
     | 
| 139 | 
         
            -
             
     | 
| 140 | 
         
            -
                return encoded_frame
         
     | 
| 141 | 
         | 
| 142 | 
         | 
| 143 | 
         
            -
            def build_revert_indices( 
     | 
| 
         | 
|
| 
         | 
|
| 144 | 
         
             
                """
         
     | 
| 145 | 
         
             
                Precompute indices for the revert operation using PyTorch.
         
     | 
| 146 | 
         | 
| 
         @@ -162,8 +115,12 @@ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> 
     | 
|
| 162 | 
         
             
                    t_idx_BT1 + delay_arr.view(1, 1, C),
         
     | 
| 163 | 
         
             
                    torch.tensor(T - 1, device=device),
         
     | 
| 164 | 
         
             
                )
         
     | 
| 165 | 
         
            -
                b_idx_BxTxC = torch.broadcast_to( 
     | 
| 166 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 167 | 
         | 
| 168 | 
         
             
                indices_BTCx3 = torch.stack(
         
     | 
| 169 | 
         
             
                    [
         
     | 
| 
         @@ -205,15 +162,21 @@ def revert_audio_delay( 
     | 
|
| 205 | 
         
             
                indices_BTCx3 = indices_BTCx3.to(device)
         
     | 
| 206 | 
         | 
| 207 | 
         
             
                # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
         
     | 
| 208 | 
         
            -
                gathered_flat = audio_BxTxC[ 
     | 
| 209 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 210 | 
         | 
| 211 | 
         
             
                # Create pad_tensor on the correct device
         
     | 
| 212 | 
         
             
                pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
         
     | 
| 213 | 
         
             
                # Create T tensor on the correct device for comparison
         
     | 
| 214 | 
         
             
                T_tensor = torch.tensor(T, device=device)
         
     | 
| 215 | 
         | 
| 216 | 
         
            -
                result_BxTxC = torch.where( 
     | 
| 
         | 
|
| 
         | 
|
| 217 | 
         | 
| 218 | 
         
             
                return result_BxTxC
         
     | 
| 219 | 
         | 
| 
         @@ -238,43 +201,3 @@ def decode( 
     | 
|
| 238 | 
         
             
                except Exception as e:
         
     | 
| 239 | 
         
             
                    print(f"Error in decode method: {str(e)}")
         
     | 
| 240 | 
         
             
                    raise
         
     | 
| 241 | 
         
            -
             
     | 
| 242 | 
         
            -
             
     | 
| 243 | 
         
            -
            def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
         
     | 
| 244 | 
         
            -
                """Process a single codebook file to generate audio"""
         
     | 
| 245 | 
         
            -
                # Remove BOS token
         
     | 
| 246 | 
         
            -
                generated_codes = generated_codes[:, 1:]
         
     | 
| 247 | 
         
            -
             
     | 
| 248 | 
         
            -
                if generated_codes.shape[1] > T:
         
     | 
| 249 | 
         
            -
                    generated_codes = generated_codes[:, :T]
         
     | 
| 250 | 
         
            -
             
     | 
| 251 | 
         
            -
                seq_length = generated_codes.shape[1]
         
     | 
| 252 | 
         
            -
             
     | 
| 253 | 
         
            -
                # Build revert indices
         
     | 
| 254 | 
         
            -
                t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
         
     | 
| 255 | 
         
            -
             
     | 
| 256 | 
         
            -
                # Transpose and add batch dimension
         
     | 
| 257 | 
         
            -
                audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
         
     | 
| 258 | 
         
            -
                reverted_codebook = revert_audio_delay(
         
     | 
| 259 | 
         
            -
                    audio_BxTxC=audio_BxTxC,
         
     | 
| 260 | 
         
            -
                    pad_value=0,
         
     | 
| 261 | 
         
            -
                    precomp=(t_idx_BxTxC, indices_BTCx3),
         
     | 
| 262 | 
         
            -
                    T=seq_length,
         
     | 
| 263 | 
         
            -
                )
         
     | 
| 264 | 
         
            -
                reverted_codebook = reverted_codebook[:, :-30, :]
         
     | 
| 265 | 
         
            -
             
     | 
| 266 | 
         
            -
                codebook = reverted_codebook.transpose(1, 2)
         
     | 
| 267 | 
         
            -
             
     | 
| 268 | 
         
            -
                min_valid_index = 0
         
     | 
| 269 | 
         
            -
                max_valid_index = 1023
         
     | 
| 270 | 
         
            -
                invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
         
     | 
| 271 | 
         
            -
             
     | 
| 272 | 
         
            -
                num_invalid = torch.sum(invalid_mask).item()
         
     | 
| 273 | 
         
            -
                if num_invalid > 0:
         
     | 
| 274 | 
         
            -
                    print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
         
     | 
| 275 | 
         
            -
             
     | 
| 276 | 
         
            -
                # Set invalid values to 0 (modify the tensor in-place)
         
     | 
| 277 | 
         
            -
                codebook[invalid_mask] = 0
         
     | 
| 278 | 
         
            -
                audio_array = decode(model, codebook)
         
     | 
| 279 | 
         
            -
             
     | 
| 280 | 
         
            -
                return audio_array
         
     | 
| 
         | 
|
| 2 | 
         | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         | 
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
            +
            def build_delay_indices(
         
     | 
| 7 | 
         
            +
                B: int, T: int, C: int, delay_pattern: tp.List[int]
         
     | 
| 8 | 
         
            +
            ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 9 | 
         
             
                """
         
     | 
| 10 | 
         
             
                Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
         
     | 
| 11 | 
         
             
                Negative t_idx => BOS; t_idx >= T => PAD.
         
     | 
| 
         | 
|
| 69 | 
         | 
| 70 | 
         
             
                # Equivalent of tf.gather_nd using advanced indexing
         
     | 
| 71 | 
         
             
                # Ensure indices are long type if not already (build_delay_indices should handle this)
         
     | 
| 72 | 
         
            +
                gathered_flat = audio_BxTxC[
         
     | 
| 73 | 
         
            +
                    indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
         
     | 
| 74 | 
         
            +
                ]
         
     | 
| 75 | 
         
             
                gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
         
     | 
| 76 | 
         | 
| 77 | 
         
             
                # Create masks on the correct device
         
     | 
| 
         | 
|
| 84 | 
         | 
| 85 | 
         
             
                # If mask_bos, BOS; else if mask_pad, PAD; else original gather
         
     | 
| 86 | 
         
             
                # All tensors should now be on the same device
         
     | 
| 87 | 
         
            +
                result_BxTxC = torch.where(
         
     | 
| 88 | 
         
            +
                    mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 89 | 
         
             
                )
         
     | 
| 90 | 
         | 
| 91 | 
         
            +
                return result_BxTxC
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 92 | 
         | 
| 93 | 
         | 
| 94 | 
         
            +
            def build_revert_indices(
         
     | 
| 95 | 
         
            +
                B: int, T: int, C: int, delay_pattern: tp.List[int]
         
     | 
| 96 | 
         
            +
            ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 97 | 
         
             
                """
         
     | 
| 98 | 
         
             
                Precompute indices for the revert operation using PyTorch.
         
     | 
| 99 | 
         | 
| 
         | 
|
| 115 | 
         
             
                    t_idx_BT1 + delay_arr.view(1, 1, C),
         
     | 
| 116 | 
         
             
                    torch.tensor(T - 1, device=device),
         
     | 
| 117 | 
         
             
                )
         
     | 
| 118 | 
         
            +
                b_idx_BxTxC = torch.broadcast_to(
         
     | 
| 119 | 
         
            +
                    torch.arange(B, device=device).view(B, 1, 1), [B, T, C]
         
     | 
| 120 | 
         
            +
                )
         
     | 
| 121 | 
         
            +
                c_idx_BxTxC = torch.broadcast_to(
         
     | 
| 122 | 
         
            +
                    torch.arange(C, device=device).view(1, 1, C), [B, T, C]
         
     | 
| 123 | 
         
            +
                )
         
     | 
| 124 | 
         | 
| 125 | 
         
             
                indices_BTCx3 = torch.stack(
         
     | 
| 126 | 
         
             
                    [
         
     | 
| 
         | 
|
| 162 | 
         
             
                indices_BTCx3 = indices_BTCx3.to(device)
         
     | 
| 163 | 
         | 
| 164 | 
         
             
                # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
         
     | 
| 165 | 
         
            +
                gathered_flat = audio_BxTxC[
         
     | 
| 166 | 
         
            +
                    indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
         
     | 
| 167 | 
         
            +
                ]
         
     | 
| 168 | 
         
            +
                gathered_BxTxC = gathered_flat.view(
         
     | 
| 169 | 
         
            +
                    audio_BxTxC.size()
         
     | 
| 170 | 
         
            +
                )  # Use .size() for robust reshaping
         
     | 
| 171 | 
         | 
| 172 | 
         
             
                # Create pad_tensor on the correct device
         
     | 
| 173 | 
         
             
                pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
         
     | 
| 174 | 
         
             
                # Create T tensor on the correct device for comparison
         
     | 
| 175 | 
         
             
                T_tensor = torch.tensor(T, device=device)
         
     | 
| 176 | 
         | 
| 177 | 
         
            +
                result_BxTxC = torch.where(
         
     | 
| 178 | 
         
            +
                    t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC
         
     | 
| 179 | 
         
            +
                )  # Changed np.where to torch.where
         
     | 
| 180 | 
         | 
| 181 | 
         
             
                return result_BxTxC
         
     | 
| 182 | 
         | 
| 
         | 
|
| 201 | 
         
             
                except Exception as e:
         
     | 
| 202 | 
         
             
                    print(f"Error in decode method: {str(e)}")
         
     | 
| 203 | 
         
             
                    raise
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        dia/config.py
    CHANGED
    
    | 
         @@ -33,14 +33,20 @@ class DataConfig(BaseModel, frozen=True): 
     | 
|
| 33 | 
         
             
                    delay_pattern: List of delay values for each audio channel.
         
     | 
| 34 | 
         
             
                """
         
     | 
| 35 | 
         | 
| 36 | 
         
            -
                text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] =  
     | 
| 37 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 38 | 
         
             
                channels: int = Field(default=9, gt=0, multiple_of=1)
         
     | 
| 39 | 
         
             
                text_pad_value: int = Field(default=0)
         
     | 
| 40 | 
         
             
                audio_eos_value: int = Field(default=1024)
         
     | 
| 41 | 
         
             
                audio_pad_value: int = Field(default=1025)
         
     | 
| 42 | 
         
             
                audio_bos_value: int = Field(default=1026)
         
     | 
| 43 | 
         
            -
                delay_pattern: list[Annotated[int, Field(ge=0)]] = Field( 
     | 
| 
         | 
|
| 
         | 
|
| 44 | 
         | 
| 45 | 
         
             
                def __hash__(self) -> int:
         
     | 
| 46 | 
         
             
                    """Generate a hash based on all fields of the config."""
         
     | 
| 
         @@ -67,8 +73,6 @@ class EncoderConfig(BaseModel, frozen=True): 
     | 
|
| 67 | 
         
             
                    n_hidden: Hidden dimension size in the MLP layers.
         
     | 
| 68 | 
         
             
                    n_head: Number of attention heads.
         
     | 
| 69 | 
         
             
                    head_dim: Dimension per attention head.
         
     | 
| 70 | 
         
            -
                    mlp_activations: List of activation functions for the MLP layers.
         
     | 
| 71 | 
         
            -
                    use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
         
     | 
| 72 | 
         
             
                """
         
     | 
| 73 | 
         | 
| 74 | 
         
             
                n_layer: int = Field(gt=0)
         
     | 
| 
         @@ -76,8 +80,6 @@ class EncoderConfig(BaseModel, frozen=True): 
     | 
|
| 76 | 
         
             
                n_hidden: int = Field(gt=0)
         
     | 
| 77 | 
         
             
                n_head: int = Field(gt=0)
         
     | 
| 78 | 
         
             
                head_dim: int = Field(gt=0)
         
     | 
| 79 | 
         
            -
                mlp_activations: list[str] = Field(default=["silu", "linear"])
         
     | 
| 80 | 
         
            -
                use_pre_norm: bool = Field(default=False)
         
     | 
| 81 | 
         | 
| 82 | 
         | 
| 83 | 
         
             
            class DecoderConfig(BaseModel, frozen=True):
         
     | 
| 
         @@ -92,8 +94,6 @@ class DecoderConfig(BaseModel, frozen=True): 
     | 
|
| 92 | 
         
             
                    gqa_head_dim: Dimension per query head for grouped-query self-attention.
         
     | 
| 93 | 
         
             
                    cross_query_heads: Number of query heads for cross-attention.
         
     | 
| 94 | 
         
             
                    cross_head_dim: Dimension per cross-attention head.
         
     | 
| 95 | 
         
            -
                    mlp_activations: List of activation functions for the MLP layers.
         
     | 
| 96 | 
         
            -
                    use_pre_norm: Whether to use pre-normalization.
         
     | 
| 97 | 
         
             
                """
         
     | 
| 98 | 
         | 
| 99 | 
         
             
                n_layer: int = Field(gt=0)
         
     | 
| 
         @@ -104,8 +104,6 @@ class DecoderConfig(BaseModel, frozen=True): 
     | 
|
| 104 | 
         
             
                gqa_head_dim: int = Field(gt=0)
         
     | 
| 105 | 
         
             
                cross_query_heads: int = Field(gt=0)
         
     | 
| 106 | 
         
             
                cross_head_dim: int = Field(gt=0)
         
     | 
| 107 | 
         
            -
                mlp_activations: list[str] = Field(default=["silu", "linear"])
         
     | 
| 108 | 
         
            -
                use_pre_norm: bool = Field(default=False)
         
     | 
| 109 | 
         | 
| 110 | 
         | 
| 111 | 
         
             
            class ModelConfig(BaseModel, frozen=True):
         
     | 
| 
         @@ -130,24 +128,16 @@ class ModelConfig(BaseModel, frozen=True): 
     | 
|
| 130 | 
         
             
                dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
         
     | 
| 131 | 
         
             
                normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
         
     | 
| 132 | 
         
             
                weight_dtype: str = Field(default="float32", description="Weight precision")
         
     | 
| 133 | 
         
            -
                rope_min_timescale: int = Field( 
     | 
| 134 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 135 | 
         | 
| 136 | 
         | 
| 137 | 
         
             
            class TrainingConfig(BaseModel, frozen=True):
         
     | 
| 138 | 
         
            -
                 
     | 
| 139 | 
         
            -
             
     | 
| 140 | 
         
            -
                Note: This configuration currently only includes precision settings.
         
     | 
| 141 | 
         
            -
                Other training parameters (like batch size, learning rate, optimizer settings)
         
     | 
| 142 | 
         
            -
                are assumed to be handled externally.
         
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
                Attributes:
         
     | 
| 145 | 
         
            -
                    dtype: Data type for activations during training (e.g., "bfloat16", "float32").
         
     | 
| 146 | 
         
            -
                    logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
         
     | 
| 147 | 
         
            -
                """
         
     | 
| 148 | 
         
            -
             
     | 
| 149 | 
         
            -
                dtype: str = Field(default="bfloat16", description="Activation precision")
         
     | 
| 150 | 
         
            -
                logits_dot_in_fp32: bool = Field(default=False)
         
     | 
| 151 | 
         | 
| 152 | 
         | 
| 153 | 
         
             
            class DiaConfig(BaseModel, frozen=True):
         
     | 
| 
         @@ -164,6 +154,7 @@ class DiaConfig(BaseModel, frozen=True): 
     | 
|
| 164 | 
         | 
| 165 | 
         
             
                version: str = Field(default="1.0")
         
     | 
| 166 | 
         
             
                model: ModelConfig
         
     | 
| 
         | 
|
| 167 | 
         
             
                training: TrainingConfig
         
     | 
| 168 | 
         
             
                data: DataConfig
         
     | 
| 169 | 
         | 
| 
         | 
|
| 33 | 
         
             
                    delay_pattern: List of delay values for each audio channel.
         
     | 
| 34 | 
         
             
                """
         
     | 
| 35 | 
         | 
| 36 | 
         
            +
                text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
         
     | 
| 37 | 
         
            +
                    Field(gt=0, multiple_of=128)
         
     | 
| 38 | 
         
            +
                )
         
     | 
| 39 | 
         
            +
                audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
         
     | 
| 40 | 
         
            +
                    Field(gt=0, multiple_of=128)
         
     | 
| 41 | 
         
            +
                )
         
     | 
| 42 | 
         
             
                channels: int = Field(default=9, gt=0, multiple_of=1)
         
     | 
| 43 | 
         
             
                text_pad_value: int = Field(default=0)
         
     | 
| 44 | 
         
             
                audio_eos_value: int = Field(default=1024)
         
     | 
| 45 | 
         
             
                audio_pad_value: int = Field(default=1025)
         
     | 
| 46 | 
         
             
                audio_bos_value: int = Field(default=1026)
         
     | 
| 47 | 
         
            +
                delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
         
     | 
| 48 | 
         
            +
                    default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
         
     | 
| 49 | 
         
            +
                )
         
     | 
| 50 | 
         | 
| 51 | 
         
             
                def __hash__(self) -> int:
         
     | 
| 52 | 
         
             
                    """Generate a hash based on all fields of the config."""
         
     | 
| 
         | 
|
| 73 | 
         
             
                    n_hidden: Hidden dimension size in the MLP layers.
         
     | 
| 74 | 
         
             
                    n_head: Number of attention heads.
         
     | 
| 75 | 
         
             
                    head_dim: Dimension per attention head.
         
     | 
| 
         | 
|
| 
         | 
|
| 76 | 
         
             
                """
         
     | 
| 77 | 
         | 
| 78 | 
         
             
                n_layer: int = Field(gt=0)
         
     | 
| 
         | 
|
| 80 | 
         
             
                n_hidden: int = Field(gt=0)
         
     | 
| 81 | 
         
             
                n_head: int = Field(gt=0)
         
     | 
| 82 | 
         
             
                head_dim: int = Field(gt=0)
         
     | 
| 
         | 
|
| 
         | 
|
| 83 | 
         | 
| 84 | 
         | 
| 85 | 
         
             
            class DecoderConfig(BaseModel, frozen=True):
         
     | 
| 
         | 
|
| 94 | 
         
             
                    gqa_head_dim: Dimension per query head for grouped-query self-attention.
         
     | 
| 95 | 
         
             
                    cross_query_heads: Number of query heads for cross-attention.
         
     | 
| 96 | 
         
             
                    cross_head_dim: Dimension per cross-attention head.
         
     | 
| 
         | 
|
| 
         | 
|
| 97 | 
         
             
                """
         
     | 
| 98 | 
         | 
| 99 | 
         
             
                n_layer: int = Field(gt=0)
         
     | 
| 
         | 
|
| 104 | 
         
             
                gqa_head_dim: int = Field(gt=0)
         
     | 
| 105 | 
         
             
                cross_query_heads: int = Field(gt=0)
         
     | 
| 106 | 
         
             
                cross_head_dim: int = Field(gt=0)
         
     | 
| 
         | 
|
| 
         | 
|
| 107 | 
         | 
| 108 | 
         | 
| 109 | 
         
             
            class ModelConfig(BaseModel, frozen=True):
         
     | 
| 
         | 
|
| 128 | 
         
             
                dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
         
     | 
| 129 | 
         
             
                normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
         
     | 
| 130 | 
         
             
                weight_dtype: str = Field(default="float32", description="Weight precision")
         
     | 
| 131 | 
         
            +
                rope_min_timescale: int = Field(
         
     | 
| 132 | 
         
            +
                    default=1, description="Timescale For global Attention"
         
     | 
| 133 | 
         
            +
                )
         
     | 
| 134 | 
         
            +
                rope_max_timescale: int = Field(
         
     | 
| 135 | 
         
            +
                    default=10_000, description="Timescale For global Attention"
         
     | 
| 136 | 
         
            +
                )
         
     | 
| 137 | 
         | 
| 138 | 
         | 
| 139 | 
         
             
            class TrainingConfig(BaseModel, frozen=True):
         
     | 
| 140 | 
         
            +
                pass
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 141 | 
         | 
| 142 | 
         | 
| 143 | 
         
             
            class DiaConfig(BaseModel, frozen=True):
         
     | 
| 
         | 
|
| 154 | 
         | 
| 155 | 
         
             
                version: str = Field(default="1.0")
         
     | 
| 156 | 
         
             
                model: ModelConfig
         
     | 
| 157 | 
         
            +
                # TODO: remove training. this is just for backwards-compatability
         
     | 
| 158 | 
         
             
                training: TrainingConfig
         
     | 
| 159 | 
         
             
                data: DataConfig
         
     | 
| 160 | 
         | 
    	
        dia/layers.py
    CHANGED
    
    | 
         @@ -1,5 +1,3 @@ 
     | 
|
| 1 | 
         
            -
            from typing import Any
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
             
            import torch.nn as nn
         
     | 
| 5 | 
         
             
            import torch.nn.functional as F
         
     | 
| 
         @@ -7,26 +5,13 @@ from torch import Tensor 
     | 
|
| 7 | 
         
             
            from torch.nn import RMSNorm
         
     | 
| 8 | 
         | 
| 9 | 
         
             
            from .config import DiaConfig
         
     | 
| 
         | 
|
| 10 | 
         | 
| 11 | 
         | 
| 12 | 
         
             
            def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
         
     | 
| 13 | 
         
             
                return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
         
     | 
| 14 | 
         | 
| 15 | 
         | 
| 16 | 
         
            -
            def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
         
     | 
| 17 | 
         
            -
                # Allow None for default behavior
         
     | 
| 18 | 
         
            -
                if dtype_str is None or dtype_str.lower() == "none":
         
     | 
| 19 | 
         
            -
                    return None
         
     | 
| 20 | 
         
            -
                if dtype_str == "float32":
         
     | 
| 21 | 
         
            -
                    return torch.float32
         
     | 
| 22 | 
         
            -
                elif dtype_str == "float16":
         
     | 
| 23 | 
         
            -
                    return torch.float16
         
     | 
| 24 | 
         
            -
                elif dtype_str == "bfloat16":
         
     | 
| 25 | 
         
            -
                    return torch.bfloat16
         
     | 
| 26 | 
         
            -
                else:
         
     | 
| 27 | 
         
            -
                    raise ValueError(f"Unsupported dtype string: {dtype_str}")
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
             
            class DenseGeneral(nn.Module):
         
     | 
| 31 | 
         
             
                """
         
     | 
| 32 | 
         
             
                PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
         
     | 
| 
         @@ -50,7 +35,6 @@ class DenseGeneral(nn.Module): 
     | 
|
| 50 | 
         
             
                    in_shapes: tuple[int, ...],
         
     | 
| 51 | 
         
             
                    out_features: tuple[int, ...],
         
     | 
| 52 | 
         
             
                    axis: tuple[int, ...] = (-1,),
         
     | 
| 53 | 
         
            -
                    dtype: torch.dtype | None = None,
         
     | 
| 54 | 
         
             
                    weight_dtype: torch.dtype | None = None,
         
     | 
| 55 | 
         
             
                    device: torch.device | None = None,
         
     | 
| 56 | 
         
             
                ):
         
     | 
| 
         @@ -58,7 +42,6 @@ class DenseGeneral(nn.Module): 
     | 
|
| 58 | 
         
             
                    self.in_shapes = in_shapes
         
     | 
| 59 | 
         
             
                    self.out_features = out_features
         
     | 
| 60 | 
         
             
                    self.axis = axis
         
     | 
| 61 | 
         
            -
                    self.dtype = dtype
         
     | 
| 62 | 
         
             
                    self.kernel_shape = self.in_shapes + self.out_features
         
     | 
| 63 | 
         | 
| 64 | 
         
             
                    factory_kwargs = {"device": device, "dtype": weight_dtype}
         
     | 
| 
         @@ -70,95 +53,44 @@ class DenseGeneral(nn.Module): 
     | 
|
| 70 | 
         
             
                    kernel_contract_axes = tuple(range(len(norm_axis)))
         
     | 
| 71 | 
         | 
| 72 | 
         
             
                    output = torch.tensordot(
         
     | 
| 73 | 
         
            -
                        inputs. 
     | 
| 74 | 
         
            -
                        self.weight 
     | 
| 75 | 
         
             
                        dims=(norm_axis, kernel_contract_axes),
         
     | 
| 76 | 
         
             
                    ).to(inputs.dtype)
         
     | 
| 77 | 
         
             
                    return output
         
     | 
| 78 | 
         | 
| 79 | 
         | 
| 80 | 
         
            -
            def get_activation_fn(activation_string: str) -> nn.Module:  # Return Module instance
         
     | 
| 81 | 
         
            -
                """Maps activation string to PyTorch activation function module."""
         
     | 
| 82 | 
         
            -
                if activation_string == "gelu":
         
     | 
| 83 | 
         
            -
                    return nn.GELU()
         
     | 
| 84 | 
         
            -
                elif activation_string == "relu":
         
     | 
| 85 | 
         
            -
                    return nn.ReLU()
         
     | 
| 86 | 
         
            -
                elif activation_string == "silu" or activation_string == "swish":
         
     | 
| 87 | 
         
            -
                    return nn.SiLU()
         
     | 
| 88 | 
         
            -
                elif activation_string == "linear":
         
     | 
| 89 | 
         
            -
                    return nn.Identity()
         
     | 
| 90 | 
         
            -
                else:
         
     | 
| 91 | 
         
            -
                    raise ValueError(f"Unsupported activation function: {activation_string}")
         
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
             
            class MlpBlock(nn.Module):
         
     | 
| 95 | 
         
             
                """MLP block using DenseGeneral."""
         
     | 
| 96 | 
         | 
| 97 | 
         
             
                def __init__(
         
     | 
| 98 | 
         
            -
                    self,
         
     | 
| 99 | 
         
            -
                    config: DiaConfig,
         
     | 
| 100 | 
         
            -
                    embed_dim: int,
         
     | 
| 101 | 
         
            -
                    intermediate_dim: int,
         
     | 
| 102 | 
         
            -
                    dropout_rate: float,
         
     | 
| 103 | 
         
            -
                    activations: list[str] = ["silu", "linear"],
         
     | 
| 104 | 
         
            -
                    use_pre_norm: bool = False,
         
     | 
| 105 | 
         
             
                ):
         
     | 
| 106 | 
         
             
                    super().__init__()
         
     | 
| 107 | 
         
            -
                    self.use_pre_norm = use_pre_norm
         
     | 
| 108 | 
         
            -
                    num_activations = len(activations)
         
     | 
| 109 | 
         
            -
                    compute_dtype = _str_to_dtype(config.training.dtype)
         
     | 
| 110 | 
         
            -
                    weight_dtype = _str_to_dtype(config.model.weight_dtype)
         
     | 
| 111 | 
         
             
                    self.dtype = compute_dtype
         
     | 
| 112 | 
         
            -
                    # Assume default device for now, could be passed in config
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                    if use_pre_norm:
         
     | 
| 115 | 
         
            -
                        self.pre_norm = RMSNorm(
         
     | 
| 116 | 
         
            -
                            embed_dim,
         
     | 
| 117 | 
         
            -
                            eps=config.model.normalization_layer_epsilon,
         
     | 
| 118 | 
         
            -
                            dtype=torch.float32,
         
     | 
| 119 | 
         
            -
                        )
         
     | 
| 120 | 
         | 
| 121 | 
         
             
                    self.wi_fused = DenseGeneral(
         
     | 
| 122 | 
         
             
                        in_shapes=(embed_dim,),
         
     | 
| 123 | 
         
            -
                        out_features=(
         
     | 
| 124 | 
         
            -
                            num_activations,
         
     | 
| 125 | 
         
            -
                            intermediate_dim,
         
     | 
| 126 | 
         
            -
                        ),
         
     | 
| 127 | 
         
             
                        axis=(-1,),
         
     | 
| 128 | 
         
            -
                         
     | 
| 129 | 
         
            -
                        weight_dtype=weight_dtype,
         
     | 
| 130 | 
         
             
                    )
         
     | 
| 131 | 
         | 
| 132 | 
         
            -
                    self.activation_fn_0 = get_activation_fn(activations[0])  # silu
         
     | 
| 133 | 
         
            -
                    self.activation_fn_1 = get_activation_fn(activations[1])  # linear
         
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         
            -
                    self.dropout = nn.Dropout(dropout_rate)
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
                    # Output layer using DenseGeneral
         
     | 
| 138 | 
         
             
                    self.wo = DenseGeneral(
         
     | 
| 139 | 
         
             
                        in_shapes=(intermediate_dim,),
         
     | 
| 140 | 
         
             
                        out_features=(embed_dim,),
         
     | 
| 141 | 
         
             
                        axis=(-1,),
         
     | 
| 142 | 
         
            -
                         
     | 
| 143 | 
         
            -
                        weight_dtype=weight_dtype,
         
     | 
| 144 | 
         
             
                    )
         
     | 
| 145 | 
         | 
| 146 | 
         
            -
                def forward(self, x: torch.Tensor 
     | 
| 147 | 
         
             
                    """Forward pass."""
         
     | 
| 148 | 
         
            -
                    if self.use_pre_norm and hasattr(self, "pre_norm"):
         
     | 
| 149 | 
         
            -
                        x = self.pre_norm(x)
         
     | 
| 150 | 
         
            -
             
     | 
| 151 | 
         
             
                    fused_x = self.wi_fused(x)
         
     | 
| 152 | 
         | 
| 153 | 
         
            -
                     
     | 
| 154 | 
         
            -
                     
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                    gate = self.activation_fn_0(gate_input)
         
     | 
| 157 | 
         
            -
                    up = self.activation_fn_1(up_input)
         
     | 
| 158 | 
         
            -
                    hidden = torch.mul(gate, up).to(self.dtype)
         
     | 
| 159 | 
         | 
| 160 | 
         
            -
                     
     | 
| 161 | 
         
            -
                        hidden = self.dropout(hidden)
         
     | 
| 162 | 
         | 
| 163 | 
         
             
                    output = self.wo(hidden)
         
     | 
| 164 | 
         
             
                    return output
         
     | 
| 
         @@ -207,37 +139,6 @@ class RotaryEmbedding(nn.Module): 
     | 
|
| 207 | 
         
             
                    return torch.cat((first_part, second_part), dim=-1)
         
     | 
| 208 | 
         | 
| 209 | 
         | 
| 210 | 
         
            -
            class KVCache:
         
     | 
| 211 | 
         
            -
                def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
         
     | 
| 212 | 
         
            -
                    self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
         
     | 
| 213 | 
         
            -
                    self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
         
     | 
| 214 | 
         
            -
                    self.current_idx = 0
         
     | 
| 215 | 
         
            -
                    self.max_len = max_len
         
     | 
| 216 | 
         
            -
             
     | 
| 217 | 
         
            -
                def get_kv_for_attention(self, current_k, current_v):
         
     | 
| 218 | 
         
            -
                    if self.current_idx == 0:
         
     | 
| 219 | 
         
            -
                        return current_k, current_v
         
     | 
| 220 | 
         
            -
                    else:
         
     | 
| 221 | 
         
            -
                        past_k = self.k[:, :, : self.current_idx, :]
         
     | 
| 222 | 
         
            -
                        past_v = self.v[:, :, : self.current_idx, :]
         
     | 
| 223 | 
         
            -
                        attn_k = torch.cat((past_k, current_k), dim=2)
         
     | 
| 224 | 
         
            -
                        attn_v = torch.cat((past_v, current_v), dim=2)
         
     | 
| 225 | 
         
            -
                        return attn_k, attn_v
         
     | 
| 226 | 
         
            -
             
     | 
| 227 | 
         
            -
                def update_cache(self, k, v):
         
     | 
| 228 | 
         
            -
                    assert self.current_idx < self.max_len
         
     | 
| 229 | 
         
            -
                    self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
         
     | 
| 230 | 
         
            -
                    self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
         
     | 
| 231 | 
         
            -
                    self.current_idx += 1
         
     | 
| 232 | 
         
            -
             
     | 
| 233 | 
         
            -
                def prefill_kv(self, k, v):
         
     | 
| 234 | 
         
            -
                    prefill_len = k.shape[2]
         
     | 
| 235 | 
         
            -
                    assert prefill_len <= self.max_len
         
     | 
| 236 | 
         
            -
                    self.k[:, :, :prefill_len, :] = k
         
     | 
| 237 | 
         
            -
                    self.v[:, :, :prefill_len, :] = v
         
     | 
| 238 | 
         
            -
                    self.current_idx = prefill_len
         
     | 
| 239 | 
         
            -
             
     | 
| 240 | 
         
            -
             
     | 
| 241 | 
         
             
            class Attention(nn.Module):
         
     | 
| 242 | 
         
             
                """Attention using DenseGeneral."""
         
     | 
| 243 | 
         | 
| 
         @@ -249,7 +150,7 @@ class Attention(nn.Module): 
     | 
|
| 249 | 
         
             
                    num_query_heads: int,
         
     | 
| 250 | 
         
             
                    num_kv_heads: int,
         
     | 
| 251 | 
         
             
                    head_dim: int,
         
     | 
| 252 | 
         
            -
                     
     | 
| 253 | 
         
             
                    is_cross_attn: bool = False,
         
     | 
| 254 | 
         
             
                    out_embed_dim: int | None = None,
         
     | 
| 255 | 
         
             
                ):
         
     | 
| 
         @@ -258,13 +159,12 @@ class Attention(nn.Module): 
     | 
|
| 258 | 
         
             
                    self.num_kv_heads = num_kv_heads
         
     | 
| 259 | 
         
             
                    self.head_dim = head_dim
         
     | 
| 260 | 
         
             
                    self.is_cross_attn = is_cross_attn
         
     | 
| 261 | 
         
            -
                    self.dropout_rate = dropout_rate
         
     | 
| 262 | 
         
            -
                    compute_dtype = _str_to_dtype(config.training.dtype)
         
     | 
| 263 | 
         
            -
                    weight_dtype = _str_to_dtype(config.model.weight_dtype)
         
     | 
| 264 | 
         
             
                    self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
         
     | 
| 265 | 
         
             
                    self.projected_query_dim = num_query_heads * head_dim
         
     | 
| 266 | 
         
             
                    if num_query_heads % num_kv_heads != 0:
         
     | 
| 267 | 
         
            -
                        raise ValueError( 
     | 
| 
         | 
|
| 
         | 
|
| 268 | 
         
             
                    self.num_gqa_groups = num_query_heads // num_kv_heads
         
     | 
| 269 | 
         | 
| 270 | 
         
             
                    # --- Projection Layers using DenseGeneral ---
         
     | 
| 
         @@ -272,29 +172,25 @@ class Attention(nn.Module): 
     | 
|
| 272 | 
         
             
                        in_shapes=(q_embed_dim,),
         
     | 
| 273 | 
         
             
                        out_features=(num_query_heads, head_dim),
         
     | 
| 274 | 
         
             
                        axis=(-1,),
         
     | 
| 275 | 
         
            -
                         
     | 
| 276 | 
         
            -
                        weight_dtype=weight_dtype,
         
     | 
| 277 | 
         
             
                    )
         
     | 
| 278 | 
         
             
                    self.k_proj = DenseGeneral(
         
     | 
| 279 | 
         
             
                        in_shapes=(kv_embed_dim,),
         
     | 
| 280 | 
         
             
                        out_features=(num_kv_heads, head_dim),
         
     | 
| 281 | 
         
             
                        axis=(-1,),
         
     | 
| 282 | 
         
            -
                         
     | 
| 283 | 
         
            -
                        weight_dtype=weight_dtype,
         
     | 
| 284 | 
         
             
                    )
         
     | 
| 285 | 
         
             
                    self.v_proj = DenseGeneral(
         
     | 
| 286 | 
         
             
                        in_shapes=(kv_embed_dim,),
         
     | 
| 287 | 
         
             
                        out_features=(num_kv_heads, head_dim),
         
     | 
| 288 | 
         
             
                        axis=(-1,),
         
     | 
| 289 | 
         
            -
                         
     | 
| 290 | 
         
            -
                        weight_dtype=weight_dtype,
         
     | 
| 291 | 
         
             
                    )
         
     | 
| 292 | 
         
             
                    self.o_proj = DenseGeneral(
         
     | 
| 293 | 
         
             
                        in_shapes=(num_query_heads, head_dim),
         
     | 
| 294 | 
         
             
                        out_features=(self.output_dim,),
         
     | 
| 295 | 
         
             
                        axis=(-2, -1),
         
     | 
| 296 | 
         
            -
                         
     | 
| 297 | 
         
            -
                        weight_dtype=weight_dtype,
         
     | 
| 298 | 
         
             
                    )
         
     | 
| 299 | 
         | 
| 300 | 
         
             
                    # --- Rotary Embedding ---
         
     | 
| 
         @@ -311,10 +207,11 @@ class Attention(nn.Module): 
     | 
|
| 311 | 
         
             
                    Xkv: torch.Tensor,  # (B, S, E) S = 1 in AR generation
         
     | 
| 312 | 
         
             
                    q_positions: torch.Tensor,  # (B, T)
         
     | 
| 313 | 
         
             
                    kv_positions: torch.Tensor | None = None,  # (B, S)
         
     | 
| 314 | 
         
            -
                     
     | 
| 315 | 
         
            -
                     
     | 
| 316 | 
         
             
                    cache: KVCache | None = None,  # None in Encoder, KVCache in Decoder
         
     | 
| 317 | 
         
            -
                    prefill: bool = False, 
     | 
| 
         | 
|
| 318 | 
         
             
                ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
         
     | 
| 319 | 
         
             
                    """
         
     | 
| 320 | 
         
             
                    Performs attention calculation with optional KV caching.
         
     | 
| 
         @@ -324,7 +221,6 @@ class Attention(nn.Module): 
     | 
|
| 324 | 
         
             
                        Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
         
     | 
| 325 | 
         
             
                        q_positions: Positions for queries (B, T).
         
     | 
| 326 | 
         
             
                        kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
         
     | 
| 327 | 
         
            -
                        deterministic: If True, disable dropout.
         
     | 
| 328 | 
         
             
                        attn_mask: Attention mask.
         
     | 
| 329 | 
         
             
                        cache: KVCache.
         
     | 
| 330 | 
         
             
                        prefill: If True, use prefill mode.
         
     | 
| 
         @@ -342,72 +238,51 @@ class Attention(nn.Module): 
     | 
|
| 342 | 
         
             
                    Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
         
     | 
| 343 | 
         
             
                    Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
         
     | 
| 344 | 
         | 
| 345 | 
         
            -
                    # Input values into attention calculation
         
     | 
| 346 | 
         
             
                    attn_k: torch.Tensor | None = None
         
     | 
| 347 | 
         
             
                    attn_v: torch.Tensor | None = None
         
     | 
| 348 | 
         
            -
                    new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
         
     | 
| 349 | 
         | 
| 350 | 
         
            -
                    # Decoder Cross Attention
         
     | 
| 351 | 
         
             
                    if self.is_cross_attn:
         
     | 
| 352 | 
         
            -
                        # Directly use cache (no need to check index)
         
     | 
| 353 | 
         
             
                        attn_k, attn_v = cache.k, cache.v
         
     | 
| 354 | 
         
            -
                        if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
         
     | 
| 355 | 
         
            -
                            raise ValueError(
         
     | 
| 356 | 
         
            -
                                f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
         
     | 
| 357 | 
         
            -
                                f"does not match num_query_heads ({self.num_query_heads}). "
         
     | 
| 358 | 
         
            -
                                "Cache should be pre-repeated for GQA."
         
     | 
| 359 | 
         
            -
                            )
         
     | 
| 360 | 
         
            -
                    # Self Attention
         
     | 
| 361 | 
         
             
                    else:
         
     | 
| 362 | 
         
             
                        Xk_BxSxKxH = self.k_proj(Xkv)  # (B, S, K, H)
         
     | 
| 363 | 
         
             
                        Xv_BxSxKxH = self.v_proj(Xkv)  # (B, S, K, H)
         
     | 
| 364 | 
         
            -
                        Xk_BxSxKxH = self.rotary_emb( 
     | 
| 
         | 
|
| 
         | 
|
| 365 | 
         | 
| 366 | 
         
             
                        Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2)  # (B, K, S, H)
         
     | 
| 367 | 
         
             
                        Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2)  # (B, K, S, H)
         
     | 
| 368 | 
         
            -
                        # S=1 for Decode Step
         
     | 
| 369 | 
         
            -
             
     | 
| 370 | 
         
            -
                        if self.num_gqa_groups > 1:
         
     | 
| 371 | 
         
            -
                            Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
         
     | 
| 372 | 
         
            -
                            Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
         
     | 
| 373 | 
         
            -
                        else:
         
     | 
| 374 | 
         
            -
                            Xk_BxNxSxH = Xk_BxKxSxH
         
     | 
| 375 | 
         
            -
                            Xv_BxNxSxH = Xv_BxKxSxH
         
     | 
| 376 | 
         | 
| 377 | 
         
            -
                        # Encoder Self Attention
         
     | 
| 378 | 
         
             
                        if cache is None:
         
     | 
| 379 | 
         
            -
                            attn_k =  
     | 
| 380 | 
         
            -
                            attn_v =  
     | 
| 381 | 
         
            -
                        # Decoder Self Attention
         
     | 
| 382 | 
         
             
                        else:
         
     | 
| 383 | 
         
            -
                            # In prefill mode, we fill in cache until prefill length
         
     | 
| 384 | 
         
             
                            if prefill:
         
     | 
| 385 | 
         
            -
                                attn_k, attn_v =  
     | 
| 386 | 
         
            -
                                cache. 
     | 
| 387 | 
         
            -
                            # In decode step, we add current K/V to cache step by step
         
     | 
| 388 | 
         
             
                            else:
         
     | 
| 389 | 
         
            -
                                 
     | 
| 390 | 
         
            -
                                attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
         
     | 
| 391 | 
         | 
| 392 | 
         
             
                    attn_output = F.scaled_dot_product_attention(
         
     | 
| 393 | 
         
             
                        Xq_BxNxTxH,
         
     | 
| 394 | 
         
             
                        attn_k,
         
     | 
| 395 | 
         
             
                        attn_v,
         
     | 
| 396 | 
         
             
                        attn_mask=attn_mask,
         
     | 
| 397 | 
         
            -
                        dropout_p=self.dropout_rate if not deterministic else 0.0,
         
     | 
| 398 | 
         
             
                        scale=1.0,
         
     | 
| 
         | 
|
| 
         | 
|
| 399 | 
         
             
                    )
         
     | 
| 400 | 
         | 
| 401 | 
         
             
                    attn_output = attn_output.transpose(1, 2).contiguous()  # (B, T, N, H)
         
     | 
| 402 | 
         
             
                    output = self.o_proj(attn_output)
         
     | 
| 403 | 
         | 
| 404 | 
         
            -
                    return output.to(original_dtype) 
     | 
| 405 | 
         | 
| 406 | 
         | 
| 407 | 
         
             
            class EncoderLayer(nn.Module):
         
     | 
| 408 | 
         
             
                """Transformer Encoder Layer using DenseGeneral."""
         
     | 
| 409 | 
         | 
| 410 | 
         
            -
                def __init__(self, config: DiaConfig):
         
     | 
| 411 | 
         
             
                    super().__init__()
         
     | 
| 412 | 
         
             
                    self.config = config
         
     | 
| 413 | 
         
             
                    model_config = config.model
         
     | 
| 
         @@ -420,13 +295,13 @@ class EncoderLayer(nn.Module): 
     | 
|
| 420 | 
         
             
                        dtype=torch.float32,
         
     | 
| 421 | 
         
             
                    )
         
     | 
| 422 | 
         
             
                    self.self_attention = Attention(
         
     | 
| 423 | 
         
            -
                        config 
     | 
| 424 | 
         
             
                        q_embed_dim=embed_dim,
         
     | 
| 425 | 
         
             
                        kv_embed_dim=embed_dim,
         
     | 
| 426 | 
         
             
                        num_query_heads=enc_config.n_head,
         
     | 
| 427 | 
         
             
                        num_kv_heads=enc_config.n_head,
         
     | 
| 428 | 
         
             
                        head_dim=enc_config.head_dim,
         
     | 
| 429 | 
         
            -
                         
     | 
| 430 | 
         
             
                        is_cross_attn=False,
         
     | 
| 431 | 
         
             
                        out_embed_dim=embed_dim,
         
     | 
| 432 | 
         
             
                    )
         
     | 
| 
         @@ -436,62 +311,52 @@ class EncoderLayer(nn.Module): 
     | 
|
| 436 | 
         
             
                        dtype=torch.float32,
         
     | 
| 437 | 
         
             
                    )
         
     | 
| 438 | 
         
             
                    self.mlp = MlpBlock(
         
     | 
| 439 | 
         
            -
                        config=config,
         
     | 
| 440 | 
         
             
                        embed_dim=embed_dim,
         
     | 
| 441 | 
         
             
                        intermediate_dim=enc_config.n_hidden,
         
     | 
| 442 | 
         
            -
                         
     | 
| 443 | 
         
            -
                        dropout_rate=model_config.dropout,
         
     | 
| 444 | 
         
            -
                        use_pre_norm=enc_config.use_pre_norm,
         
     | 
| 445 | 
         
             
                    )
         
     | 
| 446 | 
         
            -
                    self.dropout = nn.Dropout(model_config.dropout)
         
     | 
| 447 | 
         | 
| 448 | 
         
             
                def forward(
         
     | 
| 449 | 
         
             
                    self,
         
     | 
| 450 | 
         
             
                    x: torch.Tensor,
         
     | 
| 451 | 
         
            -
                     
     | 
| 452 | 
         
            -
                    deterministic: bool = True,
         
     | 
| 453 | 
         
            -
                    attn_mask: torch.Tensor | None = None,
         
     | 
| 454 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 455 | 
         
             
                    residual = x
         
     | 
| 456 | 
         
             
                    x_norm = self.pre_sa_norm(x)
         
     | 
| 457 | 
         
            -
             
     | 
| 458 | 
         
            -
                    sa_out, _ = self.self_attention(
         
     | 
| 459 | 
         
             
                        Xq=x_norm,
         
     | 
| 460 | 
         
             
                        Xkv=x_norm,
         
     | 
| 461 | 
         
            -
                        q_positions= 
     | 
| 462 | 
         
            -
                        kv_positions= 
     | 
| 463 | 
         
            -
                         
     | 
| 464 | 
         
            -
                        attn_mask=attn_mask,
         
     | 
| 465 | 
         
             
                    )
         
     | 
| 466 | 
         
             
                    x = residual + sa_out
         
     | 
| 467 | 
         | 
| 468 | 
         
             
                    residual = x
         
     | 
| 469 | 
         
             
                    x_norm = self.post_sa_norm(x)
         
     | 
| 470 | 
         
            -
                    mlp_out = self.mlp(x_norm 
     | 
| 471 | 
         
             
                    x = residual + mlp_out
         
     | 
| 472 | 
         | 
| 473 | 
         
            -
                    if not deterministic:
         
     | 
| 474 | 
         
            -
                        x = self.dropout(x)
         
     | 
| 475 | 
         
             
                    return x
         
     | 
| 476 | 
         | 
| 477 | 
         | 
| 478 | 
         
             
            class Encoder(nn.Module):
         
     | 
| 479 | 
         
             
                """Transformer Encoder Stack using DenseGeneral."""
         
     | 
| 480 | 
         | 
| 481 | 
         
            -
                def __init__(self, config: DiaConfig):
         
     | 
| 482 | 
         
             
                    super().__init__()
         
     | 
| 483 | 
         
             
                    self.config = config
         
     | 
| 484 | 
         
             
                    model_config = config.model
         
     | 
| 485 | 
         
             
                    enc_config = config.model.encoder
         
     | 
| 486 | 
         
            -
                    compute_dtype = _str_to_dtype(config.training.dtype)
         
     | 
| 487 | 
         | 
| 488 | 
         
             
                    self.embedding = nn.Embedding(
         
     | 
| 489 | 
         
             
                        model_config.src_vocab_size,
         
     | 
| 490 | 
         
             
                        enc_config.n_embd,
         
     | 
| 491 | 
         
             
                        dtype=compute_dtype,
         
     | 
| 492 | 
         
             
                    )
         
     | 
| 493 | 
         
            -
                    self. 
     | 
| 494 | 
         
            -
             
     | 
| 
         | 
|
| 495 | 
         
             
                    self.norm = RMSNorm(
         
     | 
| 496 | 
         
             
                        enc_config.n_embd,
         
     | 
| 497 | 
         
             
                        eps=model_config.normalization_layer_epsilon,
         
     | 
| 
         @@ -501,32 +366,21 @@ class Encoder(nn.Module): 
     | 
|
| 501 | 
         
             
                def forward(
         
     | 
| 502 | 
         
             
                    self,
         
     | 
| 503 | 
         
             
                    x_ids: torch.Tensor,
         
     | 
| 504 | 
         
            -
                     
     | 
| 505 | 
         
            -
                    deterministic: bool = True,
         
     | 
| 506 | 
         
            -
                    attn_mask: torch.Tensor | None = None,
         
     | 
| 507 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 508 | 
         
             
                    x = self.embedding(x_ids)
         
     | 
| 509 | 
         | 
| 510 | 
         
            -
                    if not deterministic:
         
     | 
| 511 | 
         
            -
                        x = self.dropout(x)
         
     | 
| 512 | 
         
            -
             
     | 
| 513 | 
         
             
                    for layer in self.layers:
         
     | 
| 514 | 
         
            -
                        x = layer(
         
     | 
| 515 | 
         
            -
             
     | 
| 516 | 
         
            -
                            src_positions=src_positions,
         
     | 
| 517 | 
         
            -
                            deterministic=deterministic,
         
     | 
| 518 | 
         
            -
                            attn_mask=attn_mask,
         
     | 
| 519 | 
         
            -
                        )
         
     | 
| 520 | 
         
             
                    x = self.norm(x)
         
     | 
| 521 | 
         
            -
                    if not deterministic:
         
     | 
| 522 | 
         
            -
                        x = self.dropout(x)
         
     | 
| 523 | 
         
             
                    return x
         
     | 
| 524 | 
         | 
| 525 | 
         | 
| 526 | 
         
             
            class DecoderLayer(nn.Module):
         
     | 
| 527 | 
         
             
                """Transformer Decoder Layer using DenseGeneral."""
         
     | 
| 528 | 
         | 
| 529 | 
         
            -
                def __init__(self, config: DiaConfig):
         
     | 
| 530 | 
         
             
                    super().__init__()
         
     | 
| 531 | 
         
             
                    self.config = config
         
     | 
| 532 | 
         
             
                    model_config = config.model
         
     | 
| 
         @@ -554,13 +408,13 @@ class DecoderLayer(nn.Module): 
     | 
|
| 554 | 
         | 
| 555 | 
         
             
                    # Self-Attention (GQA) with Causal Masking
         
     | 
| 556 | 
         
             
                    self.self_attention = Attention(
         
     | 
| 557 | 
         
            -
                        config 
     | 
| 558 | 
         
             
                        q_embed_dim=dec_embed_dim,
         
     | 
| 559 | 
         
             
                        kv_embed_dim=dec_embed_dim,
         
     | 
| 560 | 
         
             
                        num_query_heads=dec_config.gqa_query_heads,
         
     | 
| 561 | 
         
             
                        num_kv_heads=dec_config.kv_heads,
         
     | 
| 562 | 
         
             
                        head_dim=dec_config.gqa_head_dim,
         
     | 
| 563 | 
         
            -
                         
     | 
| 564 | 
         
             
                        is_cross_attn=False,
         
     | 
| 565 | 
         
             
                        out_embed_dim=dec_embed_dim,
         
     | 
| 566 | 
         
             
                    )
         
     | 
| 
         @@ -572,116 +426,105 @@ class DecoderLayer(nn.Module): 
     | 
|
| 572 | 
         
             
                        num_query_heads=dec_config.cross_query_heads,
         
     | 
| 573 | 
         
             
                        num_kv_heads=dec_config.cross_query_heads,
         
     | 
| 574 | 
         
             
                        head_dim=dec_config.cross_head_dim,
         
     | 
| 575 | 
         
            -
                         
     | 
| 576 | 
         
             
                        is_cross_attn=True,
         
     | 
| 577 | 
         
             
                        out_embed_dim=dec_embed_dim,
         
     | 
| 578 | 
         
             
                    )
         
     | 
| 579 | 
         
             
                    # MLP
         
     | 
| 580 | 
         
             
                    self.mlp = MlpBlock(
         
     | 
| 581 | 
         
            -
                        config=config,
         
     | 
| 582 | 
         
             
                        embed_dim=dec_embed_dim,
         
     | 
| 583 | 
         
             
                        intermediate_dim=dec_config.n_hidden,
         
     | 
| 584 | 
         
            -
                         
     | 
| 585 | 
         
            -
                        dropout_rate=model_config.dropout,
         
     | 
| 586 | 
         
            -
                        use_pre_norm=dec_config.use_pre_norm,
         
     | 
| 587 | 
         
             
                    )
         
     | 
| 588 | 
         | 
| 589 | 
         
             
                def forward(
         
     | 
| 590 | 
         
             
                    self,
         
     | 
| 591 | 
         
             
                    x: torch.Tensor,
         
     | 
| 592 | 
         
            -
                     
     | 
| 593 | 
         
            -
                     
     | 
| 594 | 
         
            -
                     
     | 
| 595 | 
         
            -
                    deterministic: bool,
         
     | 
| 596 | 
         
            -
                    self_attn_mask: torch.Tensor,
         
     | 
| 597 | 
         
            -
                    cross_attn_mask: torch.Tensor,
         
     | 
| 598 | 
         
            -
                    self_attn_cache: KVCache,
         
     | 
| 599 | 
         
            -
                    cross_attn_cache: KVCache,
         
     | 
| 600 | 
         
             
                    prefill: bool = False,
         
     | 
| 601 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 602 | 
         
             
                    residual = x
         
     | 
| 603 | 
         
             
                    x_norm = self.pre_sa_norm(x)
         
     | 
| 604 | 
         | 
| 605 | 
         
            -
                    sa_out 
     | 
| 606 | 
         
             
                        Xq=x_norm,  # (2, 1, D)
         
     | 
| 607 | 
         
             
                        Xkv=x_norm,  # (2, 1, D)
         
     | 
| 608 | 
         
            -
                        q_positions= 
     | 
| 609 | 
         
            -
                        kv_positions= 
     | 
| 610 | 
         
            -
                         
     | 
| 611 | 
         
            -
                        attn_mask=self_attn_mask,  # (2, 1, 1, S_max)
         
     | 
| 612 | 
         
             
                        cache=self_attn_cache,
         
     | 
| 613 | 
         
             
                        prefill=prefill,
         
     | 
| 
         | 
|
| 614 | 
         
             
                    )
         
     | 
| 615 | 
         | 
| 616 | 
         
             
                    x = residual + sa_out
         
     | 
| 617 | 
         | 
| 618 | 
         
            -
                    # 2. Cross-Attention
         
     | 
| 619 | 
         
             
                    residual = x
         
     | 
| 620 | 
         
             
                    x_norm = self.pre_ca_norm(x)
         
     | 
| 621 | 
         
            -
                    ca_out 
     | 
| 622 | 
         
             
                        Xq=x_norm,
         
     | 
| 623 | 
         
            -
                        Xkv= 
     | 
| 624 | 
         
            -
                        q_positions= 
     | 
| 625 | 
         
            -
                        kv_positions= 
     | 
| 626 | 
         
            -
                         
     | 
| 627 | 
         
            -
                        attn_mask=cross_attn_mask,
         
     | 
| 628 | 
         
             
                        cache=cross_attn_cache,
         
     | 
| 629 | 
         
             
                    )
         
     | 
| 630 | 
         
             
                    x = residual + ca_out
         
     | 
| 631 | 
         | 
| 632 | 
         
            -
                    # 3. MLP
         
     | 
| 633 | 
         
             
                    residual = x
         
     | 
| 634 | 
         
             
                    x_norm = self.pre_mlp_norm(x)
         
     | 
| 635 | 
         
            -
                    mlp_out = self.mlp(x_norm 
     | 
| 636 | 
         
             
                    x = residual + mlp_out
         
     | 
| 637 | 
         | 
| 638 | 
         
            -
                    return x 
     | 
| 639 | 
         | 
| 640 | 
         | 
| 641 | 
         
             
            class Decoder(nn.Module):
         
     | 
| 642 | 
         
             
                """Transformer Decoder Stack using DenseGeneral."""
         
     | 
| 643 | 
         | 
| 644 | 
         
            -
                def __init__(self, config: DiaConfig):
         
     | 
| 645 | 
         
             
                    super().__init__()
         
     | 
| 646 | 
         
             
                    self.config = config
         
     | 
| 647 | 
         
             
                    model_config = config.model
         
     | 
| 648 | 
         
             
                    dec_config = config.model.decoder
         
     | 
| 649 | 
         
            -
                    train_config = config.training
         
     | 
| 650 | 
         
             
                    data_config = config.data
         
     | 
| 651 | 
         
            -
                    compute_dtype = _str_to_dtype(config.training.dtype)
         
     | 
| 652 | 
         
            -
                    weight_dtype = _str_to_dtype(config.model.weight_dtype)
         
     | 
| 653 | 
         
             
                    self.num_channels = data_config.channels
         
     | 
| 654 | 
         
             
                    self.num_layers = dec_config.n_layer
         
     | 
| 655 | 
         | 
| 656 | 
         
             
                    self.embeddings = nn.ModuleList(
         
     | 
| 657 | 
         
             
                        [
         
     | 
| 658 | 
         
            -
                            nn.Embedding( 
     | 
| 
         | 
|
| 
         | 
|
| 659 | 
         
             
                            for _ in range(self.num_channels)
         
     | 
| 660 | 
         
             
                        ]
         
     | 
| 661 | 
         
             
                    )
         
     | 
| 662 | 
         
            -
                    self. 
     | 
| 663 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 664 | 
         
             
                    self.norm = RMSNorm(
         
     | 
| 665 | 
         
             
                        dec_config.n_embd,
         
     | 
| 666 | 
         
             
                        eps=model_config.normalization_layer_epsilon,
         
     | 
| 667 | 
         
             
                        dtype=torch.float32,
         
     | 
| 668 | 
         
             
                    )
         
     | 
| 669 | 
         | 
| 670 | 
         
            -
                    # Final Logits Projection using DenseGeneral
         
     | 
| 671 | 
         
             
                    self.logits_dense = DenseGeneral(
         
     | 
| 672 | 
         
             
                        in_shapes=(dec_config.n_embd,),
         
     | 
| 673 | 
         
             
                        out_features=(self.num_channels, model_config.tgt_vocab_size),
         
     | 
| 674 | 
         
             
                        axis=(-1,),
         
     | 
| 675 | 
         
            -
                         
     | 
| 676 | 
         
            -
                        weight_dtype=weight_dtype,
         
     | 
| 677 | 
         
             
                    )
         
     | 
| 678 | 
         
            -
                    self.logits_in_fp32 = train_config.logits_dot_in_fp32
         
     | 
| 679 | 
         | 
| 680 | 
         
            -
                def  
     | 
| 681 | 
         
             
                    self,
         
     | 
| 682 | 
         
            -
                     
     | 
| 683 | 
         
            -
                     
     | 
| 684 | 
         
            -
                    src_positions: torch.Tensor | None,  # (B, S)
         
     | 
| 685 | 
         
             
                ) -> list[KVCache]:
         
     | 
| 686 | 
         
             
                    """
         
     | 
| 687 | 
         
             
                    Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
         
     | 
| 
         @@ -690,35 +533,21 @@ class Decoder(nn.Module): 
     | 
|
| 690 | 
         | 
| 691 | 
         
             
                    for layer in self.layers:
         
     | 
| 692 | 
         
             
                        cross_attn_module = layer.cross_attention
         
     | 
| 693 | 
         
            -
                        k_proj = cross_attn_module.k_proj( 
     | 
| 694 | 
         
            -
                        v_proj = cross_attn_module.v_proj( 
     | 
| 695 | 
         | 
| 696 | 
         
            -
                        k_proj = cross_attn_module.rotary_emb(k_proj, position= 
     | 
| 697 | 
         
             
                        k = k_proj.transpose(1, 2)
         
     | 
| 698 | 
         
             
                        v = v_proj.transpose(1, 2)
         
     | 
| 699 | 
         | 
| 700 | 
         
            -
                        per_layer_kv_cache.append(
         
     | 
| 701 | 
         
            -
                            KVCache(
         
     | 
| 702 | 
         
            -
                                cross_attn_module.num_kv_heads,
         
     | 
| 703 | 
         
            -
                                max_len,
         
     | 
| 704 | 
         
            -
                                cross_attn_module.head_dim,
         
     | 
| 705 | 
         
            -
                                k.device,
         
     | 
| 706 | 
         
            -
                                k=k,
         
     | 
| 707 | 
         
            -
                                v=v,
         
     | 
| 708 | 
         
            -
                            )
         
     | 
| 709 | 
         
            -
                        )
         
     | 
| 710 | 
         | 
| 711 | 
         
             
                    return per_layer_kv_cache
         
     | 
| 712 | 
         | 
| 713 | 
         
             
                def decode_step(
         
     | 
| 714 | 
         
             
                    self,
         
     | 
| 715 | 
         
             
                    tgt_ids_Bx1xC: torch.Tensor,  # [B, 1, C]
         
     | 
| 716 | 
         
            -
                     
     | 
| 717 | 
         
            -
                    encoder_out: torch.Tensor,  # [B, S, E]
         
     | 
| 718 | 
         
            -
                    self_attn_mask: Any,  # None
         
     | 
| 719 | 
         
            -
                    cross_attn_mask: torch.Tensor,  # [B, 1, 1, S]
         
     | 
| 720 | 
         
            -
                    self_attention_cache: list[KVCache],
         
     | 
| 721 | 
         
            -
                    cross_attention_cache: list[KVCache],
         
     | 
| 722 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 723 | 
         
             
                    """
         
     | 
| 724 | 
         
             
                    Performs a single decoding step, managing KV caches layer by layer.
         
     | 
| 
         @@ -727,7 +556,6 @@ class Decoder(nn.Module): 
     | 
|
| 727 | 
         
             
                        A tuple containing:
         
     | 
| 728 | 
         
             
                        - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
         
     | 
| 729 | 
         
             
                    """
         
     | 
| 730 | 
         
            -
                    assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
         
     | 
| 731 | 
         | 
| 732 | 
         
             
                    x = None
         
     | 
| 733 | 
         
             
                    for i in range(self.num_channels):
         
     | 
| 
         @@ -735,40 +563,23 @@ class Decoder(nn.Module): 
     | 
|
| 735 | 
         
             
                        channel_embed = self.embeddings[i](channel_tokens)
         
     | 
| 736 | 
         
             
                        x = channel_embed if x is None else x + channel_embed
         
     | 
| 737 | 
         | 
| 738 | 
         
            -
                    new_cache = []
         
     | 
| 739 | 
         
            -
             
     | 
| 740 | 
         
             
                    for i, layer in enumerate(self.layers):
         
     | 
| 741 | 
         
            -
                        self_cache =  
     | 
| 742 | 
         
            -
                        cross_cache =  
     | 
| 743 | 
         
            -
                        x 
     | 
| 744 | 
         
             
                            x,  # (2, 1, D)
         
     | 
| 745 | 
         
            -
                             
     | 
| 746 | 
         
            -
                            src_positions=None,  # CA KV is already computed
         
     | 
| 747 | 
         
            -
                            tgt_positions=tgt_pos_Bx1,  # (2, 1)
         
     | 
| 748 | 
         
            -
                            deterministic=True,
         
     | 
| 749 | 
         
            -
                            self_attn_mask=None,
         
     | 
| 750 | 
         
            -
                            cross_attn_mask=cross_attn_mask,
         
     | 
| 751 | 
         
             
                            self_attn_cache=self_cache,
         
     | 
| 752 | 
         
             
                            cross_attn_cache=cross_cache,
         
     | 
| 753 | 
         
             
                        )
         
     | 
| 754 | 
         
            -
                        new_cache.append(new_kv_cache)
         
     | 
| 755 | 
         | 
| 756 | 
         
             
                    x = self.norm(x)
         
     | 
| 757 | 
         
             
                    logits_Bx1xCxV = self.logits_dense(x)
         
     | 
| 758 | 
         | 
| 759 | 
         
            -
                    return logits_Bx1xCxV.to(torch.float32) 
     | 
| 760 | 
         | 
| 761 | 
         
             
                def forward(
         
     | 
| 762 | 
         
            -
                    self,
         
     | 
| 763 | 
         
            -
                    tgt_ids_BxTxC: torch.Tensor,
         
     | 
| 764 | 
         
            -
                    encoder_out: torch.Tensor,
         
     | 
| 765 | 
         
            -
                    tgt_positions: torch.Tensor,
         
     | 
| 766 | 
         
            -
                    src_positions: torch.Tensor,
         
     | 
| 767 | 
         
            -
                    deterministic: bool,
         
     | 
| 768 | 
         
            -
                    self_attn_mask: torch.Tensor,
         
     | 
| 769 | 
         
            -
                    cross_attn_mask: torch.Tensor,
         
     | 
| 770 | 
         
            -
                    self_attention_cache: list[KVCache],
         
     | 
| 771 | 
         
            -
                    cross_attention_cache: list[KVCache],
         
     | 
| 772 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 773 | 
         
             
                    """
         
     | 
| 774 | 
         
             
                    Forward pass for the Decoder stack, managing KV caches.
         
     | 
| 
         @@ -778,7 +589,6 @@ class Decoder(nn.Module): 
     | 
|
| 778 | 
         
             
                        encoder_out: Output from the encoder (B, S, E).
         
     | 
| 779 | 
         
             
                        tgt_positions: Positions for target sequence (B, T).
         
     | 
| 780 | 
         
             
                        src_positions: Positions for source sequence (B, S).
         
     | 
| 781 | 
         
            -
                        deterministic: Disable dropout if True.
         
     | 
| 782 | 
         
             
                        self_attn_mask: Mask for self-attention.
         
     | 
| 783 | 
         
             
                        cross_attn_mask: Mask for cross-attention.
         
     | 
| 784 | 
         
             
                        past_key_values: List containing the self-attention KV cache for each layer
         
     | 
| 
         @@ -804,20 +614,14 @@ class Decoder(nn.Module): 
     | 
|
| 804 | 
         
             
                        channel_embed = self.embeddings[i](channel_tokens)
         
     | 
| 805 | 
         
             
                        x = channel_embed if x is None else x + channel_embed
         
     | 
| 806 | 
         | 
| 807 | 
         
            -
                    if not deterministic:
         
     | 
| 808 | 
         
            -
                        x = self.dropout(x)
         
     | 
| 809 | 
         
            -
             
     | 
| 810 | 
         
             
                    for i, layer in enumerate(self.layers):
         
     | 
| 811 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 812 | 
         
             
                            x,
         
     | 
| 813 | 
         
            -
                             
     | 
| 814 | 
         
            -
                             
     | 
| 815 | 
         
            -
                             
     | 
| 816 | 
         
            -
                            deterministic=deterministic,
         
     | 
| 817 | 
         
            -
                            self_attn_mask=self_attn_mask,
         
     | 
| 818 | 
         
            -
                            cross_attn_mask=cross_attn_mask,
         
     | 
| 819 | 
         
            -
                            self_attn_cache=self_attention_cache[i],
         
     | 
| 820 | 
         
            -
                            cross_attn_cache=cross_attention_cache[i],
         
     | 
| 821 | 
         
             
                            prefill=True,
         
     | 
| 822 | 
         
             
                        )
         
     | 
| 823 | 
         | 
| 
         @@ -831,43 +635,8 @@ class Decoder(nn.Module): 
     | 
|
| 831 | 
         
             
            class DiaModel(nn.Module):
         
     | 
| 832 | 
         
             
                """PyTorch Dia Model using DenseGeneral."""
         
     | 
| 833 | 
         | 
| 834 | 
         
            -
                def __init__(self, config: DiaConfig):
         
     | 
| 835 | 
         
             
                    super().__init__()
         
     | 
| 836 | 
         
             
                    self.config = config
         
     | 
| 837 | 
         
            -
                    self.encoder = Encoder(config)
         
     | 
| 838 | 
         
            -
                    self.decoder = Decoder(config)
         
     | 
| 839 | 
         
            -
             
     | 
| 840 | 
         
            -
                def forward(
         
     | 
| 841 | 
         
            -
                    self,
         
     | 
| 842 | 
         
            -
                    src_BxS: torch.Tensor,
         
     | 
| 843 | 
         
            -
                    tgt_BxTxC: torch.Tensor,
         
     | 
| 844 | 
         
            -
                    src_positions: torch.Tensor | None = None,
         
     | 
| 845 | 
         
            -
                    tgt_positions: torch.Tensor | None = None,
         
     | 
| 846 | 
         
            -
                    enc_self_attn_mask: torch.Tensor | None = None,
         
     | 
| 847 | 
         
            -
                    dec_self_attn_mask: torch.Tensor | None = None,
         
     | 
| 848 | 
         
            -
                    dec_cross_attn_mask: torch.Tensor | None = None,
         
     | 
| 849 | 
         
            -
                    enable_dropout: bool = True,
         
     | 
| 850 | 
         
            -
                ):
         
     | 
| 851 | 
         
            -
                    deterministic = not enable_dropout
         
     | 
| 852 | 
         
            -
             
     | 
| 853 | 
         
            -
                    # --- Encoder Pass ---
         
     | 
| 854 | 
         
            -
                    encoder_out = self.encoder(
         
     | 
| 855 | 
         
            -
                        x_ids=src_BxS,
         
     | 
| 856 | 
         
            -
                        src_positions=src_positions,
         
     | 
| 857 | 
         
            -
                        deterministic=deterministic,
         
     | 
| 858 | 
         
            -
                        attn_mask=enc_self_attn_mask,
         
     | 
| 859 | 
         
            -
                    )
         
     | 
| 860 | 
         
            -
             
     | 
| 861 | 
         
            -
                    # --- Decoder Pass ---
         
     | 
| 862 | 
         
            -
                    logits, _ = self.decoder(
         
     | 
| 863 | 
         
            -
                        tgt_ids_BxTxC=tgt_BxTxC,
         
     | 
| 864 | 
         
            -
                        encoder_out=encoder_out,
         
     | 
| 865 | 
         
            -
                        tgt_positions=tgt_positions,
         
     | 
| 866 | 
         
            -
                        src_positions=src_positions,
         
     | 
| 867 | 
         
            -
                        deterministic=deterministic,
         
     | 
| 868 | 
         
            -
                        self_attn_mask=dec_self_attn_mask,
         
     | 
| 869 | 
         
            -
                        cross_attn_mask=dec_cross_attn_mask,
         
     | 
| 870 | 
         
            -
                        precomputed_cross_attn_kv=None,
         
     | 
| 871 | 
         
            -
                    )
         
     | 
| 872 | 
         
            -
             
     | 
| 873 | 
         
            -
                    return logits
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
             
            import torch.nn as nn
         
     | 
| 3 | 
         
             
            import torch.nn.functional as F
         
     | 
| 
         | 
|
| 5 | 
         
             
            from torch.nn import RMSNorm
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            from .config import DiaConfig
         
     | 
| 8 | 
         
            +
            from .state import DecoderInferenceState, EncoderInferenceState, KVCache
         
     | 
| 9 | 
         | 
| 10 | 
         | 
| 11 | 
         
             
            def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
         
     | 
| 12 | 
         
             
                return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
         
     | 
| 13 | 
         | 
| 14 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 15 | 
         
             
            class DenseGeneral(nn.Module):
         
     | 
| 16 | 
         
             
                """
         
     | 
| 17 | 
         
             
                PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
         
     | 
| 
         | 
|
| 35 | 
         
             
                    in_shapes: tuple[int, ...],
         
     | 
| 36 | 
         
             
                    out_features: tuple[int, ...],
         
     | 
| 37 | 
         
             
                    axis: tuple[int, ...] = (-1,),
         
     | 
| 
         | 
|
| 38 | 
         
             
                    weight_dtype: torch.dtype | None = None,
         
     | 
| 39 | 
         
             
                    device: torch.device | None = None,
         
     | 
| 40 | 
         
             
                ):
         
     | 
| 
         | 
|
| 42 | 
         
             
                    self.in_shapes = in_shapes
         
     | 
| 43 | 
         
             
                    self.out_features = out_features
         
     | 
| 44 | 
         
             
                    self.axis = axis
         
     | 
| 
         | 
|
| 45 | 
         
             
                    self.kernel_shape = self.in_shapes + self.out_features
         
     | 
| 46 | 
         | 
| 47 | 
         
             
                    factory_kwargs = {"device": device, "dtype": weight_dtype}
         
     | 
| 
         | 
|
| 53 | 
         
             
                    kernel_contract_axes = tuple(range(len(norm_axis)))
         
     | 
| 54 | 
         | 
| 55 | 
         
             
                    output = torch.tensordot(
         
     | 
| 56 | 
         
            +
                        inputs.to(self.weight.dtype),
         
     | 
| 57 | 
         
            +
                        self.weight,
         
     | 
| 58 | 
         
             
                        dims=(norm_axis, kernel_contract_axes),
         
     | 
| 59 | 
         
             
                    ).to(inputs.dtype)
         
     | 
| 60 | 
         
             
                    return output
         
     | 
| 61 | 
         | 
| 62 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 63 | 
         
             
            class MlpBlock(nn.Module):
         
     | 
| 64 | 
         
             
                """MLP block using DenseGeneral."""
         
     | 
| 65 | 
         | 
| 66 | 
         
             
                def __init__(
         
     | 
| 67 | 
         
            +
                    self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 68 | 
         
             
                ):
         
     | 
| 69 | 
         
             
                    super().__init__()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 70 | 
         
             
                    self.dtype = compute_dtype
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 71 | 
         | 
| 72 | 
         
             
                    self.wi_fused = DenseGeneral(
         
     | 
| 73 | 
         
             
                        in_shapes=(embed_dim,),
         
     | 
| 74 | 
         
            +
                        out_features=(2, intermediate_dim),
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 75 | 
         
             
                        axis=(-1,),
         
     | 
| 76 | 
         
            +
                        weight_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 77 | 
         
             
                    )
         
     | 
| 78 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 79 | 
         
             
                    self.wo = DenseGeneral(
         
     | 
| 80 | 
         
             
                        in_shapes=(intermediate_dim,),
         
     | 
| 81 | 
         
             
                        out_features=(embed_dim,),
         
     | 
| 82 | 
         
             
                        axis=(-1,),
         
     | 
| 83 | 
         
            +
                        weight_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 84 | 
         
             
                    )
         
     | 
| 85 | 
         | 
| 86 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 87 | 
         
             
                    """Forward pass."""
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 88 | 
         
             
                    fused_x = self.wi_fused(x)
         
     | 
| 89 | 
         | 
| 90 | 
         
            +
                    gate = fused_x[..., 0, :]
         
     | 
| 91 | 
         
            +
                    up = fused_x[..., 1, :]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 92 | 
         | 
| 93 | 
         
            +
                    hidden = torch.mul(F.silu(gate), up).to(self.dtype)
         
     | 
| 
         | 
|
| 94 | 
         | 
| 95 | 
         
             
                    output = self.wo(hidden)
         
     | 
| 96 | 
         
             
                    return output
         
     | 
| 
         | 
|
| 139 | 
         
             
                    return torch.cat((first_part, second_part), dim=-1)
         
     | 
| 140 | 
         | 
| 141 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 142 | 
         
             
            class Attention(nn.Module):
         
     | 
| 143 | 
         
             
                """Attention using DenseGeneral."""
         
     | 
| 144 | 
         | 
| 
         | 
|
| 150 | 
         
             
                    num_query_heads: int,
         
     | 
| 151 | 
         
             
                    num_kv_heads: int,
         
     | 
| 152 | 
         
             
                    head_dim: int,
         
     | 
| 153 | 
         
            +
                    compute_dtype: torch.dtype,
         
     | 
| 154 | 
         
             
                    is_cross_attn: bool = False,
         
     | 
| 155 | 
         
             
                    out_embed_dim: int | None = None,
         
     | 
| 156 | 
         
             
                ):
         
     | 
| 
         | 
|
| 159 | 
         
             
                    self.num_kv_heads = num_kv_heads
         
     | 
| 160 | 
         
             
                    self.head_dim = head_dim
         
     | 
| 161 | 
         
             
                    self.is_cross_attn = is_cross_attn
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 162 | 
         
             
                    self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
         
     | 
| 163 | 
         
             
                    self.projected_query_dim = num_query_heads * head_dim
         
     | 
| 164 | 
         
             
                    if num_query_heads % num_kv_heads != 0:
         
     | 
| 165 | 
         
            +
                        raise ValueError(
         
     | 
| 166 | 
         
            +
                            f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
         
     | 
| 167 | 
         
            +
                        )
         
     | 
| 168 | 
         
             
                    self.num_gqa_groups = num_query_heads // num_kv_heads
         
     | 
| 169 | 
         | 
| 170 | 
         
             
                    # --- Projection Layers using DenseGeneral ---
         
     | 
| 
         | 
|
| 172 | 
         
             
                        in_shapes=(q_embed_dim,),
         
     | 
| 173 | 
         
             
                        out_features=(num_query_heads, head_dim),
         
     | 
| 174 | 
         
             
                        axis=(-1,),
         
     | 
| 175 | 
         
            +
                        weight_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 176 | 
         
             
                    )
         
     | 
| 177 | 
         
             
                    self.k_proj = DenseGeneral(
         
     | 
| 178 | 
         
             
                        in_shapes=(kv_embed_dim,),
         
     | 
| 179 | 
         
             
                        out_features=(num_kv_heads, head_dim),
         
     | 
| 180 | 
         
             
                        axis=(-1,),
         
     | 
| 181 | 
         
            +
                        weight_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 182 | 
         
             
                    )
         
     | 
| 183 | 
         
             
                    self.v_proj = DenseGeneral(
         
     | 
| 184 | 
         
             
                        in_shapes=(kv_embed_dim,),
         
     | 
| 185 | 
         
             
                        out_features=(num_kv_heads, head_dim),
         
     | 
| 186 | 
         
             
                        axis=(-1,),
         
     | 
| 187 | 
         
            +
                        weight_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 188 | 
         
             
                    )
         
     | 
| 189 | 
         
             
                    self.o_proj = DenseGeneral(
         
     | 
| 190 | 
         
             
                        in_shapes=(num_query_heads, head_dim),
         
     | 
| 191 | 
         
             
                        out_features=(self.output_dim,),
         
     | 
| 192 | 
         
             
                        axis=(-2, -1),
         
     | 
| 193 | 
         
            +
                        weight_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 194 | 
         
             
                    )
         
     | 
| 195 | 
         | 
| 196 | 
         
             
                    # --- Rotary Embedding ---
         
     | 
| 
         | 
|
| 207 | 
         
             
                    Xkv: torch.Tensor,  # (B, S, E) S = 1 in AR generation
         
     | 
| 208 | 
         
             
                    q_positions: torch.Tensor,  # (B, T)
         
     | 
| 209 | 
         
             
                    kv_positions: torch.Tensor | None = None,  # (B, S)
         
     | 
| 210 | 
         
            +
                    attn_mask: torch.Tensor
         
     | 
| 211 | 
         
            +
                    | None = None,  # None in Decoder Self Attention, Valid mask in Others
         
     | 
| 212 | 
         
             
                    cache: KVCache | None = None,  # None in Encoder, KVCache in Decoder
         
     | 
| 213 | 
         
            +
                    prefill: bool = False,
         
     | 
| 214 | 
         
            +
                    is_causal: bool = False,
         
     | 
| 215 | 
         
             
                ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
         
     | 
| 216 | 
         
             
                    """
         
     | 
| 217 | 
         
             
                    Performs attention calculation with optional KV caching.
         
     | 
| 
         | 
|
| 221 | 
         
             
                        Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
         
     | 
| 222 | 
         
             
                        q_positions: Positions for queries (B, T).
         
     | 
| 223 | 
         
             
                        kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
         
     | 
| 
         | 
|
| 224 | 
         
             
                        attn_mask: Attention mask.
         
     | 
| 225 | 
         
             
                        cache: KVCache.
         
     | 
| 226 | 
         
             
                        prefill: If True, use prefill mode.
         
     | 
| 
         | 
|
| 238 | 
         
             
                    Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
         
     | 
| 239 | 
         
             
                    Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
         
     | 
| 240 | 
         | 
| 
         | 
|
| 241 | 
         
             
                    attn_k: torch.Tensor | None = None
         
     | 
| 242 | 
         
             
                    attn_v: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 243 | 
         | 
| 
         | 
|
| 244 | 
         
             
                    if self.is_cross_attn:
         
     | 
| 
         | 
|
| 245 | 
         
             
                        attn_k, attn_v = cache.k, cache.v
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 246 | 
         
             
                    else:
         
     | 
| 247 | 
         
             
                        Xk_BxSxKxH = self.k_proj(Xkv)  # (B, S, K, H)
         
     | 
| 248 | 
         
             
                        Xv_BxSxKxH = self.v_proj(Xkv)  # (B, S, K, H)
         
     | 
| 249 | 
         
            +
                        Xk_BxSxKxH = self.rotary_emb(
         
     | 
| 250 | 
         
            +
                            Xk_BxSxKxH, position=kv_positions
         
     | 
| 251 | 
         
            +
                        )  # (B, S, K, H)
         
     | 
| 252 | 
         | 
| 253 | 
         
             
                        Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2)  # (B, K, S, H)
         
     | 
| 254 | 
         
             
                        Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2)  # (B, K, S, H)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 255 | 
         | 
| 
         | 
|
| 256 | 
         
             
                        if cache is None:
         
     | 
| 257 | 
         
            +
                            attn_k = Xk_BxKxSxH
         
     | 
| 258 | 
         
            +
                            attn_v = Xv_BxKxSxH
         
     | 
| 
         | 
|
| 259 | 
         
             
                        else:
         
     | 
| 
         | 
|
| 260 | 
         
             
                            if prefill:
         
     | 
| 261 | 
         
            +
                                attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
         
     | 
| 262 | 
         
            +
                                cache.prefill(attn_k, attn_v)
         
     | 
| 
         | 
|
| 263 | 
         
             
                            else:
         
     | 
| 264 | 
         
            +
                                attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
         
     | 
| 
         | 
|
| 265 | 
         | 
| 266 | 
         
             
                    attn_output = F.scaled_dot_product_attention(
         
     | 
| 267 | 
         
             
                        Xq_BxNxTxH,
         
     | 
| 268 | 
         
             
                        attn_k,
         
     | 
| 269 | 
         
             
                        attn_v,
         
     | 
| 270 | 
         
             
                        attn_mask=attn_mask,
         
     | 
| 
         | 
|
| 271 | 
         
             
                        scale=1.0,
         
     | 
| 272 | 
         
            +
                        enable_gqa=self.num_gqa_groups > 1,
         
     | 
| 273 | 
         
            +
                        is_causal=is_causal,
         
     | 
| 274 | 
         
             
                    )
         
     | 
| 275 | 
         | 
| 276 | 
         
             
                    attn_output = attn_output.transpose(1, 2).contiguous()  # (B, T, N, H)
         
     | 
| 277 | 
         
             
                    output = self.o_proj(attn_output)
         
     | 
| 278 | 
         | 
| 279 | 
         
            +
                    return output.to(original_dtype)
         
     | 
| 280 | 
         | 
| 281 | 
         | 
| 282 | 
         
             
            class EncoderLayer(nn.Module):
         
     | 
| 283 | 
         
             
                """Transformer Encoder Layer using DenseGeneral."""
         
     | 
| 284 | 
         | 
| 285 | 
         
            +
                def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
         
     | 
| 286 | 
         
             
                    super().__init__()
         
     | 
| 287 | 
         
             
                    self.config = config
         
     | 
| 288 | 
         
             
                    model_config = config.model
         
     | 
| 
         | 
|
| 295 | 
         
             
                        dtype=torch.float32,
         
     | 
| 296 | 
         
             
                    )
         
     | 
| 297 | 
         
             
                    self.self_attention = Attention(
         
     | 
| 298 | 
         
            +
                        config,
         
     | 
| 299 | 
         
             
                        q_embed_dim=embed_dim,
         
     | 
| 300 | 
         
             
                        kv_embed_dim=embed_dim,
         
     | 
| 301 | 
         
             
                        num_query_heads=enc_config.n_head,
         
     | 
| 302 | 
         
             
                        num_kv_heads=enc_config.n_head,
         
     | 
| 303 | 
         
             
                        head_dim=enc_config.head_dim,
         
     | 
| 304 | 
         
            +
                        compute_dtype=compute_dtype,
         
     | 
| 305 | 
         
             
                        is_cross_attn=False,
         
     | 
| 306 | 
         
             
                        out_embed_dim=embed_dim,
         
     | 
| 307 | 
         
             
                    )
         
     | 
| 
         | 
|
| 311 | 
         
             
                        dtype=torch.float32,
         
     | 
| 312 | 
         
             
                    )
         
     | 
| 313 | 
         
             
                    self.mlp = MlpBlock(
         
     | 
| 
         | 
|
| 314 | 
         
             
                        embed_dim=embed_dim,
         
     | 
| 315 | 
         
             
                        intermediate_dim=enc_config.n_hidden,
         
     | 
| 316 | 
         
            +
                        compute_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 
         | 
|
| 317 | 
         
             
                    )
         
     | 
| 
         | 
|
| 318 | 
         | 
| 319 | 
         
             
                def forward(
         
     | 
| 320 | 
         
             
                    self,
         
     | 
| 321 | 
         
             
                    x: torch.Tensor,
         
     | 
| 322 | 
         
            +
                    state: EncoderInferenceState,
         
     | 
| 
         | 
|
| 
         | 
|
| 323 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 324 | 
         
             
                    residual = x
         
     | 
| 325 | 
         
             
                    x_norm = self.pre_sa_norm(x)
         
     | 
| 326 | 
         
            +
                    sa_out = self.self_attention(
         
     | 
| 
         | 
|
| 327 | 
         
             
                        Xq=x_norm,
         
     | 
| 328 | 
         
             
                        Xkv=x_norm,
         
     | 
| 329 | 
         
            +
                        q_positions=state.positions,
         
     | 
| 330 | 
         
            +
                        kv_positions=state.positions,
         
     | 
| 331 | 
         
            +
                        attn_mask=state.attn_mask,
         
     | 
| 
         | 
|
| 332 | 
         
             
                    )
         
     | 
| 333 | 
         
             
                    x = residual + sa_out
         
     | 
| 334 | 
         | 
| 335 | 
         
             
                    residual = x
         
     | 
| 336 | 
         
             
                    x_norm = self.post_sa_norm(x)
         
     | 
| 337 | 
         
            +
                    mlp_out = self.mlp(x_norm)
         
     | 
| 338 | 
         
             
                    x = residual + mlp_out
         
     | 
| 339 | 
         | 
| 
         | 
|
| 
         | 
|
| 340 | 
         
             
                    return x
         
     | 
| 341 | 
         | 
| 342 | 
         | 
| 343 | 
         
             
            class Encoder(nn.Module):
         
     | 
| 344 | 
         
             
                """Transformer Encoder Stack using DenseGeneral."""
         
     | 
| 345 | 
         | 
| 346 | 
         
            +
                def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
         
     | 
| 347 | 
         
             
                    super().__init__()
         
     | 
| 348 | 
         
             
                    self.config = config
         
     | 
| 349 | 
         
             
                    model_config = config.model
         
     | 
| 350 | 
         
             
                    enc_config = config.model.encoder
         
     | 
| 
         | 
|
| 351 | 
         | 
| 352 | 
         
             
                    self.embedding = nn.Embedding(
         
     | 
| 353 | 
         
             
                        model_config.src_vocab_size,
         
     | 
| 354 | 
         
             
                        enc_config.n_embd,
         
     | 
| 355 | 
         
             
                        dtype=compute_dtype,
         
     | 
| 356 | 
         
             
                    )
         
     | 
| 357 | 
         
            +
                    self.layers = nn.ModuleList(
         
     | 
| 358 | 
         
            +
                        [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
         
     | 
| 359 | 
         
            +
                    )
         
     | 
| 360 | 
         
             
                    self.norm = RMSNorm(
         
     | 
| 361 | 
         
             
                        enc_config.n_embd,
         
     | 
| 362 | 
         
             
                        eps=model_config.normalization_layer_epsilon,
         
     | 
| 
         | 
|
| 366 | 
         
             
                def forward(
         
     | 
| 367 | 
         
             
                    self,
         
     | 
| 368 | 
         
             
                    x_ids: torch.Tensor,
         
     | 
| 369 | 
         
            +
                    state: EncoderInferenceState,
         
     | 
| 
         | 
|
| 
         | 
|
| 370 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 371 | 
         
             
                    x = self.embedding(x_ids)
         
     | 
| 372 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 373 | 
         
             
                    for layer in self.layers:
         
     | 
| 374 | 
         
            +
                        x = layer(x, state)
         
     | 
| 375 | 
         
            +
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                    x = self.norm(x)
         
     | 
| 
         | 
|
| 
         | 
|
| 377 | 
         
             
                    return x
         
     | 
| 378 | 
         | 
| 379 | 
         | 
| 380 | 
         
             
            class DecoderLayer(nn.Module):
         
     | 
| 381 | 
         
             
                """Transformer Decoder Layer using DenseGeneral."""
         
     | 
| 382 | 
         | 
| 383 | 
         
            +
                def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
         
     | 
| 384 | 
         
             
                    super().__init__()
         
     | 
| 385 | 
         
             
                    self.config = config
         
     | 
| 386 | 
         
             
                    model_config = config.model
         
     | 
| 
         | 
|
| 408 | 
         | 
| 409 | 
         
             
                    # Self-Attention (GQA) with Causal Masking
         
     | 
| 410 | 
         
             
                    self.self_attention = Attention(
         
     | 
| 411 | 
         
            +
                        config,
         
     | 
| 412 | 
         
             
                        q_embed_dim=dec_embed_dim,
         
     | 
| 413 | 
         
             
                        kv_embed_dim=dec_embed_dim,
         
     | 
| 414 | 
         
             
                        num_query_heads=dec_config.gqa_query_heads,
         
     | 
| 415 | 
         
             
                        num_kv_heads=dec_config.kv_heads,
         
     | 
| 416 | 
         
             
                        head_dim=dec_config.gqa_head_dim,
         
     | 
| 417 | 
         
            +
                        compute_dtype=compute_dtype,
         
     | 
| 418 | 
         
             
                        is_cross_attn=False,
         
     | 
| 419 | 
         
             
                        out_embed_dim=dec_embed_dim,
         
     | 
| 420 | 
         
             
                    )
         
     | 
| 
         | 
|
| 426 | 
         
             
                        num_query_heads=dec_config.cross_query_heads,
         
     | 
| 427 | 
         
             
                        num_kv_heads=dec_config.cross_query_heads,
         
     | 
| 428 | 
         
             
                        head_dim=dec_config.cross_head_dim,
         
     | 
| 429 | 
         
            +
                        compute_dtype=compute_dtype,
         
     | 
| 430 | 
         
             
                        is_cross_attn=True,
         
     | 
| 431 | 
         
             
                        out_embed_dim=dec_embed_dim,
         
     | 
| 432 | 
         
             
                    )
         
     | 
| 433 | 
         
             
                    # MLP
         
     | 
| 434 | 
         
             
                    self.mlp = MlpBlock(
         
     | 
| 
         | 
|
| 435 | 
         
             
                        embed_dim=dec_embed_dim,
         
     | 
| 436 | 
         
             
                        intermediate_dim=dec_config.n_hidden,
         
     | 
| 437 | 
         
            +
                        compute_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 
         | 
|
| 438 | 
         
             
                    )
         
     | 
| 439 | 
         | 
| 440 | 
         
             
                def forward(
         
     | 
| 441 | 
         
             
                    self,
         
     | 
| 442 | 
         
             
                    x: torch.Tensor,
         
     | 
| 443 | 
         
            +
                    state: DecoderInferenceState,
         
     | 
| 444 | 
         
            +
                    self_attn_cache: KVCache | None = None,
         
     | 
| 445 | 
         
            +
                    cross_attn_cache: KVCache | None = None,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 446 | 
         
             
                    prefill: bool = False,
         
     | 
| 447 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 448 | 
         
             
                    residual = x
         
     | 
| 449 | 
         
             
                    x_norm = self.pre_sa_norm(x)
         
     | 
| 450 | 
         | 
| 451 | 
         
            +
                    sa_out = self.self_attention(
         
     | 
| 452 | 
         
             
                        Xq=x_norm,  # (2, 1, D)
         
     | 
| 453 | 
         
             
                        Xkv=x_norm,  # (2, 1, D)
         
     | 
| 454 | 
         
            +
                        q_positions=state.dec_positions,  # (2, 1)
         
     | 
| 455 | 
         
            +
                        kv_positions=state.dec_positions,  # (2, 1)
         
     | 
| 456 | 
         
            +
                        attn_mask=None,
         
     | 
| 
         | 
|
| 457 | 
         
             
                        cache=self_attn_cache,
         
     | 
| 458 | 
         
             
                        prefill=prefill,
         
     | 
| 459 | 
         
            +
                        is_causal=prefill,
         
     | 
| 460 | 
         
             
                    )
         
     | 
| 461 | 
         | 
| 462 | 
         
             
                    x = residual + sa_out
         
     | 
| 463 | 
         | 
| 
         | 
|
| 464 | 
         
             
                    residual = x
         
     | 
| 465 | 
         
             
                    x_norm = self.pre_ca_norm(x)
         
     | 
| 466 | 
         
            +
                    ca_out = self.cross_attention(
         
     | 
| 467 | 
         
             
                        Xq=x_norm,
         
     | 
| 468 | 
         
            +
                        Xkv=state.enc_out,
         
     | 
| 469 | 
         
            +
                        q_positions=state.dec_positions,
         
     | 
| 470 | 
         
            +
                        kv_positions=state.enc_positions,
         
     | 
| 471 | 
         
            +
                        attn_mask=state.dec_cross_attn_mask,
         
     | 
| 
         | 
|
| 472 | 
         
             
                        cache=cross_attn_cache,
         
     | 
| 473 | 
         
             
                    )
         
     | 
| 474 | 
         
             
                    x = residual + ca_out
         
     | 
| 475 | 
         | 
| 
         | 
|
| 476 | 
         
             
                    residual = x
         
     | 
| 477 | 
         
             
                    x_norm = self.pre_mlp_norm(x)
         
     | 
| 478 | 
         
            +
                    mlp_out = self.mlp(x_norm)
         
     | 
| 479 | 
         
             
                    x = residual + mlp_out
         
     | 
| 480 | 
         | 
| 481 | 
         
            +
                    return x
         
     | 
| 482 | 
         | 
| 483 | 
         | 
| 484 | 
         
             
            class Decoder(nn.Module):
         
     | 
| 485 | 
         
             
                """Transformer Decoder Stack using DenseGeneral."""
         
     | 
| 486 | 
         | 
| 487 | 
         
            +
                def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
         
     | 
| 488 | 
         
             
                    super().__init__()
         
     | 
| 489 | 
         
             
                    self.config = config
         
     | 
| 490 | 
         
             
                    model_config = config.model
         
     | 
| 491 | 
         
             
                    dec_config = config.model.decoder
         
     | 
| 
         | 
|
| 492 | 
         
             
                    data_config = config.data
         
     | 
| 
         | 
|
| 
         | 
|
| 493 | 
         
             
                    self.num_channels = data_config.channels
         
     | 
| 494 | 
         
             
                    self.num_layers = dec_config.n_layer
         
     | 
| 495 | 
         | 
| 496 | 
         
             
                    self.embeddings = nn.ModuleList(
         
     | 
| 497 | 
         
             
                        [
         
     | 
| 498 | 
         
            +
                            nn.Embedding(
         
     | 
| 499 | 
         
            +
                                model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
         
     | 
| 500 | 
         
            +
                            )
         
     | 
| 501 | 
         
             
                            for _ in range(self.num_channels)
         
     | 
| 502 | 
         
             
                        ]
         
     | 
| 503 | 
         
             
                    )
         
     | 
| 504 | 
         
            +
                    self.layers = nn.ModuleList(
         
     | 
| 505 | 
         
            +
                        [
         
     | 
| 506 | 
         
            +
                            DecoderLayer(config=config, compute_dtype=compute_dtype)
         
     | 
| 507 | 
         
            +
                            for _ in range(self.num_layers)
         
     | 
| 508 | 
         
            +
                        ]
         
     | 
| 509 | 
         
            +
                    )
         
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
             
                    self.norm = RMSNorm(
         
     | 
| 512 | 
         
             
                        dec_config.n_embd,
         
     | 
| 513 | 
         
             
                        eps=model_config.normalization_layer_epsilon,
         
     | 
| 514 | 
         
             
                        dtype=torch.float32,
         
     | 
| 515 | 
         
             
                    )
         
     | 
| 516 | 
         | 
| 
         | 
|
| 517 | 
         
             
                    self.logits_dense = DenseGeneral(
         
     | 
| 518 | 
         
             
                        in_shapes=(dec_config.n_embd,),
         
     | 
| 519 | 
         
             
                        out_features=(self.num_channels, model_config.tgt_vocab_size),
         
     | 
| 520 | 
         
             
                        axis=(-1,),
         
     | 
| 521 | 
         
            +
                        weight_dtype=compute_dtype,
         
     | 
| 
         | 
|
| 522 | 
         
             
                    )
         
     | 
| 
         | 
|
| 523 | 
         | 
| 524 | 
         
            +
                def precompute_cross_attn_cache(
         
     | 
| 525 | 
         
             
                    self,
         
     | 
| 526 | 
         
            +
                    enc_out: torch.Tensor,  # (B, S, E)
         
     | 
| 527 | 
         
            +
                    enc_positions: torch.Tensor,  # (B, S)
         
     | 
| 
         | 
|
| 528 | 
         
             
                ) -> list[KVCache]:
         
     | 
| 529 | 
         
             
                    """
         
     | 
| 530 | 
         
             
                    Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
         
     | 
| 
         | 
|
| 533 | 
         | 
| 534 | 
         
             
                    for layer in self.layers:
         
     | 
| 535 | 
         
             
                        cross_attn_module = layer.cross_attention
         
     | 
| 536 | 
         
            +
                        k_proj = cross_attn_module.k_proj(enc_out)
         
     | 
| 537 | 
         
            +
                        v_proj = cross_attn_module.v_proj(enc_out)
         
     | 
| 538 | 
         | 
| 539 | 
         
            +
                        k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
         
     | 
| 540 | 
         
             
                        k = k_proj.transpose(1, 2)
         
     | 
| 541 | 
         
             
                        v = v_proj.transpose(1, 2)
         
     | 
| 542 | 
         | 
| 543 | 
         
            +
                        per_layer_kv_cache.append(KVCache.from_kv(k, v))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 544 | 
         | 
| 545 | 
         
             
                    return per_layer_kv_cache
         
     | 
| 546 | 
         | 
| 547 | 
         
             
                def decode_step(
         
     | 
| 548 | 
         
             
                    self,
         
     | 
| 549 | 
         
             
                    tgt_ids_Bx1xC: torch.Tensor,  # [B, 1, C]
         
     | 
| 550 | 
         
            +
                    state: DecoderInferenceState,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 551 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 552 | 
         
             
                    """
         
     | 
| 553 | 
         
             
                    Performs a single decoding step, managing KV caches layer by layer.
         
     | 
| 
         | 
|
| 556 | 
         
             
                        A tuple containing:
         
     | 
| 557 | 
         
             
                        - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
         
     | 
| 558 | 
         
             
                    """
         
     | 
| 
         | 
|
| 559 | 
         | 
| 560 | 
         
             
                    x = None
         
     | 
| 561 | 
         
             
                    for i in range(self.num_channels):
         
     | 
| 
         | 
|
| 563 | 
         
             
                        channel_embed = self.embeddings[i](channel_tokens)
         
     | 
| 564 | 
         
             
                        x = channel_embed if x is None else x + channel_embed
         
     | 
| 565 | 
         | 
| 
         | 
|
| 
         | 
|
| 566 | 
         
             
                    for i, layer in enumerate(self.layers):
         
     | 
| 567 | 
         
            +
                        self_cache = state.self_attn_cache[i]
         
     | 
| 568 | 
         
            +
                        cross_cache = state.cross_attn_cache[i]
         
     | 
| 569 | 
         
            +
                        x = layer(
         
     | 
| 570 | 
         
             
                            x,  # (2, 1, D)
         
     | 
| 571 | 
         
            +
                            state,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 572 | 
         
             
                            self_attn_cache=self_cache,
         
     | 
| 573 | 
         
             
                            cross_attn_cache=cross_cache,
         
     | 
| 574 | 
         
             
                        )
         
     | 
| 
         | 
|
| 575 | 
         | 
| 576 | 
         
             
                    x = self.norm(x)
         
     | 
| 577 | 
         
             
                    logits_Bx1xCxV = self.logits_dense(x)
         
     | 
| 578 | 
         | 
| 579 | 
         
            +
                    return logits_Bx1xCxV.to(torch.float32)
         
     | 
| 580 | 
         | 
| 581 | 
         
             
                def forward(
         
     | 
| 582 | 
         
            +
                    self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 583 | 
         
             
                ) -> torch.Tensor:
         
     | 
| 584 | 
         
             
                    """
         
     | 
| 585 | 
         
             
                    Forward pass for the Decoder stack, managing KV caches.
         
     | 
| 
         | 
|
| 589 | 
         
             
                        encoder_out: Output from the encoder (B, S, E).
         
     | 
| 590 | 
         
             
                        tgt_positions: Positions for target sequence (B, T).
         
     | 
| 591 | 
         
             
                        src_positions: Positions for source sequence (B, S).
         
     | 
| 
         | 
|
| 592 | 
         
             
                        self_attn_mask: Mask for self-attention.
         
     | 
| 593 | 
         
             
                        cross_attn_mask: Mask for cross-attention.
         
     | 
| 594 | 
         
             
                        past_key_values: List containing the self-attention KV cache for each layer
         
     | 
| 
         | 
|
| 614 | 
         
             
                        channel_embed = self.embeddings[i](channel_tokens)
         
     | 
| 615 | 
         
             
                        x = channel_embed if x is None else x + channel_embed
         
     | 
| 616 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 617 | 
         
             
                    for i, layer in enumerate(self.layers):
         
     | 
| 618 | 
         
            +
                        self_cache = state.self_attn_cache[i]
         
     | 
| 619 | 
         
            +
                        cross_cache = state.cross_attn_cache[i]
         
     | 
| 620 | 
         
            +
                        x = layer(
         
     | 
| 621 | 
         
             
                            x,
         
     | 
| 622 | 
         
            +
                            state,
         
     | 
| 623 | 
         
            +
                            self_attn_cache=self_cache,
         
     | 
| 624 | 
         
            +
                            cross_attn_cache=cross_cache,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 625 | 
         
             
                            prefill=True,
         
     | 
| 626 | 
         
             
                        )
         
     | 
| 627 | 
         | 
| 
         | 
|
| 635 | 
         
             
            class DiaModel(nn.Module):
         
     | 
| 636 | 
         
             
                """PyTorch Dia Model using DenseGeneral."""
         
     | 
| 637 | 
         | 
| 638 | 
         
            +
                def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
         
     | 
| 639 | 
         
             
                    super().__init__()
         
     | 
| 640 | 
         
             
                    self.config = config
         
     | 
| 641 | 
         
            +
                    self.encoder = Encoder(config, compute_dtype)
         
     | 
| 642 | 
         
            +
                    self.decoder = Decoder(config, compute_dtype)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        dia/model.py
    CHANGED
    
    | 
         @@ -1,26 +1,46 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import dac
         
     | 
| 2 | 
         
             
            import numpy as np
         
     | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
             
            import torchaudio
         
     | 
| 5 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 6 | 
         | 
| 7 | 
         
            -
            from .audio import  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            from .config import DiaConfig
         
     | 
| 9 | 
         
            -
            from .layers import DiaModel 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 10 | 
         | 
| 11 | 
         | 
| 12 | 
         
             
            def _sample_next_token(
         
     | 
| 13 | 
         
             
                logits_BCxV: torch.Tensor,
         
     | 
| 14 | 
         
             
                temperature: float,
         
     | 
| 15 | 
         
             
                top_p: float,
         
     | 
| 16 | 
         
            -
                use_cfg_filter: bool,
         
     | 
| 17 | 
         
             
                cfg_filter_top_k: int | None = None,
         
     | 
| 18 | 
         
             
            ) -> torch.Tensor:
         
     | 
| 19 | 
         
             
                if temperature == 0.0:
         
     | 
| 20 | 
         
             
                    return torch.argmax(logits_BCxV, dim=-1)
         
     | 
| 21 | 
         | 
| 22 | 
         
             
                logits_BCxV = logits_BCxV / temperature
         
     | 
| 23 | 
         
            -
                if  
     | 
| 24 | 
         
             
                    _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
         
     | 
| 25 | 
         
             
                    mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
         
     | 
| 26 | 
         
             
                    mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
         
     | 
| 
         @@ -28,17 +48,21 @@ def _sample_next_token( 
     | 
|
| 28 | 
         | 
| 29 | 
         
             
                if top_p < 1.0:
         
     | 
| 30 | 
         
             
                    probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
         
     | 
| 31 | 
         
            -
                    sorted_probs_BCxV, sorted_indices_BCxV = torch.sort( 
     | 
| 
         | 
|
| 
         | 
|
| 32 | 
         
             
                    cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
         
     | 
| 33 | 
         | 
| 34 | 
         
            -
                    # Calculate indices to remove based on top_p
         
     | 
| 35 | 
         
             
                    sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
         
     | 
| 36 | 
         
            -
                     
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                     
     | 
| 
         | 
|
| 39 | 
         | 
| 40 | 
         
             
                    indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
         
     | 
| 41 | 
         
            -
                    indices_to_remove_BCxV.scatter_( 
     | 
| 
         | 
|
| 
         | 
|
| 42 | 
         
             
                    logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
         
     | 
| 43 | 
         | 
| 44 | 
         
             
                final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
         
     | 
| 
         @@ -48,31 +72,61 @@ def _sample_next_token( 
     | 
|
| 48 | 
         
             
                return sampled_indices_C
         
     | 
| 49 | 
         | 
| 50 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 51 | 
         
             
            class Dia:
         
     | 
| 52 | 
         
            -
                def __init__( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 53 | 
         
             
                    """Initializes the Dia model.
         
     | 
| 54 | 
         | 
| 55 | 
         
             
                    Args:
         
     | 
| 56 | 
         
             
                        config: The configuration object for the model.
         
     | 
| 57 | 
         
            -
                        device: The device to load the model onto.
         
     | 
| 58 | 
         | 
| 59 | 
         
             
                    Raises:
         
     | 
| 60 | 
         
             
                        RuntimeError: If there is an error loading the DAC model.
         
     | 
| 61 | 
         
             
                    """
         
     | 
| 62 | 
         
             
                    super().__init__()
         
     | 
| 63 | 
         
             
                    self.config = config
         
     | 
| 64 | 
         
            -
                    self.device = device
         
     | 
| 65 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 66 | 
         
             
                    self.dac_model = None
         
     | 
| 67 | 
         | 
| 68 | 
         
             
                @classmethod
         
     | 
| 69 | 
         
            -
                def from_local( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 70 | 
         
             
                    """Loads the Dia model from local configuration and checkpoint files.
         
     | 
| 71 | 
         | 
| 72 | 
         
             
                    Args:
         
     | 
| 73 | 
         
             
                        config_path: Path to the configuration JSON file.
         
     | 
| 74 | 
         
             
                        checkpoint_path: Path to the model checkpoint (.pth) file.
         
     | 
| 75 | 
         
            -
                        device: The device to load the model onto.
         
     | 
| 76 | 
         | 
| 77 | 
         
             
                    Returns:
         
     | 
| 78 | 
         
             
                        An instance of the Dia model loaded with weights and set to eval mode.
         
     | 
| 
         @@ -85,23 +139,29 @@ class Dia: 
     | 
|
| 85 | 
         
             
                    if config is None:
         
     | 
| 86 | 
         
             
                        raise FileNotFoundError(f"Config file not found at {config_path}")
         
     | 
| 87 | 
         | 
| 88 | 
         
            -
                    dia = cls(config, device)
         
     | 
| 89 | 
         | 
| 90 | 
         
             
                    try:
         
     | 
| 91 | 
         
            -
                         
     | 
| 
         | 
|
| 92 | 
         
             
                    except FileNotFoundError:
         
     | 
| 93 | 
         
             
                        raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
         
     | 
| 94 | 
         
             
                    except Exception as e:
         
     | 
| 95 | 
         
            -
                        raise RuntimeError( 
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         | 
| 97 | 
         
            -
                    dia.model.to(device)
         
     | 
| 98 | 
         
             
                    dia.model.eval()
         
     | 
| 99 | 
         
             
                    dia._load_dac_model()
         
     | 
| 100 | 
         
             
                    return dia
         
     | 
| 101 | 
         | 
| 102 | 
         
             
                @classmethod
         
     | 
| 103 | 
         
             
                def from_pretrained(
         
     | 
| 104 | 
         
            -
                    cls, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 105 | 
         
             
                ) -> "Dia":
         
     | 
| 106 | 
         
             
                    """Loads the Dia model from a Hugging Face Hub repository.
         
     | 
| 107 | 
         | 
| 
         @@ -110,7 +170,7 @@ class Dia: 
     | 
|
| 110 | 
         | 
| 111 | 
         
             
                    Args:
         
     | 
| 112 | 
         
             
                        model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
         
     | 
| 113 | 
         
            -
                        device: The device to load the model onto.
         
     | 
| 114 | 
         | 
| 115 | 
         
             
                    Returns:
         
     | 
| 116 | 
         
             
                        An instance of the Dia model loaded with weights and set to eval mode.
         
     | 
| 
         @@ -121,7 +181,7 @@ class Dia: 
     | 
|
| 121 | 
         
             
                    """
         
     | 
| 122 | 
         
             
                    config_path = hf_hub_download(repo_id=model_name, filename="config.json")
         
     | 
| 123 | 
         
             
                    checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
         
     | 
| 124 | 
         
            -
                    return cls.from_local(config_path, checkpoint_path, device)
         
     | 
| 125 | 
         | 
| 126 | 
         
             
                def _load_dac_model(self):
         
     | 
| 127 | 
         
             
                    try:
         
     | 
| 
         @@ -131,44 +191,7 @@ class Dia: 
     | 
|
| 131 | 
         
             
                        raise RuntimeError("Failed to load DAC model") from e
         
     | 
| 132 | 
         
             
                    self.dac_model = dac_model
         
     | 
| 133 | 
         | 
| 134 | 
         
            -
                def  
     | 
| 135 | 
         
            -
                    self,
         
     | 
| 136 | 
         
            -
                    q_padding_mask_1d: torch.Tensor,
         
     | 
| 137 | 
         
            -
                    k_padding_mask_1d: torch.Tensor,
         
     | 
| 138 | 
         
            -
                    is_causal: bool = False,
         
     | 
| 139 | 
         
            -
                ) -> torch.Tensor:
         
     | 
| 140 | 
         
            -
                    """
         
     | 
| 141 | 
         
            -
                    Creates the attention mask (self or cross) mimicking JAX segment ID logic.
         
     | 
| 142 | 
         
            -
                    """
         
     | 
| 143 | 
         
            -
                    B1, Tq = q_padding_mask_1d.shape
         
     | 
| 144 | 
         
            -
                    B2, Tk = k_padding_mask_1d.shape
         
     | 
| 145 | 
         
            -
                    assert B1 == B2, "Query and key batch dimensions must match"
         
     | 
| 146 | 
         
            -
             
     | 
| 147 | 
         
            -
                    p_mask_q = q_padding_mask_1d.unsqueeze(2)  # Shape [B, Tq, 1]
         
     | 
| 148 | 
         
            -
                    p_mask_k = k_padding_mask_1d.unsqueeze(1)  # Shape [B, 1, Tk]
         
     | 
| 149 | 
         
            -
             
     | 
| 150 | 
         
            -
                    # Condition A: Non-padding query attends to non-padding key
         
     | 
| 151 | 
         
            -
                    non_pad_attends_non_pad = p_mask_q & p_mask_k  # Shape [B, Tq, Tk]
         
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
                    # Condition B: Padding query attends to padding key
         
     | 
| 154 | 
         
            -
                    pad_attends_pad = (~p_mask_q) & (~p_mask_k)  # Shape [B, Tq, Tk]
         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                    # Combine: True if padding status is compatible (both non-pad OR both pad)
         
     | 
| 157 | 
         
            -
                    # This implementation follows Jax TPU splash attention kernel
         
     | 
| 158 | 
         
            -
                    mask = non_pad_attends_non_pad | pad_attends_pad  # Shape [B, Tq, Tk]
         
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
            -
                    if is_causal:
         
     | 
| 161 | 
         
            -
                        # Ensure causality for self-attention (Tq == Tk)
         
     | 
| 162 | 
         
            -
                        assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
         
     | 
| 163 | 
         
            -
                        # Standard lower-triangular causal mask (True means allow)
         
     | 
| 164 | 
         
            -
                        causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device))  # Shape [Tq, Tk]
         
     | 
| 165 | 
         
            -
                        causal_mask = mask & causal_mask_2d  # Shape [B, Tq, Tk]
         
     | 
| 166 | 
         
            -
                        return causal_mask.unsqueeze(1)  # Shape [B, 1, Tq, Tk] for broadcasting across heads
         
     | 
| 167 | 
         
            -
                    else:
         
     | 
| 168 | 
         
            -
                        # For cross-attention or non-causal self-attention
         
     | 
| 169 | 
         
            -
                        return mask.unsqueeze(1)  # Shape [B, 1, Tq, Tk] for broadcasting across heads
         
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
                def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         
     | 
| 172 | 
         
             
                    """Encodes text prompt, pads, and creates attention mask and positions."""
         
     | 
| 173 | 
         
             
                    text_pad_value = self.config.data.text_pad_value
         
     | 
| 174 | 
         
             
                    max_len = self.config.data.text_length
         
     | 
| 
         @@ -190,14 +213,168 @@ class Dia: 
     | 
|
| 190 | 
         
             
                            constant_values=text_pad_value,
         
     | 
| 191 | 
         
             
                        ).astype(np.uint8)
         
     | 
| 192 | 
         | 
| 193 | 
         
            -
                    src_tokens =  
     | 
| 194 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 195 | 
         | 
| 196 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 197 | 
         | 
| 198 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 199 | 
         | 
| 200 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 201 | 
         | 
| 202 | 
         
             
                @torch.inference_mode()
         
     | 
| 203 | 
         
             
                def generate(
         
     | 
| 
         @@ -207,225 +384,105 @@ class Dia: 
     | 
|
| 207 | 
         
             
                    cfg_scale: float = 3.0,
         
     | 
| 208 | 
         
             
                    temperature: float = 1.3,
         
     | 
| 209 | 
         
             
                    top_p: float = 0.95,
         
     | 
| 210 | 
         
            -
                     
     | 
| 211 | 
         
            -
                     
     | 
| 212 | 
         
            -
                     
     | 
| 213 | 
         
             
                    audio_prompt_path: str | None = None,
         
     | 
| 
         | 
|
| 
         | 
|
| 214 | 
         
             
                ) -> np.ndarray:
         
     | 
| 215 | 
         
            -
                    """
         
     | 
| 216 | 
         
            -
                    Generates audio from a text prompt (and optional audio prompt) using the Nari model.
         
     | 
| 217 | 
         
            -
             
     | 
| 218 | 
         
            -
                    Returns:
         
     | 
| 219 | 
         
            -
                        A tensor of generated audio codes (shape: [max_tokens, num_channels]).
         
     | 
| 220 | 
         
            -
                    """
         
     | 
| 221 | 
         
            -
                    num_channels = self.config.data.channels
         
     | 
| 222 | 
         
            -
                    audio_bos_value = self.config.data.audio_bos_value
         
     | 
| 223 | 
         
             
                    audio_eos_value = self.config.data.audio_eos_value
         
     | 
| 224 | 
         
             
                    audio_pad_value = self.config.data.audio_pad_value
         
     | 
| 225 | 
         
             
                    delay_pattern = self.config.data.delay_pattern
         
     | 
| 226 | 
         
             
                    max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
         
     | 
| 227 | 
         
            -
                    delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
         
     | 
| 228 | 
         
             
                    max_delay_pattern = max(delay_pattern)
         
     | 
| 229 | 
         
             
                    self.model.eval()
         
     | 
| 230 | 
         | 
| 231 | 
         
            -
                     
     | 
| 232 | 
         
            -
                         
     | 
| 233 | 
         
            -
                         
     | 
| 234 | 
         
            -
             
     | 
| 235 | 
         
            -
                         
     | 
| 236 | 
         
            -
                    ) = self._prepare_text_input(text)
         
     | 
| 237 | 
         
            -
             
     | 
| 238 | 
         
            -
                    unc_src_BxS = torch.zeros_like(cond_src_BxS)
         
     | 
| 239 | 
         
            -
                    src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
         
     | 
| 240 | 
         
            -
                    src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
         
     | 
| 241 | 
         
            -
                    src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
         
     | 
| 242 | 
         
            -
                    enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
         
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
                    # 2. Encoder Pass
         
     | 
| 245 | 
         
            -
                    # with torch.autocast(device_type="cuda", dtype=forward_dtype):
         
     | 
| 246 | 
         
            -
                    encoder_out = self.model.encoder(
         
     | 
| 247 | 
         
            -
                        x_ids=src_BxS,
         
     | 
| 248 | 
         
            -
                        src_positions=src_positions_BxS,
         
     | 
| 249 | 
         
            -
                        deterministic=True,
         
     | 
| 250 | 
         
            -
                        attn_mask=enc_self_attn_mask_Bx1xSxS,
         
     | 
| 251 | 
         
            -
                    )  # Shape: (B, S, E)
         
     | 
| 252 | 
         
            -
             
     | 
| 253 | 
         
            -
                    # 3. Prepare Decoder Inputs
         
     | 
| 254 | 
         
            -
                    # 3-1. Allocate KV Cache (Static)
         
     | 
| 255 | 
         
            -
                    decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
         
     | 
| 256 | 
         
            -
                        max_tokens, encoder_out, src_positions_BxS
         
     | 
| 257 | 
         
            -
                    )
         
     | 
| 258 | 
         
            -
             
     | 
| 259 | 
         
            -
                    decoder_self_attention_cache: list[KVCache] = []
         
     | 
| 260 | 
         
            -
                    for _ in range(self.model.decoder.num_layers):
         
     | 
| 261 | 
         
            -
                        decoder_self_attention_cache.append(
         
     | 
| 262 | 
         
            -
                            KVCache(
         
     | 
| 263 | 
         
            -
                                self.config.model.decoder.gqa_query_heads,
         
     | 
| 264 | 
         
            -
                                max_tokens,
         
     | 
| 265 | 
         
            -
                                self.config.model.decoder.gqa_head_dim,
         
     | 
| 266 | 
         
            -
                                self.device,
         
     | 
| 267 | 
         
            -
                            )
         
     | 
| 268 | 
         
            -
                        )
         
     | 
| 269 | 
         
            -
             
     | 
| 270 | 
         
            -
                    # 3-2. Initialize Decoder Inputs
         
     | 
| 271 | 
         
            -
                    generated_BxTxC = torch.full(
         
     | 
| 272 | 
         
            -
                        (2, 1, num_channels),
         
     | 
| 273 | 
         
            -
                        fill_value=audio_bos_value,
         
     | 
| 274 | 
         
            -
                        dtype=torch.long,
         
     | 
| 275 | 
         
            -
                        device=self.device,
         
     | 
| 276 | 
         
            -
                    )
         
     | 
| 277 | 
         
            -
             
     | 
| 278 | 
         
            -
                    current_step = 0
         
     | 
| 279 | 
         
            -
                    prompt_len_inc_bos = 1  # Start with BOS length
         
     | 
| 280 | 
         
            -
             
     | 
| 281 | 
         
            -
                    # 3-3. Load Audio Prompt (if provided)
         
     | 
| 282 | 
         
            -
                    if audio_prompt_path is not None:
         
     | 
| 283 | 
         
            -
                        audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True)  # C, T
         
     | 
| 284 | 
         
            -
                        if sr != 44100:  # Resample to 44.1kHz
         
     | 
| 285 | 
         
            -
                            audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
         
     | 
| 286 | 
         
            -
                        audio_prompt = audio_prompt.to(self.device).unsqueeze(0)  # 1, C, T
         
     | 
| 287 | 
         
            -
                        audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
         
     | 
| 288 | 
         
            -
                        generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
         
     | 
| 289 | 
         
            -
             
     | 
| 290 | 
         
            -
                        prefill_len = generated_BxTxC.shape[1]
         
     | 
| 291 | 
         
            -
                        prompt_len_inc_bos = prefill_len
         
     | 
| 292 | 
         
            -
                        prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
         
     | 
| 293 | 
         
            -
                        prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
         
     | 
| 294 | 
         
            -
             
     | 
| 295 | 
         
            -
                        prefill_self_attn_mask = self._create_attn_mask(
         
     | 
| 296 | 
         
            -
                            prefill_tgt_padding_mask,
         
     | 
| 297 | 
         
            -
                            prefill_tgt_padding_mask,
         
     | 
| 298 | 
         
            -
                            is_causal=True,
         
     | 
| 299 | 
         
            -
                        )
         
     | 
| 300 | 
         
            -
                        prefill_cross_attn_mask = self._create_attn_mask(
         
     | 
| 301 | 
         
            -
                            prefill_tgt_padding_mask,
         
     | 
| 302 | 
         
            -
                            src_padding_mask_BxS,
         
     | 
| 303 | 
         
            -
                            is_causal=False,
         
     | 
| 304 | 
         
            -
                        )
         
     | 
| 305 | 
         | 
| 306 | 
         
            -
             
     | 
| 307 | 
         
            -
             
     | 
| 308 | 
         
            -
                            encoder_out=encoder_out,
         
     | 
| 309 | 
         
            -
                            tgt_positions=prefill_tgt_pos,
         
     | 
| 310 | 
         
            -
                            src_positions=src_positions_BxS,
         
     | 
| 311 | 
         
            -
                            deterministic=True,
         
     | 
| 312 | 
         
            -
                            self_attn_mask=prefill_self_attn_mask,
         
     | 
| 313 | 
         
            -
                            cross_attn_mask=prefill_cross_attn_mask,
         
     | 
| 314 | 
         
            -
                            self_attention_cache=decoder_self_attention_cache,
         
     | 
| 315 | 
         
            -
                            cross_attention_cache=decoder_cross_attention_cache,
         
     | 
| 316 | 
         
            -
                        )
         
     | 
| 317 | 
         | 
| 318 | 
         
            -
             
     | 
| 
         | 
|
| 319 | 
         | 
| 320 | 
         
            -
                     
     | 
| 321 | 
         
            -
                     
     | 
| 322 | 
         
             
                    eos_countdown = -1
         
     | 
| 323 | 
         
            -
                    extra_steps_after_eos = 30
         
     | 
| 324 | 
         
            -
                    # Make generated_BxTxC a fixed size tensor
         
     | 
| 325 | 
         
            -
                    # Length is either 1 + max tokens or 1 + prompt len + max tokens
         
     | 
| 326 | 
         
            -
                    generated_BxTxC = torch.cat(
         
     | 
| 327 | 
         
            -
                        [
         
     | 
| 328 | 
         
            -
                            generated_BxTxC,
         
     | 
| 329 | 
         
            -
                            torch.full(
         
     | 
| 330 | 
         
            -
                                (2, max_tokens, num_channels),
         
     | 
| 331 | 
         
            -
                                fill_value=-1,
         
     | 
| 332 | 
         
            -
                                dtype=torch.long,
         
     | 
| 333 | 
         
            -
                                device=self.device,
         
     | 
| 334 | 
         
            -
                            ),
         
     | 
| 335 | 
         
            -
                        ],
         
     | 
| 336 | 
         
            -
                        dim=1,
         
     | 
| 337 | 
         
            -
                    )
         
     | 
| 338 | 
         | 
| 339 | 
         
            -
                    decode_step = self.model.decoder.decode_step
         
     | 
| 340 | 
         
             
                    if use_torch_compile:
         
     | 
| 341 | 
         
            -
                         
     | 
| 342 | 
         
            -
             
     | 
| 343 | 
         
            -
             
     | 
| 344 | 
         
            -
                        )
         
     | 
| 345 | 
         | 
| 346 | 
         
            -
                     
     | 
| 347 | 
         
            -
                        ( 
     | 
| 348 | 
         
            -
             
     | 
| 349 | 
         
            -
             
     | 
| 350 | 
         
            -
             
     | 
| 351 | 
         
            -
             
     | 
| 352 | 
         
            -
                         
     | 
| 353 | 
         
            -
                        is_causal=False,
         
     | 
| 354 | 
         
            -
                    )  # [B, 1, 1, S]
         
     | 
| 355 | 
         
            -
             
     | 
| 356 | 
         
            -
                    for step in range(current_step, current_step + max_tokens):
         
     | 
| 357 | 
         
            -
                        tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
         
     | 
| 358 | 
         
            -
                        tgt_pos_Bx1 = torch.full(
         
     | 
| 359 | 
         
            -
                            (2, 1),
         
     | 
| 360 | 
         
            -
                            fill_value=step,
         
     | 
| 361 | 
         
            -
                            dtype=torch.long,
         
     | 
| 362 | 
         
            -
                            device=self.device,
         
     | 
| 363 | 
         
            -
                        )
         
     | 
| 364 | 
         | 
| 365 | 
         
            -
             
     | 
| 366 | 
         
            -
             
     | 
| 367 | 
         
            -
             
     | 
| 368 | 
         
            -
                             
     | 
| 369 | 
         
            -
                            self_attn_mask=None,
         
     | 
| 370 | 
         
            -
                            cross_attn_mask=decoder_cross_attn_mask,
         
     | 
| 371 | 
         
            -
                            self_attention_cache=decoder_self_attention_cache,
         
     | 
| 372 | 
         
            -
                            cross_attention_cache=decoder_cross_attention_cache,
         
     | 
| 373 | 
         
             
                        )
         
     | 
| 374 | 
         
            -
             
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                             
     | 
| 377 | 
         
            -
             
     | 
| 378 | 
         
            -
             
     | 
| 379 | 
         
            -
             
     | 
| 380 | 
         
            -
             
     | 
| 381 | 
         
            -
                        cond_logits_CxV = logits_last_BxCxV[1, :, :]
         
     | 
| 382 | 
         
            -
             
     | 
| 383 | 
         
            -
                        cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
            -
                        logits_CxV = cfg_logits_CxV.reshape((-1, V))  # C, V
         
     | 
| 386 | 
         
            -
                        logits_CxV[:, 1025:] = -torch.inf
         
     | 
| 387 | 
         
            -
             
     | 
| 388 | 
         
            -
                        # Sample next token
         
     | 
| 389 | 
         
            -
                        pred_C = _sample_next_token(
         
     | 
| 390 | 
         
            -
                            logits_CxV.float(),
         
     | 
| 391 | 
         
            -
                            temperature=temperature,
         
     | 
| 392 | 
         
            -
                            top_p=top_p,
         
     | 
| 393 | 
         
            -
                            use_cfg_filter=use_cfg_filter,
         
     | 
| 394 | 
         
            -
                            cfg_filter_top_k=cfg_filter_top_k,
         
     | 
| 395 | 
         
             
                        )
         
     | 
| 396 | 
         | 
| 397 | 
         
            -
                         
     | 
| 398 | 
         
            -
             
     | 
| 399 | 
         
            -
             
     | 
| 400 | 
         
            -
             
     | 
| 401 | 
         
            -
             
     | 
| 402 | 
         
            -
                                audio_bos_value,
         
     | 
| 403 | 
         
            -
                            )
         
     | 
| 404 | 
         
            -
             
     | 
| 405 | 
         
            -
                        generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
         
     | 
| 406 | 
         
            -
             
     | 
| 407 | 
         
            -
                        if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
         
     | 
| 408 | 
         
            -
                            eos_detected_channel_0 = True
         
     | 
| 409 | 
         
            -
                            eos_countdown = extra_steps_after_eos
         
     | 
| 410 | 
         | 
| 411 | 
         
             
                        if eos_countdown > 0:
         
     | 
| 412 | 
         
             
                            step_after_eos = max_delay_pattern - eos_countdown
         
     | 
| 413 | 
         
             
                            for i, d in enumerate(delay_pattern):
         
     | 
| 414 | 
         
             
                                if step_after_eos == d:
         
     | 
| 415 | 
         
            -
                                     
     | 
| 416 | 
         
             
                                elif step_after_eos > d:
         
     | 
| 417 | 
         
            -
                                     
     | 
| 418 | 
         
             
                            eos_countdown -= 1
         
     | 
| 419 | 
         
            -
                            if eos_countdown == 0:
         
     | 
| 420 | 
         
            -
                                break
         
     | 
| 421 | 
         | 
| 422 | 
         
            -
                         
     | 
| 
         | 
|
| 423 | 
         | 
| 424 | 
         
            -
             
     | 
| 
         | 
|
| 425 | 
         | 
| 426 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 427 | 
         | 
| 428 | 
         
            -
                     
     | 
| 429 | 
         
            -
                         
     | 
| 430 | 
         
            -
             
     | 
| 431 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import time
         
     | 
| 2 | 
         
            +
            from enum import Enum
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
             
            import dac
         
     | 
| 5 | 
         
             
            import numpy as np
         
     | 
| 6 | 
         
             
            import torch
         
     | 
| 7 | 
         
             
            import torchaudio
         
     | 
| 8 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 9 | 
         | 
| 10 | 
         
            +
            from .audio import (
         
     | 
| 11 | 
         
            +
                apply_audio_delay,
         
     | 
| 12 | 
         
            +
                build_delay_indices,
         
     | 
| 13 | 
         
            +
                build_revert_indices,
         
     | 
| 14 | 
         
            +
                decode,
         
     | 
| 15 | 
         
            +
                revert_audio_delay,
         
     | 
| 16 | 
         
            +
            )
         
     | 
| 17 | 
         
             
            from .config import DiaConfig
         
     | 
| 18 | 
         
            +
            from .layers import DiaModel
         
     | 
| 19 | 
         
            +
            from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            DEFAULT_SAMPLE_RATE = 44100
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def _get_default_device():
         
     | 
| 26 | 
         
            +
                if torch.cuda.is_available():
         
     | 
| 27 | 
         
            +
                    return torch.device("cuda")
         
     | 
| 28 | 
         
            +
                elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
         
     | 
| 29 | 
         
            +
                    return torch.device("mps")
         
     | 
| 30 | 
         
            +
                return torch.device("cpu")
         
     | 
| 31 | 
         | 
| 32 | 
         | 
| 33 | 
         
             
            def _sample_next_token(
         
     | 
| 34 | 
         
             
                logits_BCxV: torch.Tensor,
         
     | 
| 35 | 
         
             
                temperature: float,
         
     | 
| 36 | 
         
             
                top_p: float,
         
     | 
| 
         | 
|
| 37 | 
         
             
                cfg_filter_top_k: int | None = None,
         
     | 
| 38 | 
         
             
            ) -> torch.Tensor:
         
     | 
| 39 | 
         
             
                if temperature == 0.0:
         
     | 
| 40 | 
         
             
                    return torch.argmax(logits_BCxV, dim=-1)
         
     | 
| 41 | 
         | 
| 42 | 
         
             
                logits_BCxV = logits_BCxV / temperature
         
     | 
| 43 | 
         
            +
                if cfg_filter_top_k is not None:
         
     | 
| 44 | 
         
             
                    _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
         
     | 
| 45 | 
         
             
                    mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
         
     | 
| 46 | 
         
             
                    mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
         
     | 
| 
         | 
|
| 48 | 
         | 
| 49 | 
         
             
                if top_p < 1.0:
         
     | 
| 50 | 
         
             
                    probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
         
     | 
| 51 | 
         
            +
                    sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
         
     | 
| 52 | 
         
            +
                        probs_BCxV, dim=-1, descending=True
         
     | 
| 53 | 
         
            +
                    )
         
     | 
| 54 | 
         
             
                    cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
         
     | 
| 55 | 
         | 
| 
         | 
|
| 56 | 
         
             
                    sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
         
     | 
| 57 | 
         
            +
                    sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
         
     | 
| 58 | 
         
            +
                        ..., :-1
         
     | 
| 59 | 
         
            +
                    ].clone()
         
     | 
| 60 | 
         
            +
                    sorted_indices_to_remove_BCxV[..., 0] = 0
         
     | 
| 61 | 
         | 
| 62 | 
         
             
                    indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
         
     | 
| 63 | 
         
            +
                    indices_to_remove_BCxV.scatter_(
         
     | 
| 64 | 
         
            +
                        dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
             
                    logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
         
     | 
| 67 | 
         | 
| 68 | 
         
             
                final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
         
     | 
| 
         | 
|
| 72 | 
         
             
                return sampled_indices_C
         
     | 
| 73 | 
         | 
| 74 | 
         | 
| 75 | 
         
            +
            class ComputeDtype(str, Enum):
         
     | 
| 76 | 
         
            +
                FLOAT32 = "float32"
         
     | 
| 77 | 
         
            +
                FLOAT16 = "float16"
         
     | 
| 78 | 
         
            +
                BFLOAT16 = "bfloat16"
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                def to_dtype(self) -> torch.dtype:
         
     | 
| 81 | 
         
            +
                    if self == ComputeDtype.FLOAT32:
         
     | 
| 82 | 
         
            +
                        return torch.float32
         
     | 
| 83 | 
         
            +
                    elif self == ComputeDtype.FLOAT16:
         
     | 
| 84 | 
         
            +
                        return torch.float16
         
     | 
| 85 | 
         
            +
                    elif self == ComputeDtype.BFLOAT16:
         
     | 
| 86 | 
         
            +
                        return torch.bfloat16
         
     | 
| 87 | 
         
            +
                    else:
         
     | 
| 88 | 
         
            +
                        raise ValueError(f"Unsupported compute dtype: {self}")
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
             
            class Dia:
         
     | 
| 92 | 
         
            +
                def __init__(
         
     | 
| 93 | 
         
            +
                    self,
         
     | 
| 94 | 
         
            +
                    config: DiaConfig,
         
     | 
| 95 | 
         
            +
                    compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
         
     | 
| 96 | 
         
            +
                    device: torch.device | None = None,
         
     | 
| 97 | 
         
            +
                ):
         
     | 
| 98 | 
         
             
                    """Initializes the Dia model.
         
     | 
| 99 | 
         | 
| 100 | 
         
             
                    Args:
         
     | 
| 101 | 
         
             
                        config: The configuration object for the model.
         
     | 
| 102 | 
         
            +
                        device: The device to load the model onto. If None, will automatically select the best available device.
         
     | 
| 103 | 
         | 
| 104 | 
         
             
                    Raises:
         
     | 
| 105 | 
         
             
                        RuntimeError: If there is an error loading the DAC model.
         
     | 
| 106 | 
         
             
                    """
         
     | 
| 107 | 
         
             
                    super().__init__()
         
     | 
| 108 | 
         
             
                    self.config = config
         
     | 
| 109 | 
         
            +
                    self.device = device if device is not None else _get_default_device()
         
     | 
| 110 | 
         
            +
                    if isinstance(compute_dtype, str):
         
     | 
| 111 | 
         
            +
                        compute_dtype = ComputeDtype(compute_dtype)
         
     | 
| 112 | 
         
            +
                    self.compute_dtype = compute_dtype.to_dtype()
         
     | 
| 113 | 
         
            +
                    self.model = DiaModel(config, self.compute_dtype)
         
     | 
| 114 | 
         
             
                    self.dac_model = None
         
     | 
| 115 | 
         | 
| 116 | 
         
             
                @classmethod
         
     | 
| 117 | 
         
            +
                def from_local(
         
     | 
| 118 | 
         
            +
                    cls,
         
     | 
| 119 | 
         
            +
                    config_path: str,
         
     | 
| 120 | 
         
            +
                    checkpoint_path: str,
         
     | 
| 121 | 
         
            +
                    compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
         
     | 
| 122 | 
         
            +
                    device: torch.device | None = None,
         
     | 
| 123 | 
         
            +
                ) -> "Dia":
         
     | 
| 124 | 
         
             
                    """Loads the Dia model from local configuration and checkpoint files.
         
     | 
| 125 | 
         | 
| 126 | 
         
             
                    Args:
         
     | 
| 127 | 
         
             
                        config_path: Path to the configuration JSON file.
         
     | 
| 128 | 
         
             
                        checkpoint_path: Path to the model checkpoint (.pth) file.
         
     | 
| 129 | 
         
            +
                        device: The device to load the model onto. If None, will automatically select the best available device.
         
     | 
| 130 | 
         | 
| 131 | 
         
             
                    Returns:
         
     | 
| 132 | 
         
             
                        An instance of the Dia model loaded with weights and set to eval mode.
         
     | 
| 
         | 
|
| 139 | 
         
             
                    if config is None:
         
     | 
| 140 | 
         
             
                        raise FileNotFoundError(f"Config file not found at {config_path}")
         
     | 
| 141 | 
         | 
| 142 | 
         
            +
                    dia = cls(config, compute_dtype, device)
         
     | 
| 143 | 
         | 
| 144 | 
         
             
                    try:
         
     | 
| 145 | 
         
            +
                        state_dict = torch.load(checkpoint_path, map_location=dia.device)
         
     | 
| 146 | 
         
            +
                        dia.model.load_state_dict(state_dict)
         
     | 
| 147 | 
         
             
                    except FileNotFoundError:
         
     | 
| 148 | 
         
             
                        raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
         
     | 
| 149 | 
         
             
                    except Exception as e:
         
     | 
| 150 | 
         
            +
                        raise RuntimeError(
         
     | 
| 151 | 
         
            +
                            f"Error loading checkpoint from {checkpoint_path}"
         
     | 
| 152 | 
         
            +
                        ) from e
         
     | 
| 153 | 
         | 
| 154 | 
         
            +
                    dia.model.to(dia.device)
         
     | 
| 155 | 
         
             
                    dia.model.eval()
         
     | 
| 156 | 
         
             
                    dia._load_dac_model()
         
     | 
| 157 | 
         
             
                    return dia
         
     | 
| 158 | 
         | 
| 159 | 
         
             
                @classmethod
         
     | 
| 160 | 
         
             
                def from_pretrained(
         
     | 
| 161 | 
         
            +
                    cls,
         
     | 
| 162 | 
         
            +
                    model_name: str = "nari-labs/Dia-1.6B",
         
     | 
| 163 | 
         
            +
                    compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
         
     | 
| 164 | 
         
            +
                    device: torch.device | None = None,
         
     | 
| 165 | 
         
             
                ) -> "Dia":
         
     | 
| 166 | 
         
             
                    """Loads the Dia model from a Hugging Face Hub repository.
         
     | 
| 167 | 
         | 
| 
         | 
|
| 170 | 
         | 
| 171 | 
         
             
                    Args:
         
     | 
| 172 | 
         
             
                        model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
         
     | 
| 173 | 
         
            +
                        device: The device to load the model onto. If None, will automatically select the best available device.
         
     | 
| 174 | 
         | 
| 175 | 
         
             
                    Returns:
         
     | 
| 176 | 
         
             
                        An instance of the Dia model loaded with weights and set to eval mode.
         
     | 
| 
         | 
|
| 181 | 
         
             
                    """
         
     | 
| 182 | 
         
             
                    config_path = hf_hub_download(repo_id=model_name, filename="config.json")
         
     | 
| 183 | 
         
             
                    checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
         
     | 
| 184 | 
         
            +
                    return cls.from_local(config_path, checkpoint_path, compute_dtype, device)
         
     | 
| 185 | 
         | 
| 186 | 
         
             
                def _load_dac_model(self):
         
     | 
| 187 | 
         
             
                    try:
         
     | 
| 
         | 
|
| 191 | 
         
             
                        raise RuntimeError("Failed to load DAC model") from e
         
     | 
| 192 | 
         
             
                    self.dac_model = dac_model
         
     | 
| 193 | 
         | 
| 194 | 
         
            +
                def _prepare_text_input(self, text: str) -> torch.Tensor:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 195 | 
         
             
                    """Encodes text prompt, pads, and creates attention mask and positions."""
         
     | 
| 196 | 
         
             
                    text_pad_value = self.config.data.text_pad_value
         
     | 
| 197 | 
         
             
                    max_len = self.config.data.text_length
         
     | 
| 
         | 
|
| 213 | 
         
             
                            constant_values=text_pad_value,
         
     | 
| 214 | 
         
             
                        ).astype(np.uint8)
         
     | 
| 215 | 
         | 
| 216 | 
         
            +
                    src_tokens = (
         
     | 
| 217 | 
         
            +
                        torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
         
     | 
| 218 | 
         
            +
                    )  # [1, S]
         
     | 
| 219 | 
         
            +
                    return src_tokens
         
     | 
| 220 | 
         | 
| 221 | 
         
            +
                def _prepare_audio_prompt(
         
     | 
| 222 | 
         
            +
                    self, audio_prompt: torch.Tensor | None
         
     | 
| 223 | 
         
            +
                ) -> tuple[torch.Tensor, int]:
         
     | 
| 224 | 
         
            +
                    num_channels = self.config.data.channels
         
     | 
| 225 | 
         
            +
                    audio_bos_value = self.config.data.audio_bos_value
         
     | 
| 226 | 
         
            +
                    audio_pad_value = self.config.data.audio_pad_value
         
     | 
| 227 | 
         
            +
                    delay_pattern = self.config.data.delay_pattern
         
     | 
| 228 | 
         
            +
                    max_delay_pattern = max(delay_pattern)
         
     | 
| 229 | 
         | 
| 230 | 
         
            +
                    prefill = torch.full(
         
     | 
| 231 | 
         
            +
                        (1, num_channels),
         
     | 
| 232 | 
         
            +
                        fill_value=audio_bos_value,
         
     | 
| 233 | 
         
            +
                        dtype=torch.int,
         
     | 
| 234 | 
         
            +
                        device=self.device,
         
     | 
| 235 | 
         
            +
                    )
         
     | 
| 236 | 
         | 
| 237 | 
         
            +
                    prefill_step = 1
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    if audio_prompt is not None:
         
     | 
| 240 | 
         
            +
                        prefill_step += audio_prompt.shape[0]
         
     | 
| 241 | 
         
            +
                        prefill = torch.cat([prefill, audio_prompt], dim=0)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    delay_pad_tensor = torch.full(
         
     | 
| 244 | 
         
            +
                        (max_delay_pattern, num_channels),
         
     | 
| 245 | 
         
            +
                        fill_value=-1,
         
     | 
| 246 | 
         
            +
                        dtype=torch.int,
         
     | 
| 247 | 
         
            +
                        device=self.device,
         
     | 
| 248 | 
         
            +
                    )
         
     | 
| 249 | 
         
            +
                    prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    delay_precomp = build_delay_indices(
         
     | 
| 252 | 
         
            +
                        B=1,
         
     | 
| 253 | 
         
            +
                        T=prefill.shape[0],
         
     | 
| 254 | 
         
            +
                        C=num_channels,
         
     | 
| 255 | 
         
            +
                        delay_pattern=delay_pattern,
         
     | 
| 256 | 
         
            +
                    )
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    prefill = apply_audio_delay(
         
     | 
| 259 | 
         
            +
                        audio_BxTxC=prefill.unsqueeze(0),
         
     | 
| 260 | 
         
            +
                        pad_value=audio_pad_value,
         
     | 
| 261 | 
         
            +
                        bos_value=audio_bos_value,
         
     | 
| 262 | 
         
            +
                        precomp=delay_precomp,
         
     | 
| 263 | 
         
            +
                    ).squeeze(0)
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    return prefill, prefill_step
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                def _prepare_generation(
         
     | 
| 268 | 
         
            +
                    self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
         
     | 
| 269 | 
         
            +
                ):
         
     | 
| 270 | 
         
            +
                    enc_input_cond = self._prepare_text_input(text)
         
     | 
| 271 | 
         
            +
                    enc_input_uncond = torch.zeros_like(enc_input_cond)
         
     | 
| 272 | 
         
            +
                    enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    if isinstance(audio_prompt, str):
         
     | 
| 275 | 
         
            +
                        audio_prompt = self.load_audio(audio_prompt)
         
     | 
| 276 | 
         
            +
                    prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    if verbose:
         
     | 
| 279 | 
         
            +
                        print("generate: data loaded")
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
         
     | 
| 282 | 
         
            +
                    encoder_out = self.model.encoder(enc_input, enc_state)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
         
     | 
| 285 | 
         
            +
                        encoder_out, enc_state.positions
         
     | 
| 286 | 
         
            +
                    )
         
     | 
| 287 | 
         
            +
                    dec_state = DecoderInferenceState.new(
         
     | 
| 288 | 
         
            +
                        self.config,
         
     | 
| 289 | 
         
            +
                        enc_state,
         
     | 
| 290 | 
         
            +
                        encoder_out,
         
     | 
| 291 | 
         
            +
                        dec_cross_attn_cache,
         
     | 
| 292 | 
         
            +
                        self.compute_dtype,
         
     | 
| 293 | 
         
            +
                    )
         
     | 
| 294 | 
         
            +
                    dec_output = DecoderOutput.new(self.config, self.device)
         
     | 
| 295 | 
         
            +
                    dec_output.prefill(prefill, prefill_step)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    dec_step = prefill_step - 1
         
     | 
| 298 | 
         
            +
                    if dec_step > 0:
         
     | 
| 299 | 
         
            +
                        dec_state.prepare_step(0, dec_step)
         
     | 
| 300 | 
         
            +
                        tokens_BxTxC = (
         
     | 
| 301 | 
         
            +
                            dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
         
     | 
| 302 | 
         
            +
                        )
         
     | 
| 303 | 
         
            +
                        self.model.decoder.forward(tokens_BxTxC, dec_state)
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    return dec_state, dec_output
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                def _decoder_step(
         
     | 
| 308 | 
         
            +
                    self,
         
     | 
| 309 | 
         
            +
                    tokens_Bx1xC: torch.Tensor,
         
     | 
| 310 | 
         
            +
                    dec_state: DecoderInferenceState,
         
     | 
| 311 | 
         
            +
                    cfg_scale: float,
         
     | 
| 312 | 
         
            +
                    temperature: float,
         
     | 
| 313 | 
         
            +
                    top_p: float,
         
     | 
| 314 | 
         
            +
                    cfg_filter_top_k: int,
         
     | 
| 315 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 316 | 
         
            +
                    audio_eos_value = self.config.data.audio_eos_value
         
     | 
| 317 | 
         
            +
                    logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
         
     | 
| 320 | 
         
            +
                    uncond_logits_CxV = logits_last_BxCxV[0, :, :]
         
     | 
| 321 | 
         
            +
                    cond_logits_CxV = logits_last_BxCxV[1, :, :]
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
         
     | 
| 324 | 
         
            +
                    logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
         
     | 
| 325 | 
         
            +
                    logits_CxV[1:, audio_eos_value:] = -torch.inf
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    pred_C = _sample_next_token(
         
     | 
| 328 | 
         
            +
                        logits_CxV.float(),
         
     | 
| 329 | 
         
            +
                        temperature=temperature,
         
     | 
| 330 | 
         
            +
                        top_p=top_p,
         
     | 
| 331 | 
         
            +
                        cfg_filter_top_k=cfg_filter_top_k,
         
     | 
| 332 | 
         
            +
                    )
         
     | 
| 333 | 
         
            +
                    return pred_C
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
         
     | 
| 336 | 
         
            +
                    num_channels = self.config.data.channels
         
     | 
| 337 | 
         
            +
                    seq_length = generated_codes.shape[0]
         
     | 
| 338 | 
         
            +
                    delay_pattern = self.config.data.delay_pattern
         
     | 
| 339 | 
         
            +
                    audio_pad_value = self.config.data.audio_pad_value
         
     | 
| 340 | 
         
            +
                    max_delay_pattern = max(delay_pattern)
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    revert_precomp = build_revert_indices(
         
     | 
| 343 | 
         
            +
                        B=1,
         
     | 
| 344 | 
         
            +
                        T=seq_length,
         
     | 
| 345 | 
         
            +
                        C=num_channels,
         
     | 
| 346 | 
         
            +
                        delay_pattern=delay_pattern,
         
     | 
| 347 | 
         
            +
                    )
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    codebook = revert_audio_delay(
         
     | 
| 350 | 
         
            +
                        audio_BxTxC=generated_codes.unsqueeze(0),
         
     | 
| 351 | 
         
            +
                        pad_value=audio_pad_value,
         
     | 
| 352 | 
         
            +
                        precomp=revert_precomp,
         
     | 
| 353 | 
         
            +
                        T=seq_length,
         
     | 
| 354 | 
         
            +
                    )[:, :-max_delay_pattern, :]
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                    min_valid_index = 0
         
     | 
| 357 | 
         
            +
                    max_valid_index = 1023
         
     | 
| 358 | 
         
            +
                    invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
         
     | 
| 359 | 
         
            +
                    codebook[invalid_mask] = 0
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    audio = decode(self.dac_model, codebook.transpose(1, 2))
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    return audio.squeeze().cpu().numpy()
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                def load_audio(self, audio_path: str) -> torch.Tensor:
         
     | 
| 366 | 
         
            +
                    audio, sr = torchaudio.load(audio_path, channels_first=True)  # C, T
         
     | 
| 367 | 
         
            +
                    if sr != DEFAULT_SAMPLE_RATE:
         
     | 
| 368 | 
         
            +
                        audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
         
     | 
| 369 | 
         
            +
                    audio = audio.to(self.device).unsqueeze(0)  # 1, C, T
         
     | 
| 370 | 
         
            +
                    audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
         
     | 
| 371 | 
         
            +
                    _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data)  # 1, C, T
         
     | 
| 372 | 
         
            +
                    return encoded_frame.squeeze(0).transpose(0, 1)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                def save_audio(self, path: str, audio: np.ndarray):
         
     | 
| 375 | 
         
            +
                    import soundfile as sf
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    sf.write(path, audio, DEFAULT_SAMPLE_RATE)
         
     | 
| 378 | 
         | 
| 379 | 
         
             
                @torch.inference_mode()
         
     | 
| 380 | 
         
             
                def generate(
         
     | 
| 
         | 
|
| 384 | 
         
             
                    cfg_scale: float = 3.0,
         
     | 
| 385 | 
         
             
                    temperature: float = 1.3,
         
     | 
| 386 | 
         
             
                    top_p: float = 0.95,
         
     | 
| 387 | 
         
            +
                    use_torch_compile: bool = False,
         
     | 
| 388 | 
         
            +
                    cfg_filter_top_k: int = 35,
         
     | 
| 389 | 
         
            +
                    audio_prompt: str | torch.Tensor | None = None,
         
     | 
| 390 | 
         
             
                    audio_prompt_path: str | None = None,
         
     | 
| 391 | 
         
            +
                    use_cfg_filter: bool | None = None,
         
     | 
| 392 | 
         
            +
                    verbose: bool = False,
         
     | 
| 393 | 
         
             
                ) -> np.ndarray:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 394 | 
         
             
                    audio_eos_value = self.config.data.audio_eos_value
         
     | 
| 395 | 
         
             
                    audio_pad_value = self.config.data.audio_pad_value
         
     | 
| 396 | 
         
             
                    delay_pattern = self.config.data.delay_pattern
         
     | 
| 397 | 
         
             
                    max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
         
     | 
| 
         | 
|
| 398 | 
         
             
                    max_delay_pattern = max(delay_pattern)
         
     | 
| 399 | 
         
             
                    self.model.eval()
         
     | 
| 400 | 
         | 
| 401 | 
         
            +
                    if audio_prompt_path:
         
     | 
| 402 | 
         
            +
                        print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
         
     | 
| 403 | 
         
            +
                        audio_prompt = audio_prompt_path
         
     | 
| 404 | 
         
            +
                    if use_cfg_filter is not None:
         
     | 
| 405 | 
         
            +
                        print("Warning: use_cfg_filter is deprecated.")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 406 | 
         | 
| 407 | 
         
            +
                    if verbose:
         
     | 
| 408 | 
         
            +
                        total_start_time = time.time()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 409 | 
         | 
| 410 | 
         
            +
                    dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
         
     | 
| 411 | 
         
            +
                    dec_step = dec_output.prefill_step - 1
         
     | 
| 412 | 
         | 
| 413 | 
         
            +
                    bos_countdown = max_delay_pattern
         
     | 
| 414 | 
         
            +
                    eos_detected = False
         
     | 
| 415 | 
         
             
                    eos_countdown = -1
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 416 | 
         | 
| 
         | 
|
| 417 | 
         
             
                    if use_torch_compile:
         
     | 
| 418 | 
         
            +
                        step_fn = torch.compile(self._decoder_step, mode="default")
         
     | 
| 419 | 
         
            +
                    else:
         
     | 
| 420 | 
         
            +
                        step_fn = self._decoder_step
         
     | 
| 
         | 
|
| 421 | 
         | 
| 422 | 
         
            +
                    if verbose:
         
     | 
| 423 | 
         
            +
                        print("generate: starting generation loop")
         
     | 
| 424 | 
         
            +
                        if use_torch_compile:
         
     | 
| 425 | 
         
            +
                            print(
         
     | 
| 426 | 
         
            +
                                "generate: by using use_torch_compile=True, the first step would take long"
         
     | 
| 427 | 
         
            +
                            )
         
     | 
| 428 | 
         
            +
                        start_time = time.time()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 429 | 
         | 
| 430 | 
         
            +
                    while dec_step < max_tokens:
         
     | 
| 431 | 
         
            +
                        dec_state.prepare_step(dec_step)
         
     | 
| 432 | 
         
            +
                        tokens_Bx1xC = (
         
     | 
| 433 | 
         
            +
                            dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 434 | 
         
             
                        )
         
     | 
| 435 | 
         
            +
                        pred_C = step_fn(
         
     | 
| 436 | 
         
            +
                            tokens_Bx1xC,
         
     | 
| 437 | 
         
            +
                            dec_state,
         
     | 
| 438 | 
         
            +
                            cfg_scale,
         
     | 
| 439 | 
         
            +
                            temperature,
         
     | 
| 440 | 
         
            +
                            top_p,
         
     | 
| 441 | 
         
            +
                            cfg_filter_top_k,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 442 | 
         
             
                        )
         
     | 
| 443 | 
         | 
| 444 | 
         
            +
                        if (
         
     | 
| 445 | 
         
            +
                            not eos_detected and pred_C[0] == audio_eos_value
         
     | 
| 446 | 
         
            +
                        ) or dec_step == max_tokens - max_delay_pattern - 1:
         
     | 
| 447 | 
         
            +
                            eos_detected = True
         
     | 
| 448 | 
         
            +
                            eos_countdown = max_delay_pattern
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 449 | 
         | 
| 450 | 
         
             
                        if eos_countdown > 0:
         
     | 
| 451 | 
         
             
                            step_after_eos = max_delay_pattern - eos_countdown
         
     | 
| 452 | 
         
             
                            for i, d in enumerate(delay_pattern):
         
     | 
| 453 | 
         
             
                                if step_after_eos == d:
         
     | 
| 454 | 
         
            +
                                    pred_C[i] = audio_eos_value
         
     | 
| 455 | 
         
             
                                elif step_after_eos > d:
         
     | 
| 456 | 
         
            +
                                    pred_C[i] = audio_pad_value
         
     | 
| 457 | 
         
             
                            eos_countdown -= 1
         
     | 
| 
         | 
|
| 
         | 
|
| 458 | 
         | 
| 459 | 
         
            +
                        bos_countdown = max(0, bos_countdown - 1)
         
     | 
| 460 | 
         
            +
                        dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
         
     | 
| 461 | 
         | 
| 462 | 
         
            +
                        if eos_countdown == 0:
         
     | 
| 463 | 
         
            +
                            break
         
     | 
| 464 | 
         | 
| 465 | 
         
            +
                        dec_step += 1
         
     | 
| 466 | 
         
            +
                        if verbose and dec_step % 86 == 0:
         
     | 
| 467 | 
         
            +
                            duration = time.time() - start_time
         
     | 
| 468 | 
         
            +
                            print(
         
     | 
| 469 | 
         
            +
                                f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
         
     | 
| 470 | 
         
            +
                            )
         
     | 
| 471 | 
         
            +
                            start_time = time.time()
         
     | 
| 472 | 
         | 
| 473 | 
         
            +
                    if dec_output.prefill_step >= dec_step + 1:
         
     | 
| 474 | 
         
            +
                        print("Warning: Nothing generated")
         
     | 
| 475 | 
         
            +
                        return None
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
                    generated_codes = dec_output.generated_tokens[
         
     | 
| 478 | 
         
            +
                        dec_output.prefill_step : dec_step + 1, :
         
     | 
| 479 | 
         
            +
                    ]
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
                    if verbose:
         
     | 
| 482 | 
         
            +
                        total_step = dec_step + 1 - dec_output.prefill_step
         
     | 
| 483 | 
         
            +
                        total_duration = time.time() - total_start_time
         
     | 
| 484 | 
         
            +
                        print(
         
     | 
| 485 | 
         
            +
                            f"generate: total step={total_step}, total duration={total_duration:.3f}s"
         
     | 
| 486 | 
         
            +
                        )
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                    return self._generate_output(generated_codes)
         
     | 
    	
        dia/state.py
    ADDED
    
    | 
         @@ -0,0 +1,234 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from .config import DiaConfig
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            def create_attn_mask(
         
     | 
| 9 | 
         
            +
                q_padding_mask_1d: torch.Tensor,
         
     | 
| 10 | 
         
            +
                k_padding_mask_1d: torch.Tensor,
         
     | 
| 11 | 
         
            +
                device: torch.device,
         
     | 
| 12 | 
         
            +
                is_causal: bool = False,
         
     | 
| 13 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
                Creates the attention mask (self or cross) mimicking JAX segment ID logic.
         
     | 
| 16 | 
         
            +
                """
         
     | 
| 17 | 
         
            +
                B1, Tq = q_padding_mask_1d.shape
         
     | 
| 18 | 
         
            +
                B2, Tk = k_padding_mask_1d.shape
         
     | 
| 19 | 
         
            +
                assert B1 == B2, "Query and key batch dimensions must match"
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                p_mask_q = q_padding_mask_1d.unsqueeze(2)  # Shape [B, Tq, 1]
         
     | 
| 22 | 
         
            +
                p_mask_k = k_padding_mask_1d.unsqueeze(1)  # Shape [B, 1, Tk]
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                # Condition A: Non-padding query attends to non-padding key
         
     | 
| 25 | 
         
            +
                non_pad_attends_non_pad = p_mask_q & p_mask_k  # Shape [B, Tq, Tk]
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                # Condition B: Padding query attends to padding key
         
     | 
| 28 | 
         
            +
                pad_attends_pad = (~p_mask_q) & (~p_mask_k)  # Shape [B, Tq, Tk]
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                # Combine: True if padding status is compatible (both non-pad OR both pad)
         
     | 
| 31 | 
         
            +
                mask = non_pad_attends_non_pad | pad_attends_pad  # Shape [B, Tq, Tk]
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                if is_causal:
         
     | 
| 34 | 
         
            +
                    assert Tq == Tk, (
         
     | 
| 35 | 
         
            +
                        "Causal mask requires query and key sequence lengths to be equal"
         
     | 
| 36 | 
         
            +
                    )
         
     | 
| 37 | 
         
            +
                    causal_mask_2d = torch.tril(
         
     | 
| 38 | 
         
            +
                        torch.ones((Tq, Tk), dtype=torch.bool, device=device)
         
     | 
| 39 | 
         
            +
                    )  # Shape [Tq, Tk]
         
     | 
| 40 | 
         
            +
                    causal_mask = mask & causal_mask_2d  # Shape [B, Tq, Tk]
         
     | 
| 41 | 
         
            +
                    return causal_mask.unsqueeze(1)  # Shape [B, 1, Tq, Tk]
         
     | 
| 42 | 
         
            +
                else:
         
     | 
| 43 | 
         
            +
                    return mask.unsqueeze(1)  # Shape [B, 1, Tq, Tk]
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            @dataclass
         
     | 
| 47 | 
         
            +
            class EncoderInferenceState:
         
     | 
| 48 | 
         
            +
                """Parameters specifically for encoder inference."""
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                max_seq_len: int
         
     | 
| 51 | 
         
            +
                device: torch.device
         
     | 
| 52 | 
         
            +
                positions: torch.Tensor
         
     | 
| 53 | 
         
            +
                padding_mask: torch.Tensor
         
     | 
| 54 | 
         
            +
                attn_mask: torch.Tensor
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                @classmethod
         
     | 
| 57 | 
         
            +
                def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
         
     | 
| 58 | 
         
            +
                    """Creates EtorchrInferenceParams from DiaConfig and a device."""
         
     | 
| 59 | 
         
            +
                    device = cond_src.device
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    positions = (
         
     | 
| 62 | 
         
            +
                        torch.arange(config.data.text_length, device=device)
         
     | 
| 63 | 
         
            +
                        .to(torch.long)
         
     | 
| 64 | 
         
            +
                        .unsqueeze(0)
         
     | 
| 65 | 
         
            +
                        .expand(2, -1)
         
     | 
| 66 | 
         
            +
                    )
         
     | 
| 67 | 
         
            +
                    padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
         
     | 
| 68 | 
         
            +
                    attn_mask = create_attn_mask(
         
     | 
| 69 | 
         
            +
                        padding_mask, padding_mask, device, is_causal=False
         
     | 
| 70 | 
         
            +
                    )
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    return cls(
         
     | 
| 73 | 
         
            +
                        max_seq_len=config.data.text_length,
         
     | 
| 74 | 
         
            +
                        device=device,
         
     | 
| 75 | 
         
            +
                        positions=positions,
         
     | 
| 76 | 
         
            +
                        padding_mask=padding_mask,
         
     | 
| 77 | 
         
            +
                        attn_mask=attn_mask,
         
     | 
| 78 | 
         
            +
                    )
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            class KVCache:
         
     | 
| 82 | 
         
            +
                def __init__(
         
     | 
| 83 | 
         
            +
                    self,
         
     | 
| 84 | 
         
            +
                    num_heads: int,
         
     | 
| 85 | 
         
            +
                    max_len: int,
         
     | 
| 86 | 
         
            +
                    head_dim: int,
         
     | 
| 87 | 
         
            +
                    dtype: torch.dtype,
         
     | 
| 88 | 
         
            +
                    device: torch.device,
         
     | 
| 89 | 
         
            +
                    k: torch.Tensor | None = None,
         
     | 
| 90 | 
         
            +
                    v: torch.Tensor | None = None,
         
     | 
| 91 | 
         
            +
                ):
         
     | 
| 92 | 
         
            +
                    self.k = (
         
     | 
| 93 | 
         
            +
                        torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
         
     | 
| 94 | 
         
            +
                        if k is None
         
     | 
| 95 | 
         
            +
                        else k
         
     | 
| 96 | 
         
            +
                    )
         
     | 
| 97 | 
         
            +
                    self.v = (
         
     | 
| 98 | 
         
            +
                        torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
         
     | 
| 99 | 
         
            +
                        if v is None
         
     | 
| 100 | 
         
            +
                        else v
         
     | 
| 101 | 
         
            +
                    )
         
     | 
| 102 | 
         
            +
                    self.current_idx = torch.tensor(0)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                @classmethod
         
     | 
| 105 | 
         
            +
                def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
         
     | 
| 106 | 
         
            +
                    return cls(
         
     | 
| 107 | 
         
            +
                        num_heads=k.shape[1],
         
     | 
| 108 | 
         
            +
                        max_len=k.shape[2],
         
     | 
| 109 | 
         
            +
                        head_dim=k.shape[3],
         
     | 
| 110 | 
         
            +
                        dtype=k.dtype,
         
     | 
| 111 | 
         
            +
                        device=k.device,
         
     | 
| 112 | 
         
            +
                        k=k,
         
     | 
| 113 | 
         
            +
                        v=v,
         
     | 
| 114 | 
         
            +
                    )
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def update(
         
     | 
| 117 | 
         
            +
                    self, k: torch.Tensor, v: torch.Tensor
         
     | 
| 118 | 
         
            +
                ) -> tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 119 | 
         
            +
                    self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
         
     | 
| 120 | 
         
            +
                    self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
         
     | 
| 121 | 
         
            +
                    self.current_idx += 1
         
     | 
| 122 | 
         
            +
                    return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def prefill(
         
     | 
| 125 | 
         
            +
                    self, k: torch.Tensor, v: torch.Tensor
         
     | 
| 126 | 
         
            +
                ) -> tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 127 | 
         
            +
                    prefill_len = k.shape[2]
         
     | 
| 128 | 
         
            +
                    self.k[:, :, :prefill_len, :] = k
         
     | 
| 129 | 
         
            +
                    self.v[:, :, :prefill_len, :] = v
         
     | 
| 130 | 
         
            +
                    self.current_idx = prefill_len - 1
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            @dataclass
         
     | 
| 134 | 
         
            +
            class DecoderInferenceState:
         
     | 
| 135 | 
         
            +
                """Parameters specifically for decoder inference."""
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                device: torch.device
         
     | 
| 138 | 
         
            +
                dtype: torch.dtype
         
     | 
| 139 | 
         
            +
                enc_out: torch.Tensor
         
     | 
| 140 | 
         
            +
                enc_positions: torch.Tensor
         
     | 
| 141 | 
         
            +
                dec_positions: torch.Tensor
         
     | 
| 142 | 
         
            +
                dec_cross_attn_mask: torch.Tensor
         
     | 
| 143 | 
         
            +
                self_attn_cache: list[KVCache]
         
     | 
| 144 | 
         
            +
                cross_attn_cache: list[KVCache]
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                @classmethod
         
     | 
| 147 | 
         
            +
                def new(
         
     | 
| 148 | 
         
            +
                    cls,
         
     | 
| 149 | 
         
            +
                    config: DiaConfig,
         
     | 
| 150 | 
         
            +
                    enc_state: EncoderInferenceState,
         
     | 
| 151 | 
         
            +
                    enc_out: torch.Tensor,
         
     | 
| 152 | 
         
            +
                    dec_cross_attn_cache: list[KVCache],
         
     | 
| 153 | 
         
            +
                    compute_dtype: torch.dtype,
         
     | 
| 154 | 
         
            +
                ) -> "DecoderInferenceState":
         
     | 
| 155 | 
         
            +
                    """Creates DecoderInferenceParams from DiaConfig and a device."""
         
     | 
| 156 | 
         
            +
                    device = enc_out.device
         
     | 
| 157 | 
         
            +
                    max_audio_len = config.data.audio_length
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    dec_positions = torch.full(
         
     | 
| 160 | 
         
            +
                        (2, 1), fill_value=0, dtype=torch.long, device=device
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
                    tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
         
     | 
| 163 | 
         
            +
                    dec_cross_attn_mask = create_attn_mask(
         
     | 
| 164 | 
         
            +
                        tgt_padding_mask, enc_state.padding_mask, device, is_causal=False
         
     | 
| 165 | 
         
            +
                    )
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    self_attn_cache = [
         
     | 
| 168 | 
         
            +
                        KVCache(
         
     | 
| 169 | 
         
            +
                            config.model.decoder.kv_heads,
         
     | 
| 170 | 
         
            +
                            max_audio_len,
         
     | 
| 171 | 
         
            +
                            config.model.decoder.gqa_head_dim,
         
     | 
| 172 | 
         
            +
                            compute_dtype,
         
     | 
| 173 | 
         
            +
                            device,
         
     | 
| 174 | 
         
            +
                        )
         
     | 
| 175 | 
         
            +
                        for _ in range(config.model.decoder.n_layer)
         
     | 
| 176 | 
         
            +
                    ]
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    return cls(
         
     | 
| 179 | 
         
            +
                        device=device,
         
     | 
| 180 | 
         
            +
                        dtype=compute_dtype,
         
     | 
| 181 | 
         
            +
                        enc_out=enc_out,
         
     | 
| 182 | 
         
            +
                        enc_positions=enc_state.positions,
         
     | 
| 183 | 
         
            +
                        dec_positions=dec_positions,
         
     | 
| 184 | 
         
            +
                        dec_cross_attn_mask=dec_cross_attn_mask,
         
     | 
| 185 | 
         
            +
                        self_attn_cache=self_attn_cache,
         
     | 
| 186 | 
         
            +
                        cross_attn_cache=dec_cross_attn_cache,
         
     | 
| 187 | 
         
            +
                    )
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
         
     | 
| 190 | 
         
            +
                    if step_to is None:
         
     | 
| 191 | 
         
            +
                        step_to = step_from + 1
         
     | 
| 192 | 
         
            +
                    self.dec_positions = (
         
     | 
| 193 | 
         
            +
                        torch.arange(step_from, step_to, device=self.device)
         
     | 
| 194 | 
         
            +
                        .unsqueeze(0)
         
     | 
| 195 | 
         
            +
                        .expand(2, -1)
         
     | 
| 196 | 
         
            +
                    )
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
            @dataclass
         
     | 
| 200 | 
         
            +
            class DecoderOutput:
         
     | 
| 201 | 
         
            +
                generated_tokens: torch.Tensor
         
     | 
| 202 | 
         
            +
                prefill_step: int
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                @classmethod
         
     | 
| 205 | 
         
            +
                def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
         
     | 
| 206 | 
         
            +
                    max_audio_len = config.data.audio_length
         
     | 
| 207 | 
         
            +
                    return cls(
         
     | 
| 208 | 
         
            +
                        generated_tokens=torch.full(
         
     | 
| 209 | 
         
            +
                            (max_audio_len, config.data.channels),
         
     | 
| 210 | 
         
            +
                            fill_value=-1,
         
     | 
| 211 | 
         
            +
                            dtype=torch.int,
         
     | 
| 212 | 
         
            +
                            device=device,
         
     | 
| 213 | 
         
            +
                        ),
         
     | 
| 214 | 
         
            +
                        prefill_step=0,
         
     | 
| 215 | 
         
            +
                    )
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
         
     | 
| 218 | 
         
            +
                    if step_to is None:
         
     | 
| 219 | 
         
            +
                        step_to = step_from + 1
         
     | 
| 220 | 
         
            +
                    return self.generated_tokens[step_from:step_to, :]
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
         
     | 
| 223 | 
         
            +
                    if apply_mask:
         
     | 
| 224 | 
         
            +
                        mask = self.generated_tokens[step : step + 1, :] == -1
         
     | 
| 225 | 
         
            +
                        self.generated_tokens[step : step + 1, :] = torch.where(
         
     | 
| 226 | 
         
            +
                            mask, dec_out, self.generated_tokens[step : step + 1, :]
         
     | 
| 227 | 
         
            +
                        )
         
     | 
| 228 | 
         
            +
                    else:
         
     | 
| 229 | 
         
            +
                        self.generated_tokens[step : step + 1, :] = dec_out
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                def prefill(self, dec_out: torch.Tensor, prefill_step: int):
         
     | 
| 232 | 
         
            +
                    length = dec_out.shape[0]
         
     | 
| 233 | 
         
            +
                    self.generated_tokens[0:length, :] = dec_out
         
     | 
| 234 | 
         
            +
                    self.prefill_step = prefill_step
         
     |