Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Upload 14 files
Browse files- app.py +561 -0
 - qhash/autoencoder.py +26 -0
 - qhash/backbone.py +50 -0
 - qhash/codebook_pattern.py +12 -0
 - qhash/conditioning.py +373 -0
 - qhash/config.py +38 -0
 - qhash/model.py +270 -0
 - qhash/sampling.py +141 -0
 - qhash/speaker_cloning.py +406 -0
 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,561 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import shlex
         
     | 
| 3 | 
         
            +
            import subprocess
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            subprocess.run(
         
     | 
| 6 | 
         
            +
                shlex.split("pip install flash-attn --no-build-isolation"),
         
     | 
| 7 | 
         
            +
                env=os.environ | {"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
         
     | 
| 8 | 
         
            +
                check=True,
         
     | 
| 9 | 
         
            +
            )
         
     | 
| 10 | 
         
            +
            subprocess.run(
         
     | 
| 11 | 
         
            +
                shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"),
         
     | 
| 12 | 
         
            +
                check=True,
         
     | 
| 13 | 
         
            +
            )
         
     | 
| 14 | 
         
            +
            subprocess.run(
         
     | 
| 15 | 
         
            +
                shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"),
         
     | 
| 16 | 
         
            +
                check=True,
         
     | 
| 17 | 
         
            +
            )
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import spaces
         
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
            import torchaudio
         
     | 
| 22 | 
         
            +
            import gradio as gr
         
     | 
| 23 | 
         
            +
            from os import getenv
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            from qhash.model import Zonos
         
     | 
| 26 | 
         
            +
            from qhash.conditioning import make_cond_dict, supported_language_codes
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            device = "cuda"
         
     | 
| 29 | 
         
            +
            MODEL_NAMES = ["Quantumhash/Qhash-v0.1-transformer", "Quantumhash/Qhash-v0.1-hybrid"]
         
     | 
| 30 | 
         
            +
            MODELS = {name: Zonos.from_pretrained(name, device=device) for name in MODEL_NAMES}
         
     | 
| 31 | 
         
            +
            for model in MODELS.values():
         
     | 
| 32 | 
         
            +
                model.requires_grad_(False).eval()
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def update_ui(model_choice):
         
     | 
| 36 | 
         
            +
                """
         
     | 
| 37 | 
         
            +
                Dynamically show/hide UI elements based on the model's conditioners.
         
     | 
| 38 | 
         
            +
                We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                model = MODELS[model_choice]
         
     | 
| 41 | 
         
            +
                cond_names = [c.name for c in model.prefix_conditioner.conditioners]
         
     | 
| 42 | 
         
            +
                print("Conditioners in this model:", cond_names)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                text_update = gr.update(visible=("espeak" in cond_names))
         
     | 
| 45 | 
         
            +
                language_update = gr.update(visible=("espeak" in cond_names))
         
     | 
| 46 | 
         
            +
                speaker_audio_update = gr.update(visible=("speaker" in cond_names))
         
     | 
| 47 | 
         
            +
                prefix_audio_update = gr.update(visible=True)
         
     | 
| 48 | 
         
            +
                emotion1_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 49 | 
         
            +
                emotion2_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 50 | 
         
            +
                emotion3_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 51 | 
         
            +
                emotion4_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 52 | 
         
            +
                emotion5_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 53 | 
         
            +
                emotion6_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 54 | 
         
            +
                emotion7_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 55 | 
         
            +
                emotion8_update = gr.update(visible=("emotion" in cond_names))
         
     | 
| 56 | 
         
            +
                vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
         
     | 
| 57 | 
         
            +
                fmax_slider_update = gr.update(visible=("fmax" in cond_names))
         
     | 
| 58 | 
         
            +
                pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
         
     | 
| 59 | 
         
            +
                speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
         
     | 
| 60 | 
         
            +
                dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
         
     | 
| 61 | 
         
            +
                speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
         
     | 
| 62 | 
         
            +
                unconditional_keys_update = gr.update(
         
     | 
| 63 | 
         
            +
                    choices=[name for name in cond_names if name not in ("espeak", "language_id")]
         
     | 
| 64 | 
         
            +
                )
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                return (
         
     | 
| 67 | 
         
            +
                    text_update,
         
     | 
| 68 | 
         
            +
                    language_update,
         
     | 
| 69 | 
         
            +
                    speaker_audio_update,
         
     | 
| 70 | 
         
            +
                    prefix_audio_update,
         
     | 
| 71 | 
         
            +
                    emotion1_update,
         
     | 
| 72 | 
         
            +
                    emotion2_update,
         
     | 
| 73 | 
         
            +
                    emotion3_update,
         
     | 
| 74 | 
         
            +
                    emotion4_update,
         
     | 
| 75 | 
         
            +
                    emotion5_update,
         
     | 
| 76 | 
         
            +
                    emotion6_update,
         
     | 
| 77 | 
         
            +
                    emotion7_update,
         
     | 
| 78 | 
         
            +
                    emotion8_update,
         
     | 
| 79 | 
         
            +
                    vq_single_slider_update,
         
     | 
| 80 | 
         
            +
                    fmax_slider_update,
         
     | 
| 81 | 
         
            +
                    pitch_std_slider_update,
         
     | 
| 82 | 
         
            +
                    speaking_rate_slider_update,
         
     | 
| 83 | 
         
            +
                    dnsmos_slider_update,
         
     | 
| 84 | 
         
            +
                    speaker_noised_checkbox_update,
         
     | 
| 85 | 
         
            +
                    unconditional_keys_update,
         
     | 
| 86 | 
         
            +
                )
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            @spaces.GPU(duration=120)
         
     | 
| 90 | 
         
            +
            def generate_audio(
         
     | 
| 91 | 
         
            +
                model_choice,
         
     | 
| 92 | 
         
            +
                text,
         
     | 
| 93 | 
         
            +
                language,
         
     | 
| 94 | 
         
            +
                speaker_audio,
         
     | 
| 95 | 
         
            +
                prefix_audio,
         
     | 
| 96 | 
         
            +
                e1,
         
     | 
| 97 | 
         
            +
                e2,
         
     | 
| 98 | 
         
            +
                e3,
         
     | 
| 99 | 
         
            +
                e4,
         
     | 
| 100 | 
         
            +
                e5,
         
     | 
| 101 | 
         
            +
                e6,
         
     | 
| 102 | 
         
            +
                e7,
         
     | 
| 103 | 
         
            +
                e8,
         
     | 
| 104 | 
         
            +
                vq_single,
         
     | 
| 105 | 
         
            +
                fmax,
         
     | 
| 106 | 
         
            +
                pitch_std,
         
     | 
| 107 | 
         
            +
                speaking_rate,
         
     | 
| 108 | 
         
            +
                dnsmos_ovrl,
         
     | 
| 109 | 
         
            +
                speaker_noised,
         
     | 
| 110 | 
         
            +
                cfg_scale,
         
     | 
| 111 | 
         
            +
                min_p,
         
     | 
| 112 | 
         
            +
                seed,
         
     | 
| 113 | 
         
            +
                randomize_seed,
         
     | 
| 114 | 
         
            +
                unconditional_keys,
         
     | 
| 115 | 
         
            +
                progress=gr.Progress(),
         
     | 
| 116 | 
         
            +
            ):
         
     | 
| 117 | 
         
            +
                """
         
     | 
| 118 | 
         
            +
                Generates audio based on the provided UI parameters.
         
     | 
| 119 | 
         
            +
                We do NOT use language_id or ctc_loss even if the model has them.
         
     | 
| 120 | 
         
            +
                """
         
     | 
| 121 | 
         
            +
                selected_model = MODELS[model_choice]
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                speaker_noised_bool = bool(speaker_noised)
         
     | 
| 124 | 
         
            +
                fmax = float(fmax)
         
     | 
| 125 | 
         
            +
                pitch_std = float(pitch_std)
         
     | 
| 126 | 
         
            +
                speaking_rate = float(speaking_rate)
         
     | 
| 127 | 
         
            +
                dnsmos_ovrl = float(dnsmos_ovrl)
         
     | 
| 128 | 
         
            +
                cfg_scale = float(cfg_scale)
         
     | 
| 129 | 
         
            +
                min_p = float(min_p)
         
     | 
| 130 | 
         
            +
                seed = int(seed)
         
     | 
| 131 | 
         
            +
                max_new_tokens = 86 * 30
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                if randomize_seed:
         
     | 
| 134 | 
         
            +
                    seed = torch.randint(0, 2**32 - 1, (1,)).item()
         
     | 
| 135 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                speaker_embedding = None
         
     | 
| 138 | 
         
            +
                if speaker_audio is not None and "speaker" not in unconditional_keys:
         
     | 
| 139 | 
         
            +
                    wav, sr = torchaudio.load(speaker_audio)
         
     | 
| 140 | 
         
            +
                    speaker_embedding = selected_model.make_speaker_embedding(wav, sr)
         
     | 
| 141 | 
         
            +
                    speaker_embedding = speaker_embedding.to(device, dtype=torch.bfloat16)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                audio_prefix_codes = None
         
     | 
| 144 | 
         
            +
                if prefix_audio is not None:
         
     | 
| 145 | 
         
            +
                    wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
         
     | 
| 146 | 
         
            +
                    wav_prefix = wav_prefix.mean(0, keepdim=True)
         
     | 
| 147 | 
         
            +
                    wav_prefix = torchaudio.functional.resample(wav_prefix, sr_prefix, selected_model.autoencoder.sampling_rate)
         
     | 
| 148 | 
         
            +
                    wav_prefix = wav_prefix.to(device, dtype=torch.float32)
         
     | 
| 149 | 
         
            +
                    with torch.autocast(device, dtype=torch.float32):
         
     | 
| 150 | 
         
            +
                        audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                emotion_tensor = torch.tensor(list(map(float, [e1, e2, e3, e4, e5, e6, e7, e8])), device=device)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                vq_val = float(vq_single)
         
     | 
| 155 | 
         
            +
                vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                cond_dict = make_cond_dict(
         
     | 
| 158 | 
         
            +
                    text=text,
         
     | 
| 159 | 
         
            +
                    language=language,
         
     | 
| 160 | 
         
            +
                    speaker=speaker_embedding,
         
     | 
| 161 | 
         
            +
                    emotion=emotion_tensor,
         
     | 
| 162 | 
         
            +
                    vqscore_8=vq_tensor,
         
     | 
| 163 | 
         
            +
                    fmax=fmax,
         
     | 
| 164 | 
         
            +
                    pitch_std=pitch_std,
         
     | 
| 165 | 
         
            +
                    speaking_rate=speaking_rate,
         
     | 
| 166 | 
         
            +
                    dnsmos_ovrl=dnsmos_ovrl,
         
     | 
| 167 | 
         
            +
                    speaker_noised=speaker_noised_bool,
         
     | 
| 168 | 
         
            +
                    device=device,
         
     | 
| 169 | 
         
            +
                    unconditional_keys=unconditional_keys,
         
     | 
| 170 | 
         
            +
                )
         
     | 
| 171 | 
         
            +
                conditioning = selected_model.prepare_conditioning(cond_dict)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                estimated_generation_duration = 30 * len(text) / 400
         
     | 
| 174 | 
         
            +
                estimated_total_steps = int(estimated_generation_duration * 86)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool:
         
     | 
| 177 | 
         
            +
                    progress((step, estimated_total_steps))
         
     | 
| 178 | 
         
            +
                    return True
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                codes = selected_model.generate(
         
     | 
| 181 | 
         
            +
                    prefix_conditioning=conditioning,
         
     | 
| 182 | 
         
            +
                    audio_prefix_codes=audio_prefix_codes,
         
     | 
| 183 | 
         
            +
                    max_new_tokens=max_new_tokens,
         
     | 
| 184 | 
         
            +
                    cfg_scale=cfg_scale,
         
     | 
| 185 | 
         
            +
                    batch_size=1,
         
     | 
| 186 | 
         
            +
                    sampling_params=dict(min_p=min_p),
         
     | 
| 187 | 
         
            +
                    callback=update_progress,
         
     | 
| 188 | 
         
            +
                )
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
         
     | 
| 191 | 
         
            +
                sr_out = selected_model.autoencoder.sampling_rate
         
     | 
| 192 | 
         
            +
                if wav_out.dim() == 2 and wav_out.size(0) > 1:
         
     | 
| 193 | 
         
            +
                    wav_out = wav_out[0:1, :]
         
     | 
| 194 | 
         
            +
                return (sr_out, wav_out.squeeze().numpy()), seed
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
            # Custom CSS for pastel gradient background and enhanced UI
         
     | 
| 198 | 
         
            +
            custom_css = """
         
     | 
| 199 | 
         
            +
            .gradio-container {
         
     | 
| 200 | 
         
            +
                background: #0C101B;
         
     | 
| 201 | 
         
            +
                background-size: 400% 400%;
         
     | 
| 202 | 
         
            +
                animation: gradient 15s ease infinite;
         
     | 
| 203 | 
         
            +
            }
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
            @keyframes gradient {
         
     | 
| 206 | 
         
            +
                0% {
         
     | 
| 207 | 
         
            +
                    background-position: 0% 50%;
         
     | 
| 208 | 
         
            +
                }
         
     | 
| 209 | 
         
            +
                50% {
         
     | 
| 210 | 
         
            +
                    background-position: 100% 50%;
         
     | 
| 211 | 
         
            +
                }
         
     | 
| 212 | 
         
            +
                100% {
         
     | 
| 213 | 
         
            +
                    background-position: 0% 50%;
         
     | 
| 214 | 
         
            +
                }
         
     | 
| 215 | 
         
            +
            }
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
            .container {
         
     | 
| 218 | 
         
            +
                max-width: 1200px;
         
     | 
| 219 | 
         
            +
                margin: 0 auto;
         
     | 
| 220 | 
         
            +
                padding: 20px;
         
     | 
| 221 | 
         
            +
            }
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            .panel {
         
     | 
| 224 | 
         
            +
                background-color: rgba(159, 153, 96, 0.9);
         
     | 
| 225 | 
         
            +
                border-radius: 16px;
         
     | 
| 226 | 
         
            +
                padding: 20px;
         
     | 
| 227 | 
         
            +
                box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
         
     | 
| 228 | 
         
            +
                margin-bottom: 16px;
         
     | 
| 229 | 
         
            +
                backdrop-filter: blur(5px);
         
     | 
| 230 | 
         
            +
                transition: all 0.3s ease;
         
     | 
| 231 | 
         
            +
            }
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
            .panel p {
         
     | 
| 234 | 
         
            +
                font-size: 1.1em;
         
     | 
| 235 | 
         
            +
                color: black;
         
     | 
| 236 | 
         
            +
            }
         
     | 
| 237 | 
         
            +
            .panel:hover {
         
     | 
| 238 | 
         
            +
                box-shadow: 0 6px 16px rgba(0, 0, 0, 0.12);
         
     | 
| 239 | 
         
            +
                transform: translateY(-2px);
         
     | 
| 240 | 
         
            +
            }
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
            .title {
         
     | 
| 243 | 
         
            +
                font-size: 1.2em;
         
     | 
| 244 | 
         
            +
                font-weight: 600;
         
     | 
| 245 | 
         
            +
                margin-bottom: 12px;
         
     | 
| 246 | 
         
            +
                color: #6a3ea1;
         
     | 
| 247 | 
         
            +
                border-bottom: 2px solid #f0e6ff;
         
     | 
| 248 | 
         
            +
                padding-bottom: 8px;
         
     | 
| 249 | 
         
            +
            }
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
            .slider-container {
         
     | 
| 252 | 
         
            +
                background-color: rgba(255, 255, 255, 0.5);
         
     | 
| 253 | 
         
            +
                border-radius: 10px;
         
     | 
| 254 | 
         
            +
                padding: 10px;
         
     | 
| 255 | 
         
            +
                margin: 5px 0;
         
     | 
| 256 | 
         
            +
            }
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            /* Make sliders more appealing */
         
     | 
| 259 | 
         
            +
            input[type=range] {
         
     | 
| 260 | 
         
            +
                height: 5px;
         
     | 
| 261 | 
         
            +
                appearance: none;
         
     | 
| 262 | 
         
            +
                width: 100%;
         
     | 
| 263 | 
         
            +
                border-radius: 3px;
         
     | 
| 264 | 
         
            +
                background: linear-gradient(90deg, #9c83e0, #83b1e0);
         
     | 
| 265 | 
         
            +
            }
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
            .generate-button {
         
     | 
| 268 | 
         
            +
                background: linear-gradient(90deg, #a673ff, #7c4dff);
         
     | 
| 269 | 
         
            +
                color: white;
         
     | 
| 270 | 
         
            +
                border: none;
         
     | 
| 271 | 
         
            +
                border-radius: 8px;
         
     | 
| 272 | 
         
            +
                padding: 12px 24px;
         
     | 
| 273 | 
         
            +
                font-size: 16px;
         
     | 
| 274 | 
         
            +
                font-weight: 500;
         
     | 
| 275 | 
         
            +
                cursor: pointer;
         
     | 
| 276 | 
         
            +
                transition: all 0.3s ease;
         
     | 
| 277 | 
         
            +
                box-shadow: 0 4px 10px rgba(124, 77, 255, 0.2);
         
     | 
| 278 | 
         
            +
                display: block;
         
     | 
| 279 | 
         
            +
                width: 100%;
         
     | 
| 280 | 
         
            +
                margin: 20px 0;
         
     | 
| 281 | 
         
            +
            }
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
            .generate-button:hover {
         
     | 
| 284 | 
         
            +
                background: linear-gradient(90deg, #9c5eff, #6a3aff);
         
     | 
| 285 | 
         
            +
                box-shadow: 0 6px 15px rgba(124, 77, 255, 0.3);
         
     | 
| 286 | 
         
            +
                transform: translateY(-2px);
         
     | 
| 287 | 
         
            +
            }
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
            /* Tabs styling */
         
     | 
| 290 | 
         
            +
            .tabs {
         
     | 
| 291 | 
         
            +
                display: flex;
         
     | 
| 292 | 
         
            +
                border-bottom: 1px solid #e0e0e0;
         
     | 
| 293 | 
         
            +
                margin-bottom: 20px;
         
     | 
| 294 | 
         
            +
            }
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
            .tab {
         
     | 
| 297 | 
         
            +
                padding: 10px 20px;
         
     | 
| 298 | 
         
            +
                cursor: pointer;
         
     | 
| 299 | 
         
            +
                transition: all 0.3s ease;
         
     | 
| 300 | 
         
            +
                background-color: transparent;
         
     | 
| 301 | 
         
            +
                border: none;
         
     | 
| 302 | 
         
            +
                color: #666;
         
     | 
| 303 | 
         
            +
            }
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
            .tab.active {
         
     | 
| 306 | 
         
            +
                color: #7c4dff;
         
     | 
| 307 | 
         
            +
                border-bottom: 3px solid #7c4dff;
         
     | 
| 308 | 
         
            +
                font-weight: 600;
         
     | 
| 309 | 
         
            +
            }
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
            /* Emotion sliders container */
         
     | 
| 312 | 
         
            +
            .emotion-grid {
         
     | 
| 313 | 
         
            +
                display: grid;
         
     | 
| 314 | 
         
            +
                grid-template-columns: repeat(4, 1fr);
         
     | 
| 315 | 
         
            +
                gap: 12px;
         
     | 
| 316 | 
         
            +
            }
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
            /* Header styling */
         
     | 
| 319 | 
         
            +
            .app-header {
         
     | 
| 320 | 
         
            +
                text-align: center;
         
     | 
| 321 | 
         
            +
                margin-bottom: 25px;
         
     | 
| 322 | 
         
            +
            }
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
            .app-header h1 {
         
     | 
| 325 | 
         
            +
                font-size: 2.5em;
         
     | 
| 326 | 
         
            +
                color: #6a3ea1;
         
     | 
| 327 | 
         
            +
                margin-bottom: 8px;
         
     | 
| 328 | 
         
            +
                font-weight: 700;
         
     | 
| 329 | 
         
            +
            }
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
            .app-header p {
         
     | 
| 332 | 
         
            +
                font-size: 1.1em;
         
     | 
| 333 | 
         
            +
                color: #6a3ea1;
         
     | 
| 334 | 
         
            +
                margin-bottom: 20px;
         
     | 
| 335 | 
         
            +
            }
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
            /* Audio player styling */
         
     | 
| 338 | 
         
            +
            .audio-output {
         
     | 
| 339 | 
         
            +
                margin-top: 20px;
         
     | 
| 340 | 
         
            +
            }
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
            /* Make output area more prominent */
         
     | 
| 343 | 
         
            +
            .output-container {
         
     | 
| 344 | 
         
            +
                background-color: rgba(24, 82, 79, 0.85);
         
     | 
| 345 | 
         
            +
                border-radius: 16px;
         
     | 
| 346 | 
         
            +
                padding: 24px;
         
     | 
| 347 | 
         
            +
                box-shadow: 0 8px 18px rgba(0, 0, 0, 0.1);
         
     | 
| 348 | 
         
            +
                margin-top: 20px;
         
     | 
| 349 | 
         
            +
            }
         
     | 
| 350 | 
         
            +
            """
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
            def build_interface():
         
     | 
| 354 | 
         
            +
                # Build interface with enhanced visual elements and layout
         
     | 
| 355 | 
         
            +
                with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
         
     | 
| 356 | 
         
            +
                    # Header section
         
     | 
| 357 | 
         
            +
                    with gr.Column(elem_classes="app-header"):
         
     | 
| 358 | 
         
            +
                        gr.Markdown("# ✨ Qhash Text-to-Speech Clone ✨")
         
     | 
| 359 | 
         
            +
                        gr.Markdown("Create natural-sounding speech with customizable voice characteristics")
         
     | 
| 360 | 
         
            +
                    
         
     | 
| 361 | 
         
            +
                    # Main content container 
         
     | 
| 362 | 
         
            +
                    with gr.Column(elem_classes="container"):
         
     | 
| 363 | 
         
            +
                        # First panel - Text & Model Selection
         
     | 
| 364 | 
         
            +
                        with gr.Column(elem_classes="panel"):
         
     | 
| 365 | 
         
            +
                            gr.Markdown("💬 Text & Model Configuration")
         
     | 
| 366 | 
         
            +
                            with gr.Row():
         
     | 
| 367 | 
         
            +
                                with gr.Column(scale=2):
         
     | 
| 368 | 
         
            +
                                    model_choice = gr.Dropdown(
         
     | 
| 369 | 
         
            +
                                        choices=MODEL_NAMES,
         
     | 
| 370 | 
         
            +
                                        value="Quantumhash/Qhash-v0.1-transformer",
         
     | 
| 371 | 
         
            +
                                        label="Qhash Model Type",
         
     | 
| 372 | 
         
            +
                                        info="Select the model variant to use.",
         
     | 
| 373 | 
         
            +
                                    )
         
     | 
| 374 | 
         
            +
                                    text = gr.Textbox(
         
     | 
| 375 | 
         
            +
                                        label="Text to Synthesize",
         
     | 
| 376 | 
         
            +
                                        value="Qhash uses eSpeak for text to phoneme conversion!",
         
     | 
| 377 | 
         
            +
                                        lines=4,
         
     | 
| 378 | 
         
            +
                                        max_length=500,
         
     | 
| 379 | 
         
            +
                                    )
         
     | 
| 380 | 
         
            +
                                    language = gr.Dropdown(
         
     | 
| 381 | 
         
            +
                                        choices=supported_language_codes,
         
     | 
| 382 | 
         
            +
                                        value="en-us",
         
     | 
| 383 | 
         
            +
                                        label="Language Code",
         
     | 
| 384 | 
         
            +
                                        info="Select a language code.",
         
     | 
| 385 | 
         
            +
                                    )
         
     | 
| 386 | 
         
            +
                                with gr.Column(scale=1):
         
     | 
| 387 | 
         
            +
                                    prefix_audio = gr.Audio(
         
     | 
| 388 | 
         
            +
                                        value="assets/silence_100ms.wav",
         
     | 
| 389 | 
         
            +
                                        label="Optional Prefix Audio (continue from this audio)",
         
     | 
| 390 | 
         
            +
                                        type="filepath",
         
     | 
| 391 | 
         
            +
                                    )
         
     | 
| 392 | 
         
            +
                        
         
     | 
| 393 | 
         
            +
                        # Second panel - Voice Characteristics
         
     | 
| 394 | 
         
            +
                        with gr.Column(elem_classes="panel"):
         
     | 
| 395 | 
         
            +
                            gr.Markdown("🎤 Voice Characteristics")
         
     | 
| 396 | 
         
            +
                            with gr.Row():
         
     | 
| 397 | 
         
            +
                                with gr.Column(scale=1):
         
     | 
| 398 | 
         
            +
                                    speaker_audio = gr.Audio(
         
     | 
| 399 | 
         
            +
                                        label="Optional Speaker Audio (for voice cloning)",
         
     | 
| 400 | 
         
            +
                                        type="filepath",
         
     | 
| 401 | 
         
            +
                                    )
         
     | 
| 402 | 
         
            +
                                    speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
         
     | 
| 403 | 
         
            +
                                
         
     | 
| 404 | 
         
            +
                                with gr.Column(scale=2):
         
     | 
| 405 | 
         
            +
                                    with gr.Row():
         
     | 
| 406 | 
         
            +
                                        with gr.Column():
         
     | 
| 407 | 
         
            +
                                            dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="Voice Quality", elem_classes="slider-container")
         
     | 
| 408 | 
         
            +
                                            fmax_slider = gr.Slider(0, 24000, value=24000, step=1, label="Frequency Max (Hz)", elem_classes="slider-container")
         
     | 
| 409 | 
         
            +
                                            vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="Voice Clarity", elem_classes="slider-container")
         
     | 
| 410 | 
         
            +
                                        with gr.Column():
         
     | 
| 411 | 
         
            +
                                            pitch_std_slider = gr.Slider(0.0, 300.0, value=45.0, step=1, label="Pitch Variation", elem_classes="slider-container")
         
     | 
| 412 | 
         
            +
                                            speaking_rate_slider = gr.Slider(5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate", elem_classes="slider-container")
         
     | 
| 413 | 
         
            +
                        
         
     | 
| 414 | 
         
            +
                        # Third panel - Generation Parameters
         
     | 
| 415 | 
         
            +
                        with gr.Column(elem_classes="panel"):
         
     | 
| 416 | 
         
            +
                            gr.Markdown("⚙️ Generation Parameters")
         
     | 
| 417 | 
         
            +
                            with gr.Row():
         
     | 
| 418 | 
         
            +
                                with gr.Column():
         
     | 
| 419 | 
         
            +
                                    cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="Guidance Scale", elem_classes="slider-container")
         
     | 
| 420 | 
         
            +
                                    min_p_slider = gr.Slider(0.0, 1.0, 0.15, 0.01, label="Min P (Randomness)", elem_classes="slider-container")
         
     | 
| 421 | 
         
            +
                                with gr.Column():
         
     | 
| 422 | 
         
            +
                                    seed_number = gr.Number(label="Seed", value=420, precision=0)
         
     | 
| 423 | 
         
            +
                                    randomize_seed_toggle = gr.Checkbox(label="Randomize Seed (before generation)", value=True)
         
     | 
| 424 | 
         
            +
                        
         
     | 
| 425 | 
         
            +
                        # Emotion Panel with Tabbed Interface
         
     | 
| 426 | 
         
            +
                        with gr.Accordion("🎭 Emotion Settings", open=False, elem_classes="panel"):
         
     | 
| 427 | 
         
            +
                            gr.Markdown(
         
     | 
| 428 | 
         
            +
                                "Adjust these sliders to control the emotional tone of the generated speech.\n"
         
     | 
| 429 | 
         
            +
                                "For a neutral voice, keep 'Neutral' high and other emotions low."
         
     | 
| 430 | 
         
            +
                            )
         
     | 
| 431 | 
         
            +
                            with gr.Row(elem_classes="emotion-grid"):
         
     | 
| 432 | 
         
            +
                                emotion1 = gr.Slider(0.0, 1.0, 1.0, 0.05, label="Happiness", elem_classes="slider-container")
         
     | 
| 433 | 
         
            +
                                emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness", elem_classes="slider-container")
         
     | 
| 434 | 
         
            +
                                emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust", elem_classes="slider-container")
         
     | 
| 435 | 
         
            +
                                emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear", elem_classes="slider-container")
         
     | 
| 436 | 
         
            +
                            with gr.Row(elem_classes="emotion-grid"):
         
     | 
| 437 | 
         
            +
                                emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise", elem_classes="slider-container")
         
     | 
| 438 | 
         
            +
                                emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger", elem_classes="slider-container")
         
     | 
| 439 | 
         
            +
                                emotion7 = gr.Slider(0.0, 1.0, 0.1, 0.05, label="Other", elem_classes="slider-container")
         
     | 
| 440 | 
         
            +
                                emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral", elem_classes="slider-container")
         
     | 
| 441 | 
         
            +
                        
         
     | 
| 442 | 
         
            +
                        # Advanced Settings Panel
         
     | 
| 443 | 
         
            +
                        with gr.Accordion("⚡ Advanced Settings", open=False, elem_classes="panel"):
         
     | 
| 444 | 
         
            +
                            gr.Markdown(
         
     | 
| 445 | 
         
            +
                                "### Unconditional Toggles\n"
         
     | 
| 446 | 
         
            +
                                "Checking a box will make the model ignore the corresponding conditioning value and make it unconditional.\n"
         
     | 
| 447 | 
         
            +
                                'Practically this means the given conditioning feature will be unconstrained and "filled in automatically".'
         
     | 
| 448 | 
         
            +
                            )
         
     | 
| 449 | 
         
            +
                            unconditional_keys = gr.CheckboxGroup(
         
     | 
| 450 | 
         
            +
                                [
         
     | 
| 451 | 
         
            +
                                    "speaker",
         
     | 
| 452 | 
         
            +
                                    "emotion",
         
     | 
| 453 | 
         
            +
                                    "vqscore_8",
         
     | 
| 454 | 
         
            +
                                    "fmax",
         
     | 
| 455 | 
         
            +
                                    "pitch_std",
         
     | 
| 456 | 
         
            +
                                    "speaking_rate",
         
     | 
| 457 | 
         
            +
                                    "dnsmos_ovrl",
         
     | 
| 458 | 
         
            +
                                    "speaker_noised",
         
     | 
| 459 | 
         
            +
                                ],
         
     | 
| 460 | 
         
            +
                                value=["emotion"],
         
     | 
| 461 | 
         
            +
                                label="Unconditional Keys",
         
     | 
| 462 | 
         
            +
                            )
         
     | 
| 463 | 
         
            +
                        
         
     | 
| 464 | 
         
            +
                        # Generate Button and Output Area
         
     | 
| 465 | 
         
            +
                        with gr.Column(elem_classes="panel output-container"):
         
     | 
| 466 | 
         
            +
                            gr.Markdown("🔊 Generate & Output")
         
     | 
| 467 | 
         
            +
                            generate_button = gr.Button("Generate Audio", elem_classes="generate-button")
         
     | 
| 468 | 
         
            +
                            output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True, elem_classes="audio-output")
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                    model_choice.change(
         
     | 
| 471 | 
         
            +
                        fn=update_ui,
         
     | 
| 472 | 
         
            +
                        inputs=[model_choice],
         
     | 
| 473 | 
         
            +
                        outputs=[
         
     | 
| 474 | 
         
            +
                            text,
         
     | 
| 475 | 
         
            +
                            language,
         
     | 
| 476 | 
         
            +
                            speaker_audio,
         
     | 
| 477 | 
         
            +
                            prefix_audio,
         
     | 
| 478 | 
         
            +
                            emotion1,
         
     | 
| 479 | 
         
            +
                            emotion2,
         
     | 
| 480 | 
         
            +
                            emotion3,
         
     | 
| 481 | 
         
            +
                            emotion4,
         
     | 
| 482 | 
         
            +
                            emotion5,
         
     | 
| 483 | 
         
            +
                            emotion6,
         
     | 
| 484 | 
         
            +
                            emotion7,
         
     | 
| 485 | 
         
            +
                            emotion8,
         
     | 
| 486 | 
         
            +
                            vq_single_slider,
         
     | 
| 487 | 
         
            +
                            fmax_slider,
         
     | 
| 488 | 
         
            +
                            pitch_std_slider,
         
     | 
| 489 | 
         
            +
                            speaking_rate_slider,
         
     | 
| 490 | 
         
            +
                            dnsmos_slider,
         
     | 
| 491 | 
         
            +
                            speaker_noised_checkbox,
         
     | 
| 492 | 
         
            +
                            unconditional_keys,
         
     | 
| 493 | 
         
            +
                        ],
         
     | 
| 494 | 
         
            +
                    )
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                    # On page load, trigger the same UI refresh
         
     | 
| 497 | 
         
            +
                    demo.load(
         
     | 
| 498 | 
         
            +
                        fn=update_ui,
         
     | 
| 499 | 
         
            +
                        inputs=[model_choice],
         
     | 
| 500 | 
         
            +
                        outputs=[
         
     | 
| 501 | 
         
            +
                            text,
         
     | 
| 502 | 
         
            +
                            language,
         
     | 
| 503 | 
         
            +
                            speaker_audio,
         
     | 
| 504 | 
         
            +
                            prefix_audio,
         
     | 
| 505 | 
         
            +
                            emotion1,
         
     | 
| 506 | 
         
            +
                            emotion2,
         
     | 
| 507 | 
         
            +
                            emotion3,
         
     | 
| 508 | 
         
            +
                            emotion4,
         
     | 
| 509 | 
         
            +
                            emotion5,
         
     | 
| 510 | 
         
            +
                            emotion6,
         
     | 
| 511 | 
         
            +
                            emotion7,
         
     | 
| 512 | 
         
            +
                            emotion8,
         
     | 
| 513 | 
         
            +
                            vq_single_slider,
         
     | 
| 514 | 
         
            +
                            fmax_slider,
         
     | 
| 515 | 
         
            +
                            pitch_std_slider,
         
     | 
| 516 | 
         
            +
                            speaking_rate_slider,
         
     | 
| 517 | 
         
            +
                            dnsmos_slider,
         
     | 
| 518 | 
         
            +
                            speaker_noised_checkbox,
         
     | 
| 519 | 
         
            +
                            unconditional_keys,
         
     | 
| 520 | 
         
            +
                        ],
         
     | 
| 521 | 
         
            +
                    )
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
                    # Generate audio on button click
         
     | 
| 524 | 
         
            +
                    generate_button.click(
         
     | 
| 525 | 
         
            +
                        fn=generate_audio,
         
     | 
| 526 | 
         
            +
                        inputs=[
         
     | 
| 527 | 
         
            +
                            model_choice,
         
     | 
| 528 | 
         
            +
                            text,
         
     | 
| 529 | 
         
            +
                            language,
         
     | 
| 530 | 
         
            +
                            speaker_audio,
         
     | 
| 531 | 
         
            +
                            prefix_audio,
         
     | 
| 532 | 
         
            +
                            emotion1,
         
     | 
| 533 | 
         
            +
                            emotion2,
         
     | 
| 534 | 
         
            +
                            emotion3,
         
     | 
| 535 | 
         
            +
                            emotion4,
         
     | 
| 536 | 
         
            +
                            emotion5,
         
     | 
| 537 | 
         
            +
                            emotion6,
         
     | 
| 538 | 
         
            +
                            emotion7,
         
     | 
| 539 | 
         
            +
                            emotion8,
         
     | 
| 540 | 
         
            +
                            vq_single_slider,
         
     | 
| 541 | 
         
            +
                            fmax_slider,
         
     | 
| 542 | 
         
            +
                            pitch_std_slider,
         
     | 
| 543 | 
         
            +
                            speaking_rate_slider,
         
     | 
| 544 | 
         
            +
                            dnsmos_slider,
         
     | 
| 545 | 
         
            +
                            speaker_noised_checkbox,
         
     | 
| 546 | 
         
            +
                            cfg_scale_slider,
         
     | 
| 547 | 
         
            +
                            min_p_slider,
         
     | 
| 548 | 
         
            +
                            seed_number,
         
     | 
| 549 | 
         
            +
                            randomize_seed_toggle,
         
     | 
| 550 | 
         
            +
                            unconditional_keys,
         
     | 
| 551 | 
         
            +
                        ],
         
     | 
| 552 | 
         
            +
                        outputs=[output_audio, seed_number],
         
     | 
| 553 | 
         
            +
                    )
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                return demo
         
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 559 | 
         
            +
                demo = build_interface()
         
     | 
| 560 | 
         
            +
                share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
         
     | 
| 561 | 
         
            +
                demo.launch(server_name="0.0.0.0", server_port=7860, share=share, mcp_server=True)
         
     | 
    	
        qhash/autoencoder.py
    ADDED
    
    | 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torchaudio
         
     | 
| 5 | 
         
            +
            from transformers.models.dac import DacModel
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class DACAutoencoder:
         
     | 
| 9 | 
         
            +
                def __init__(self):
         
     | 
| 10 | 
         
            +
                    super().__init__()
         
     | 
| 11 | 
         
            +
                    self.dac = DacModel.from_pretrained("Quantumhash/dac_44khz")
         
     | 
| 12 | 
         
            +
                    self.dac.eval().requires_grad_(False)
         
     | 
| 13 | 
         
            +
                    self.codebook_size = self.dac.config.codebook_size
         
     | 
| 14 | 
         
            +
                    self.num_codebooks = self.dac.quantizer.n_codebooks
         
     | 
| 15 | 
         
            +
                    self.sampling_rate = self.dac.config.sampling_rate
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
         
     | 
| 18 | 
         
            +
                    wav = torchaudio.functional.resample(wav, sr, 44_100)
         
     | 
| 19 | 
         
            +
                    right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
         
     | 
| 20 | 
         
            +
                    return torch.nn.functional.pad(wav, (0, right_pad))
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                def encode(self, wav: torch.Tensor) -> torch.Tensor:
         
     | 
| 23 | 
         
            +
                    return self.dac.encode(wav).audio_codes
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def decode(self, codes: torch.Tensor) -> torch.Tensor:
         
     | 
| 26 | 
         
            +
                    return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1)
         
     | 
    	
        qhash/backbone.py
    ADDED
    
    | 
         @@ -0,0 +1,50 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            from mamba_ssm.models.mixer_seq_simple import create_block
         
     | 
| 4 | 
         
            +
            from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
         
     | 
| 5 | 
         
            +
            from mamba_ssm.utils.generation import InferenceParams
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from qhash.config import BackboneConfig
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class ZonosBackbone(nn.Module):
         
     | 
| 11 | 
         
            +
                def __init__(self, config: BackboneConfig):
         
     | 
| 12 | 
         
            +
                    super().__init__()
         
     | 
| 13 | 
         
            +
                    self.config = config
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                    self.layers = nn.ModuleList(
         
     | 
| 16 | 
         
            +
                        [
         
     | 
| 17 | 
         
            +
                            create_block(
         
     | 
| 18 | 
         
            +
                                d_model=config.d_model,
         
     | 
| 19 | 
         
            +
                                d_intermediate=config.d_intermediate
         
     | 
| 20 | 
         
            +
                                if (i not in config.attn_layer_idx)
         
     | 
| 21 | 
         
            +
                                else config.attn_mlp_d_intermediate,
         
     | 
| 22 | 
         
            +
                                ssm_cfg=config.ssm_cfg,
         
     | 
| 23 | 
         
            +
                                layer_idx=i,
         
     | 
| 24 | 
         
            +
                                attn_layer_idx=config.attn_layer_idx,
         
     | 
| 25 | 
         
            +
                                attn_cfg=config.attn_cfg,
         
     | 
| 26 | 
         
            +
                                norm_epsilon=config.norm_epsilon,
         
     | 
| 27 | 
         
            +
                                residual_in_fp32=config.residual_in_fp32,
         
     | 
| 28 | 
         
            +
                                fused_add_norm=True,
         
     | 
| 29 | 
         
            +
                                rms_norm=config.rms_norm,
         
     | 
| 30 | 
         
            +
                            )
         
     | 
| 31 | 
         
            +
                            for i in range(config.n_layer)
         
     | 
| 32 | 
         
            +
                        ]
         
     | 
| 33 | 
         
            +
                    )
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
         
     | 
| 38 | 
         
            +
                    residual = None
         
     | 
| 39 | 
         
            +
                    for layer in self.layers:
         
     | 
| 40 | 
         
            +
                        hidden_states, residual = layer(hidden_states, residual, inference_params)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    return layer_norm_fn(
         
     | 
| 43 | 
         
            +
                        hidden_states,
         
     | 
| 44 | 
         
            +
                        self.norm_f.weight,
         
     | 
| 45 | 
         
            +
                        self.norm_f.bias,
         
     | 
| 46 | 
         
            +
                        residual,
         
     | 
| 47 | 
         
            +
                        eps=self.norm_f.eps,
         
     | 
| 48 | 
         
            +
                        residual_in_fp32=self.config.residual_in_fp32,
         
     | 
| 49 | 
         
            +
                        is_rms_norm=self.config.rms_norm,
         
     | 
| 50 | 
         
            +
                    )
         
     | 
    	
        qhash/codebook_pattern.py
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
         
     | 
| 6 | 
         
            +
                codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
         
     | 
| 7 | 
         
            +
                return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def revert_delay_pattern(codes: torch.Tensor):
         
     | 
| 11 | 
         
            +
                _, n_q, seq_len = codes.shape
         
     | 
| 12 | 
         
            +
                return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
         
     | 
    	
        qhash/conditioning.py
    ADDED
    
    | 
         @@ -0,0 +1,373 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from functools import cache
         
     | 
| 2 | 
         
            +
            from typing import Any, Literal, Iterable
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from qhash.config import PrefixConditionerConfig
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class Conditioner(nn.Module):
         
     | 
| 11 | 
         
            +
                def __init__(
         
     | 
| 12 | 
         
            +
                    self,
         
     | 
| 13 | 
         
            +
                    output_dim: int,
         
     | 
| 14 | 
         
            +
                    name: str,
         
     | 
| 15 | 
         
            +
                    cond_dim: int | None = None,
         
     | 
| 16 | 
         
            +
                    projection: Literal["none", "linear", "mlp"] = "none",
         
     | 
| 17 | 
         
            +
                    uncond_type: Literal["learned", "none"] = "none",
         
     | 
| 18 | 
         
            +
                    **kwargs,
         
     | 
| 19 | 
         
            +
                ):
         
     | 
| 20 | 
         
            +
                    super().__init__()
         
     | 
| 21 | 
         
            +
                    self.name = name
         
     | 
| 22 | 
         
            +
                    self.output_dim = output_dim
         
     | 
| 23 | 
         
            +
                    self.cond_dim = cond_dim = cond_dim or output_dim
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    if projection == "linear":
         
     | 
| 26 | 
         
            +
                        self.project = nn.Linear(cond_dim, output_dim)
         
     | 
| 27 | 
         
            +
                    elif projection == "mlp":
         
     | 
| 28 | 
         
            +
                        self.project = nn.Sequential(
         
     | 
| 29 | 
         
            +
                            nn.Linear(cond_dim, output_dim),
         
     | 
| 30 | 
         
            +
                            nn.SiLU(),
         
     | 
| 31 | 
         
            +
                            nn.Linear(output_dim, output_dim),
         
     | 
| 32 | 
         
            +
                        )
         
     | 
| 33 | 
         
            +
                    else:
         
     | 
| 34 | 
         
            +
                        self.project = nn.Identity()
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    self.uncond_vector = None
         
     | 
| 37 | 
         
            +
                    if uncond_type == "learned":
         
     | 
| 38 | 
         
            +
                        self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def apply_cond(self, *inputs: Any) -> torch.Tensor:
         
     | 
| 41 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
         
     | 
| 44 | 
         
            +
                    if inputs is None:
         
     | 
| 45 | 
         
            +
                        assert self.uncond_vector is not None
         
     | 
| 46 | 
         
            +
                        return self.uncond_vector.data.view(1, 1, -1)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    cond = self.apply_cond(*inputs)
         
     | 
| 49 | 
         
            +
                    cond = self.project(cond)
         
     | 
| 50 | 
         
            +
                    return cond
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
         
     | 
| 54 | 
         
            +
            import re
         
     | 
| 55 | 
         
            +
            import unicodedata
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            import inflect
         
     | 
| 58 | 
         
            +
            import torch
         
     | 
| 59 | 
         
            +
            import torch.nn as nn
         
     | 
| 60 | 
         
            +
            from kanjize import number2kanji
         
     | 
| 61 | 
         
            +
            from phonemizer.backend import EspeakBackend
         
     | 
| 62 | 
         
            +
            from sudachipy import Dictionary, SplitMode
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            # --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            _inflect = inflect.engine()
         
     | 
| 67 | 
         
            +
            _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
         
     | 
| 68 | 
         
            +
            _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
         
     | 
| 69 | 
         
            +
            _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
         
     | 
| 70 | 
         
            +
            _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
         
     | 
| 71 | 
         
            +
            _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
         
     | 
| 72 | 
         
            +
            _number_re = re.compile(r"[0-9]+")
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            def _remove_commas(m: re.Match) -> str:
         
     | 
| 76 | 
         
            +
                return m.group(1).replace(",", "")
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def _expand_decimal_point(m: re.Match) -> str:
         
     | 
| 80 | 
         
            +
                return m.group(1).replace(".", " point ")
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def _expand_dollars(m: re.Match) -> str:
         
     | 
| 84 | 
         
            +
                match = m.group(1)
         
     | 
| 85 | 
         
            +
                parts = match.split(".")
         
     | 
| 86 | 
         
            +
                if len(parts) > 2:
         
     | 
| 87 | 
         
            +
                    return match + " dollars"  # Unexpected format
         
     | 
| 88 | 
         
            +
                dollars = int(parts[0]) if parts[0] else 0
         
     | 
| 89 | 
         
            +
                cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
         
     | 
| 90 | 
         
            +
                if dollars and cents:
         
     | 
| 91 | 
         
            +
                    dollar_unit = "dollar" if dollars == 1 else "dollars"
         
     | 
| 92 | 
         
            +
                    cent_unit = "cent" if cents == 1 else "cents"
         
     | 
| 93 | 
         
            +
                    return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
         
     | 
| 94 | 
         
            +
                elif dollars:
         
     | 
| 95 | 
         
            +
                    dollar_unit = "dollar" if dollars == 1 else "dollars"
         
     | 
| 96 | 
         
            +
                    return "%s %s" % (dollars, dollar_unit)
         
     | 
| 97 | 
         
            +
                elif cents:
         
     | 
| 98 | 
         
            +
                    cent_unit = "cent" if cents == 1 else "cents"
         
     | 
| 99 | 
         
            +
                    return "%s %s" % (cents, cent_unit)
         
     | 
| 100 | 
         
            +
                else:
         
     | 
| 101 | 
         
            +
                    return "zero dollars"
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            def _expand_ordinal(m: re.Match) -> str:
         
     | 
| 105 | 
         
            +
                return _inflect.number_to_words(m.group(0))
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            def _expand_number(m: re.Match) -> str:
         
     | 
| 109 | 
         
            +
                num = int(m.group(0))
         
     | 
| 110 | 
         
            +
                if num > 1000 and num < 3000:
         
     | 
| 111 | 
         
            +
                    if num == 2000:
         
     | 
| 112 | 
         
            +
                        return "two thousand"
         
     | 
| 113 | 
         
            +
                    elif num > 2000 and num < 2010:
         
     | 
| 114 | 
         
            +
                        return "two thousand " + _inflect.number_to_words(num % 100)
         
     | 
| 115 | 
         
            +
                    elif num % 100 == 0:
         
     | 
| 116 | 
         
            +
                        return _inflect.number_to_words(num // 100) + " hundred"
         
     | 
| 117 | 
         
            +
                    else:
         
     | 
| 118 | 
         
            +
                        return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
         
     | 
| 119 | 
         
            +
                else:
         
     | 
| 120 | 
         
            +
                    return _inflect.number_to_words(num, andword="")
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            def normalize_numbers(text: str) -> str:
         
     | 
| 124 | 
         
            +
                text = re.sub(_comma_number_re, _remove_commas, text)
         
     | 
| 125 | 
         
            +
                text = re.sub(_pounds_re, r"\1 pounds", text)
         
     | 
| 126 | 
         
            +
                text = re.sub(_dollars_re, _expand_dollars, text)
         
     | 
| 127 | 
         
            +
                text = re.sub(_decimal_number_re, _expand_decimal_point, text)
         
     | 
| 128 | 
         
            +
                text = re.sub(_ordinal_re, _expand_ordinal, text)
         
     | 
| 129 | 
         
            +
                text = re.sub(_number_re, _expand_number, text)
         
     | 
| 130 | 
         
            +
                return text
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            # --- Number normalization code end ---
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
         
     | 
| 137 | 
         
            +
            SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            _punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
         
     | 
| 140 | 
         
            +
            _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
         
     | 
| 141 | 
         
            +
            _letters_ipa = (
         
     | 
| 142 | 
         
            +
                "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
         
     | 
| 143 | 
         
            +
            )
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            symbols = [*_punctuation, *_letters, *_letters_ipa]
         
     | 
| 146 | 
         
            +
            _symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            def _get_symbol_id(s: str) -> int:
         
     | 
| 150 | 
         
            +
                return _symbol_to_id.get(s, 1)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            def get_symbol_ids(text: str) -> list[int]:
         
     | 
| 154 | 
         
            +
                return list(map(_get_symbol_id, text))
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
         
     | 
| 158 | 
         
            +
                phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
         
     | 
| 159 | 
         
            +
                lengths = list(map(len, phoneme_ids))
         
     | 
| 160 | 
         
            +
                longest = max(lengths)
         
     | 
| 161 | 
         
            +
                phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
         
     | 
| 162 | 
         
            +
                return torch.tensor(phoneme_ids), lengths
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
            def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
         
     | 
| 166 | 
         
            +
                text = unicodedata.normalize("NFKC", text)
         
     | 
| 167 | 
         
            +
                text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
         
     | 
| 168 | 
         
            +
                final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
         
     | 
| 169 | 
         
            +
                return final_text
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            def clean(texts: list[str], languages: list[str]) -> list[str]:
         
     | 
| 173 | 
         
            +
                texts_out = []
         
     | 
| 174 | 
         
            +
                for text, language in zip(texts, languages):
         
     | 
| 175 | 
         
            +
                    if "ja" in language:
         
     | 
| 176 | 
         
            +
                        text = normalize_jp_text(text)
         
     | 
| 177 | 
         
            +
                    else:
         
     | 
| 178 | 
         
            +
                        text = normalize_numbers(text)
         
     | 
| 179 | 
         
            +
                    texts_out.append(text)
         
     | 
| 180 | 
         
            +
                return texts_out
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            @cache
         
     | 
| 184 | 
         
            +
            def get_backend(language: str) -> "EspeakBackend":
         
     | 
| 185 | 
         
            +
                import logging
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                from phonemizer.backend import EspeakBackend
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                logger = logging.getLogger("phonemizer")
         
     | 
| 190 | 
         
            +
                backend = EspeakBackend(
         
     | 
| 191 | 
         
            +
                    language,
         
     | 
| 192 | 
         
            +
                    preserve_punctuation=True,
         
     | 
| 193 | 
         
            +
                    with_stress=True,
         
     | 
| 194 | 
         
            +
                    punctuation_marks=_punctuation,
         
     | 
| 195 | 
         
            +
                    logger=logger,
         
     | 
| 196 | 
         
            +
                )
         
     | 
| 197 | 
         
            +
                logger.setLevel(logging.ERROR)
         
     | 
| 198 | 
         
            +
                return backend
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            def phonemize(texts: list[str], languages: list[str]) -> list[str]:
         
     | 
| 202 | 
         
            +
                texts = clean(texts, languages)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                batch_phonemes = []
         
     | 
| 205 | 
         
            +
                for text, language in zip(texts, languages):
         
     | 
| 206 | 
         
            +
                    backend = get_backend(language)
         
     | 
| 207 | 
         
            +
                    phonemes = backend.phonemize([text], strip=True)
         
     | 
| 208 | 
         
            +
                    batch_phonemes.append(phonemes[0])
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                return batch_phonemes
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            class EspeakPhonemeConditioner(Conditioner):
         
     | 
| 214 | 
         
            +
                def __init__(self, output_dim: int, **kwargs):
         
     | 
| 215 | 
         
            +
                    super().__init__(output_dim, **kwargs)
         
     | 
| 216 | 
         
            +
                    self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
         
     | 
| 219 | 
         
            +
                    """
         
     | 
| 220 | 
         
            +
                    Args:
         
     | 
| 221 | 
         
            +
                        texts: list of texts to convert to phonemes
         
     | 
| 222 | 
         
            +
                        languages: ISO 639-1 -or otherwise eSpeak compatible- language code
         
     | 
| 223 | 
         
            +
                    """
         
     | 
| 224 | 
         
            +
                    device = self.phoneme_embedder.weight.device
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    phonemes = phonemize(texts, languages)
         
     | 
| 227 | 
         
            +
                    phoneme_ids, _ = tokenize_phonemes(phonemes)
         
     | 
| 228 | 
         
            +
                    phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    return phoneme_embeds
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
            # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
            class FourierConditioner(Conditioner):
         
     | 
| 237 | 
         
            +
                def __init__(
         
     | 
| 238 | 
         
            +
                    self,
         
     | 
| 239 | 
         
            +
                    output_dim: int,
         
     | 
| 240 | 
         
            +
                    input_dim: int = 1,
         
     | 
| 241 | 
         
            +
                    std: float = 1.0,
         
     | 
| 242 | 
         
            +
                    min_val: float = 0.0,
         
     | 
| 243 | 
         
            +
                    max_val: float = 1.0,
         
     | 
| 244 | 
         
            +
                    **kwargs,
         
     | 
| 245 | 
         
            +
                ):
         
     | 
| 246 | 
         
            +
                    assert output_dim % 2 == 0
         
     | 
| 247 | 
         
            +
                    super().__init__(output_dim, **kwargs)
         
     | 
| 248 | 
         
            +
                    self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
         
     | 
| 249 | 
         
            +
                    self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 252 | 
         
            +
                    assert x.shape[-1] == self.input_dim
         
     | 
| 253 | 
         
            +
                    x = (x - self.min_val) / (self.max_val - self.min_val)  # [batch_size, seq_len, input_dim]
         
     | 
| 254 | 
         
            +
                    f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T  # [batch_size, seq_len, output_dim // 2]
         
     | 
| 255 | 
         
            +
                    return torch.cat([f.cos(), f.sin()], dim=-1)  # [batch_size, seq_len, output_dim]
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            class IntegerConditioner(Conditioner):
         
     | 
| 259 | 
         
            +
                def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
         
     | 
| 260 | 
         
            +
                    super().__init__(output_dim, **kwargs)
         
     | 
| 261 | 
         
            +
                    self.min_val = min_val
         
     | 
| 262 | 
         
            +
                    self.max_val = max_val
         
     | 
| 263 | 
         
            +
                    self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 266 | 
         
            +
                    assert x.shape[-1] == 1
         
     | 
| 267 | 
         
            +
                    return self.int_embedder(x.squeeze(-1) - self.min_val)  # [batch_size, seq_len, output_dim]
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            class PassthroughConditioner(Conditioner):
         
     | 
| 271 | 
         
            +
                def __init__(self, output_dim: int, **kwargs):
         
     | 
| 272 | 
         
            +
                    super().__init__(output_dim, **kwargs)
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 275 | 
         
            +
                    assert x.shape[-1] == self.cond_dim
         
     | 
| 276 | 
         
            +
                    return x
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
            _cond_cls_map = {
         
     | 
| 280 | 
         
            +
                "PassthroughConditioner": PassthroughConditioner,
         
     | 
| 281 | 
         
            +
                "EspeakPhonemeConditioner": EspeakPhonemeConditioner,
         
     | 
| 282 | 
         
            +
                "FourierConditioner": FourierConditioner,
         
     | 
| 283 | 
         
            +
                "IntegerConditioner": IntegerConditioner,
         
     | 
| 284 | 
         
            +
            }
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
            def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
         
     | 
| 288 | 
         
            +
                return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
            class PrefixConditioner(Conditioner):
         
     | 
| 292 | 
         
            +
                def __init__(self, config: PrefixConditionerConfig, output_dim: int):
         
     | 
| 293 | 
         
            +
                    super().__init__(output_dim, "prefix", projection=config.projection)
         
     | 
| 294 | 
         
            +
                    self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
         
     | 
| 295 | 
         
            +
                    self.norm = nn.LayerNorm(output_dim)
         
     | 
| 296 | 
         
            +
                    self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                def forward(self, cond_dict: dict) -> torch.Tensor:
         
     | 
| 299 | 
         
            +
                    if not set(cond_dict).issuperset(self.required_keys):
         
     | 
| 300 | 
         
            +
                        raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
         
     | 
| 301 | 
         
            +
                    conds = []
         
     | 
| 302 | 
         
            +
                    for conditioner in self.conditioners:
         
     | 
| 303 | 
         
            +
                        conds.append(conditioner(cond_dict.get(conditioner.name)))
         
     | 
| 304 | 
         
            +
                    max_bsz = max(map(len, conds))
         
     | 
| 305 | 
         
            +
                    assert all(c.shape[0] in (max_bsz, 1) for c in conds)
         
     | 
| 306 | 
         
            +
                    conds = [c.expand(max_bsz, -1, -1) for c in conds]
         
     | 
| 307 | 
         
            +
                    return self.norm(self.project(torch.cat(conds, dim=-2)))
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            supported_language_codes = [
         
     | 
| 311 | 
         
            +
                'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
         
     | 
| 312 | 
         
            +
                'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
         
     | 
| 313 | 
         
            +
                'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
         
     | 
| 314 | 
         
            +
                'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
         
     | 
| 315 | 
         
            +
                'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
         
     | 
| 316 | 
         
            +
                'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
         
     | 
| 317 | 
         
            +
                'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
         
     | 
| 318 | 
         
            +
                'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
         
     | 
| 319 | 
         
            +
                'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
         
     | 
| 320 | 
         
            +
                'vi-vn-x-central', 'vi-vn-x-south', 'yue'
         
     | 
| 321 | 
         
            +
            ]  # fmt: off
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
            def make_cond_dict(
         
     | 
| 325 | 
         
            +
                text: str = "It would be nice to have time for testing, indeed.",
         
     | 
| 326 | 
         
            +
                language: str = "en-us",
         
     | 
| 327 | 
         
            +
                speaker: torch.Tensor | None = None,
         
     | 
| 328 | 
         
            +
                emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
         
     | 
| 329 | 
         
            +
                fmax: float = 22050.0,
         
     | 
| 330 | 
         
            +
                pitch_std: float = 20.0,
         
     | 
| 331 | 
         
            +
                speaking_rate: float = 15.0,
         
     | 
| 332 | 
         
            +
                vqscore_8: list[float] = [0.78] * 8,
         
     | 
| 333 | 
         
            +
                ctc_loss: float = 0.0,
         
     | 
| 334 | 
         
            +
                dnsmos_ovrl: float = 4.0,
         
     | 
| 335 | 
         
            +
                speaker_noised: bool = False,
         
     | 
| 336 | 
         
            +
                unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
         
     | 
| 337 | 
         
            +
                device: str = "cuda",
         
     | 
| 338 | 
         
            +
            ) -> dict:
         
     | 
| 339 | 
         
            +
                """
         
     | 
| 340 | 
         
            +
                A helper to build the 'cond_dict' that the model expects.
         
     | 
| 341 | 
         
            +
                By default, it will generate a random speaker embedding
         
     | 
| 342 | 
         
            +
                """
         
     | 
| 343 | 
         
            +
                assert language.lower() in supported_language_codes, "Please pick a supported language"
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                cond_dict = {
         
     | 
| 348 | 
         
            +
                    "espeak": ([text], [language]),
         
     | 
| 349 | 
         
            +
                    "speaker": speaker,
         
     | 
| 350 | 
         
            +
                    "emotion": emotion,
         
     | 
| 351 | 
         
            +
                    "fmax": fmax,
         
     | 
| 352 | 
         
            +
                    "pitch_std": pitch_std,
         
     | 
| 353 | 
         
            +
                    "speaking_rate": speaking_rate,
         
     | 
| 354 | 
         
            +
                    "language_id": language_code_to_id[language],
         
     | 
| 355 | 
         
            +
                    "vqscore_8": vqscore_8,
         
     | 
| 356 | 
         
            +
                    "ctc_loss": ctc_loss,
         
     | 
| 357 | 
         
            +
                    "dnsmos_ovrl": dnsmos_ovrl,
         
     | 
| 358 | 
         
            +
                    "speaker_noised": int(speaker_noised),
         
     | 
| 359 | 
         
            +
                }
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                for k in unconditional_keys:
         
     | 
| 362 | 
         
            +
                    cond_dict.pop(k, None)
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                for k, v in cond_dict.items():
         
     | 
| 365 | 
         
            +
                    if isinstance(v, (float, int, list)):
         
     | 
| 366 | 
         
            +
                        v = torch.tensor(v)
         
     | 
| 367 | 
         
            +
                    if isinstance(v, torch.Tensor):
         
     | 
| 368 | 
         
            +
                        cond_dict[k] = v.view(1, 1, -1).to(device)
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    if k == "emotion":
         
     | 
| 371 | 
         
            +
                        cond_dict[k] /= cond_dict[k].sum(dim=-1)
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                return cond_dict
         
     | 
    	
        qhash/config.py
    ADDED
    
    | 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 2 | 
         
            +
            from typing import Literal
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            @dataclass
         
     | 
| 6 | 
         
            +
            class BackboneConfig:
         
     | 
| 7 | 
         
            +
                d_model: int = 1024
         
     | 
| 8 | 
         
            +
                d_intermediate: int = 0
         
     | 
| 9 | 
         
            +
                attn_mlp_d_intermediate: int = 0
         
     | 
| 10 | 
         
            +
                n_layer: int = 16
         
     | 
| 11 | 
         
            +
                ssm_cfg: dict = field(default_factory=dict)
         
     | 
| 12 | 
         
            +
                attn_layer_idx: list = field(default_factory=list)
         
     | 
| 13 | 
         
            +
                attn_cfg: dict = field(default_factory=dict)
         
     | 
| 14 | 
         
            +
                rms_norm: bool = False
         
     | 
| 15 | 
         
            +
                residual_in_fp32: bool = False
         
     | 
| 16 | 
         
            +
                norm_epsilon: float = 1e-5
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            @dataclass
         
     | 
| 20 | 
         
            +
            class PrefixConditionerConfig:
         
     | 
| 21 | 
         
            +
                conditioners: list[dict]
         
     | 
| 22 | 
         
            +
                projection: Literal["none", "linear", "mlp"]
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            @dataclass
         
     | 
| 26 | 
         
            +
            class ZonosConfig:
         
     | 
| 27 | 
         
            +
                backbone: BackboneConfig
         
     | 
| 28 | 
         
            +
                prefix_conditioner: PrefixConditionerConfig
         
     | 
| 29 | 
         
            +
                eos_token_id: int = 1024
         
     | 
| 30 | 
         
            +
                masked_token_id: int = 1025
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                @classmethod
         
     | 
| 33 | 
         
            +
                def from_dict(cls, d: dict) -> "ZonosConfig":
         
     | 
| 34 | 
         
            +
                    d = d.copy()
         
     | 
| 35 | 
         
            +
                    backbone_config = BackboneConfig(**d.pop("backbone"))
         
     | 
| 36 | 
         
            +
                    prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
         
     | 
| 37 | 
         
            +
                    config = cls(backbone_config, prefix_conditioner_config, **d)
         
     | 
| 38 | 
         
            +
                    return config
         
     | 
    	
        qhash/model.py
    ADDED
    
    | 
         @@ -0,0 +1,270 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            from typing import Callable
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import safetensors
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 8 | 
         
            +
            from mamba_ssm.utils.generation import InferenceParams
         
     | 
| 9 | 
         
            +
            from tqdm import tqdm
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from qhash.backbone import ZonosBackbone
         
     | 
| 12 | 
         
            +
            from qhash.autoencoder import DACAutoencoder
         
     | 
| 13 | 
         
            +
            from qhash.codebook_pattern import apply_delay_pattern, revert_delay_pattern
         
     | 
| 14 | 
         
            +
            from qhash.conditioning import PrefixConditioner
         
     | 
| 15 | 
         
            +
            from qhash.config import ZonosConfig
         
     | 
| 16 | 
         
            +
            from qhash.sampling import sample_from_logits
         
     | 
| 17 | 
         
            +
            from qhash.speaker_cloning import SpeakerEmbeddingLDA
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class Zonos(nn.Module):
         
     | 
| 21 | 
         
            +
                def __init__(self, config: ZonosConfig):
         
     | 
| 22 | 
         
            +
                    super().__init__()
         
     | 
| 23 | 
         
            +
                    self.config = config
         
     | 
| 24 | 
         
            +
                    dim = config.backbone.d_model
         
     | 
| 25 | 
         
            +
                    self.eos_token_id = config.eos_token_id
         
     | 
| 26 | 
         
            +
                    self.masked_token_id = config.masked_token_id
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    self.autoencoder = DACAutoencoder()
         
     | 
| 29 | 
         
            +
                    self.backbone = ZonosBackbone(config.backbone)
         
     | 
| 30 | 
         
            +
                    self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
         
     | 
| 31 | 
         
            +
                    self.spk_clone_model = None
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    # TODO: pad to multiple of at least 8
         
     | 
| 34 | 
         
            +
                    self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
         
     | 
| 35 | 
         
            +
                    self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    self._cg_graph = None
         
     | 
| 38 | 
         
            +
                    self._cg_batch_size = None
         
     | 
| 39 | 
         
            +
                    self._cg_input_ids = None
         
     | 
| 40 | 
         
            +
                    self._cg_logits = None
         
     | 
| 41 | 
         
            +
                    self._cg_inference_params = None
         
     | 
| 42 | 
         
            +
                    self._cg_scale = None
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                @classmethod
         
     | 
| 45 | 
         
            +
                def from_pretrained(cls, repo_id: str, revision: str | None = None, device: str = "cuda") -> "Zonos":
         
     | 
| 46 | 
         
            +
                    config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
         
     | 
| 47 | 
         
            +
                    model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
         
     | 
| 48 | 
         
            +
                    return cls.from_local(config_path, model_path, device)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                @classmethod
         
     | 
| 51 | 
         
            +
                def from_local(cls, config_path: str, model_path: str, device: str = "cuda") -> "Zonos":
         
     | 
| 52 | 
         
            +
                    config = ZonosConfig.from_dict(json.load(open(config_path)))
         
     | 
| 53 | 
         
            +
                    model = cls(config).to(device, torch.bfloat16)
         
     | 
| 54 | 
         
            +
                    model.autoencoder.dac.to(device)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    sd = model.state_dict()
         
     | 
| 57 | 
         
            +
                    with safetensors.safe_open(model_path, framework="pt") as f:
         
     | 
| 58 | 
         
            +
                        for k in f.keys():
         
     | 
| 59 | 
         
            +
                            sd[k] = f.get_tensor(k)
         
     | 
| 60 | 
         
            +
                    model.load_state_dict(sd)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    return model
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
         
     | 
| 65 | 
         
            +
                    """Generate a speaker embedding from an audio clip."""
         
     | 
| 66 | 
         
            +
                    if self.spk_clone_model is None:
         
     | 
| 67 | 
         
            +
                        self.spk_clone_model = SpeakerEmbeddingLDA()
         
     | 
| 68 | 
         
            +
                    _, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
         
     | 
| 69 | 
         
            +
                    return spk_embedding.unsqueeze(0).bfloat16()
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
         
     | 
| 72 | 
         
            +
                    return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
         
     | 
| 75 | 
         
            +
                    return torch.stack([head(hidden_states) for head in self.heads], dim=1)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def _compute_logits(
         
     | 
| 78 | 
         
            +
                    self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
         
     | 
| 79 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 80 | 
         
            +
                    """
         
     | 
| 81 | 
         
            +
                    Pass `hidden_states` into `backbone` and `multi_head`, applying
         
     | 
| 82 | 
         
            +
                    classifier-free guidance if `cfg_scale != 1.0`.
         
     | 
| 83 | 
         
            +
                    """
         
     | 
| 84 | 
         
            +
                    last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
         
     | 
| 85 | 
         
            +
                    logits = self.apply_heads(last_hidden_states).squeeze(2).float()
         
     | 
| 86 | 
         
            +
                    if cfg_scale != 1.0:
         
     | 
| 87 | 
         
            +
                        cond_logits, uncond_logits = logits.chunk(2)
         
     | 
| 88 | 
         
            +
                        logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
         
     | 
| 89 | 
         
            +
                    return logits
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def _decode_one_token(
         
     | 
| 92 | 
         
            +
                    self,
         
     | 
| 93 | 
         
            +
                    input_ids: torch.Tensor,
         
     | 
| 94 | 
         
            +
                    inference_params: InferenceParams,
         
     | 
| 95 | 
         
            +
                    cfg_scale: float,
         
     | 
| 96 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 97 | 
         
            +
                    """
         
     | 
| 98 | 
         
            +
                    Single-step decode. Prepares the hidden states, possibly replicates them
         
     | 
| 99 | 
         
            +
                    for CFG, and then delegates to `_compute_logits`.
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    Below we wrap this function with a simple CUDA Graph capturing mechanism,
         
     | 
| 102 | 
         
            +
                    doing 3 warmup steps if needed and then capturing or replaying the graph.
         
     | 
| 103 | 
         
            +
                    We only recapture if the batch size changes.
         
     | 
| 104 | 
         
            +
                    """
         
     | 
| 105 | 
         
            +
                    # TODO: support cfg_scale==1
         
     | 
| 106 | 
         
            +
                    if cfg_scale == 1.0:
         
     | 
| 107 | 
         
            +
                        hidden_states = self.embed_codes(input_ids)
         
     | 
| 108 | 
         
            +
                        return self._compute_logits(hidden_states, inference_params, cfg_scale)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    bsz = input_ids.size(0)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    if need_capture:
         
     | 
| 115 | 
         
            +
                        self._cg_graph = None
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                        self._cg_batch_size = bsz
         
     | 
| 118 | 
         
            +
                        self._cg_inference_params = inference_params
         
     | 
| 119 | 
         
            +
                        self._cg_scale = cfg_scale
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                        for _ in range(3):
         
     | 
| 122 | 
         
            +
                            hidden_states = self.embed_codes(input_ids)
         
     | 
| 123 | 
         
            +
                            hidden_states = hidden_states.repeat(2, 1, 1)  # because cfg != 1.0
         
     | 
| 124 | 
         
            +
                            logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                        self._cg_input_ids = input_ids.clone()
         
     | 
| 127 | 
         
            +
                        self._cg_logits = torch.empty_like(logits)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                        g = torch.cuda.CUDAGraph()
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                        def capture_region():
         
     | 
| 132 | 
         
            +
                            hidden_states_local = self.embed_codes(self._cg_input_ids)
         
     | 
| 133 | 
         
            +
                            hidden_states_local = hidden_states_local.repeat(2, 1, 1)
         
     | 
| 134 | 
         
            +
                            self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                        with torch.cuda.graph(g):
         
     | 
| 137 | 
         
            +
                            capture_region()
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                        self._cg_graph = g
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    else:
         
     | 
| 142 | 
         
            +
                        self._cg_input_ids.copy_(input_ids)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    self._cg_graph.replay()
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    return self._cg_logits
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def _prefill(
         
     | 
| 149 | 
         
            +
                    self,
         
     | 
| 150 | 
         
            +
                    prefix_hidden_states: torch.Tensor,
         
     | 
| 151 | 
         
            +
                    input_ids: torch.Tensor,
         
     | 
| 152 | 
         
            +
                    inference_params: InferenceParams,
         
     | 
| 153 | 
         
            +
                    cfg_scale: float,
         
     | 
| 154 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 155 | 
         
            +
                    """
         
     | 
| 156 | 
         
            +
                    "Prefill" mode: we already have `prefix_hidden_states`, and we want
         
     | 
| 157 | 
         
            +
                    to append new embeddings, then compute the logits.
         
     | 
| 158 | 
         
            +
                    """
         
     | 
| 159 | 
         
            +
                    # Replicate input_ids if CFG is enabled
         
     | 
| 160 | 
         
            +
                    if cfg_scale != 1.0:
         
     | 
| 161 | 
         
            +
                        input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
         
     | 
| 162 | 
         
            +
                    hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
         
     | 
| 163 | 
         
            +
                    return self._compute_logits(hidden_states, inference_params, cfg_scale)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
         
     | 
| 166 | 
         
            +
                    key_value_memory_dict = {
         
     | 
| 167 | 
         
            +
                        i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
         
     | 
| 168 | 
         
            +
                        for i, layer in enumerate(self.backbone.layers)
         
     | 
| 169 | 
         
            +
                    }
         
     | 
| 170 | 
         
            +
                    lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32, device="cuda")
         
     | 
| 171 | 
         
            +
                    return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
         
     | 
| 174 | 
         
            +
                    if uncond_dict is None:
         
     | 
| 175 | 
         
            +
                        uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
         
     | 
| 176 | 
         
            +
                    return torch.cat(
         
     | 
| 177 | 
         
            +
                        [
         
     | 
| 178 | 
         
            +
                            self.prefix_conditioner(cond_dict),
         
     | 
| 179 | 
         
            +
                            self.prefix_conditioner(uncond_dict),
         
     | 
| 180 | 
         
            +
                        ]
         
     | 
| 181 | 
         
            +
                    )
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                @torch.inference_mode()
         
     | 
| 184 | 
         
            +
                def generate(
         
     | 
| 185 | 
         
            +
                    self,
         
     | 
| 186 | 
         
            +
                    prefix_conditioning: torch.Tensor,  # [bsz, cond_seq_len, d_model]
         
     | 
| 187 | 
         
            +
                    audio_prefix_codes: torch.Tensor | None = None,  # [bsz, 9, prefix_audio_seq_len]
         
     | 
| 188 | 
         
            +
                    max_new_tokens: int = 86 * 30,
         
     | 
| 189 | 
         
            +
                    cfg_scale: float = 2.0,
         
     | 
| 190 | 
         
            +
                    batch_size: int = 1,
         
     | 
| 191 | 
         
            +
                    sampling_params: dict = dict(min_p=0.1),
         
     | 
| 192 | 
         
            +
                    progress_bar: bool = True,
         
     | 
| 193 | 
         
            +
                    callback: Callable[[torch.Tensor, int, int], bool] | None = None,
         
     | 
| 194 | 
         
            +
                ):
         
     | 
| 195 | 
         
            +
                    assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
         
     | 
| 196 | 
         
            +
                    prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    unknown_token = -1
         
     | 
| 199 | 
         
            +
                    audio_seq_len = prefix_audio_len + max_new_tokens
         
     | 
| 200 | 
         
            +
                    seq_len = prefix_conditioning.shape[1] + audio_seq_len
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    codes = torch.full((batch_size, 9, audio_seq_len), unknown_token, device="cuda")
         
     | 
| 205 | 
         
            +
                    if audio_prefix_codes is not None:
         
     | 
| 206 | 
         
            +
                        codes[..., :prefix_audio_len] = audio_prefix_codes
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
         
     | 
| 213 | 
         
            +
                    next_token = sample_from_logits(logits, **sampling_params)
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    offset = delayed_prefix_audio_codes.shape[2]
         
     | 
| 216 | 
         
            +
                    frame = delayed_codes[..., offset : offset + 1]
         
     | 
| 217 | 
         
            +
                    frame.masked_scatter_(frame == unknown_token, next_token)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
         
     | 
| 220 | 
         
            +
                    inference_params.seqlen_offset += prefix_length
         
     | 
| 221 | 
         
            +
                    inference_params.lengths_per_sample[:] += prefix_length
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    logit_bias = torch.zeros_like(logits)
         
     | 
| 224 | 
         
            +
                    logit_bias[:, 1:, self.eos_token_id] = -torch.inf  # only allow codebook 0 to predict EOS
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    stopping = torch.zeros(batch_size, dtype=torch.bool, device="cuda")
         
     | 
| 227 | 
         
            +
                    max_steps = delayed_codes.shape[2] - offset
         
     | 
| 228 | 
         
            +
                    remaining_steps = torch.full((batch_size,), max_steps, device="cuda")
         
     | 
| 229 | 
         
            +
                    progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    step = 0
         
     | 
| 232 | 
         
            +
                    while torch.max(remaining_steps) > 0:
         
     | 
| 233 | 
         
            +
                        offset += 1
         
     | 
| 234 | 
         
            +
                        input_ids = delayed_codes[..., offset - 1 : offset]
         
     | 
| 235 | 
         
            +
                        logits = self._decode_one_token(input_ids, inference_params, cfg_scale)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                        next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
         
     | 
| 238 | 
         
            +
                        eos_in_cb0 = next_token[:, 0] == self.eos_token_id
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                        remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
         
     | 
| 241 | 
         
            +
                        stopping |= eos_in_cb0[:, 0]
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                        eos_codebook_idx = 9 - remaining_steps
         
     | 
| 244 | 
         
            +
                        eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
         
     | 
| 245 | 
         
            +
                        for i in range(next_token.shape[0]):
         
     | 
| 246 | 
         
            +
                            if stopping[i]:
         
     | 
| 247 | 
         
            +
                                idx = eos_codebook_idx[i].item()
         
     | 
| 248 | 
         
            +
                                next_token[i, :idx] = self.masked_token_id
         
     | 
| 249 | 
         
            +
                                next_token[i, idx] = self.eos_token_id
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                        frame = delayed_codes[..., offset : offset + 1]
         
     | 
| 252 | 
         
            +
                        frame.masked_scatter_(frame == unknown_token, next_token)
         
     | 
| 253 | 
         
            +
                        inference_params.seqlen_offset += 1
         
     | 
| 254 | 
         
            +
                        inference_params.lengths_per_sample[:] += 1
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                        remaining_steps -= 1
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                        progress.update()
         
     | 
| 259 | 
         
            +
                        step += 1
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                        if callback is not None and not callback(frame, step, max_steps):
         
     | 
| 262 | 
         
            +
                            break
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    out_codes = revert_delay_pattern(delayed_codes)
         
     | 
| 265 | 
         
            +
                    out_codes.masked_fill_(out_codes >= 1024, 0)
         
     | 
| 266 | 
         
            +
                    out_codes = out_codes[..., : offset - 9]
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    self._cg_graph = None  # reset cuda graph to avoid cache changes
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    return out_codes
         
     | 
    	
        qhash/sampling.py
    ADDED
    
    | 
         @@ -0,0 +1,141 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
         
     | 
| 5 | 
         
            +
                """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                Args:
         
     | 
| 8 | 
         
            +
                    input (torch.Tensor): The input tensor containing probabilities.
         
     | 
| 9 | 
         
            +
                    num_samples (int): Number of samples to draw.
         
     | 
| 10 | 
         
            +
                    replacement (bool): Whether to draw with replacement or not.
         
     | 
| 11 | 
         
            +
                Keywords args:
         
     | 
| 12 | 
         
            +
                    generator (torch.Generator): A pseudorandom number generator for sampling.
         
     | 
| 13 | 
         
            +
                Returns:
         
     | 
| 14 | 
         
            +
                    torch.Tensor: Last dimension contains num_samples indices
         
     | 
| 15 | 
         
            +
                        sampled from the multinomial probability distribution
         
     | 
| 16 | 
         
            +
                        located in the last dimension of tensor input.
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                if num_samples == 1:
         
     | 
| 20 | 
         
            +
                    q = torch.empty_like(input).exponential_(1, generator=generator)
         
     | 
| 21 | 
         
            +
                    return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                input_ = input.reshape(-1, input.shape[-1])
         
     | 
| 24 | 
         
            +
                output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
         
     | 
| 25 | 
         
            +
                output = output_.reshape(*list(input.shape[:-1]), -1)
         
     | 
| 26 | 
         
            +
                return output
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def apply_top_k(
         
     | 
| 30 | 
         
            +
                probs: torch.Tensor,
         
     | 
| 31 | 
         
            +
                k: int,
         
     | 
| 32 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 33 | 
         
            +
                """Sample next token from top K values along the last dimension of the input probs tensor.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                Args:
         
     | 
| 36 | 
         
            +
                    probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
         
     | 
| 37 | 
         
            +
                    k (int): The k in “top-k”.
         
     | 
| 38 | 
         
            +
                Returns:
         
     | 
| 39 | 
         
            +
                    torch.Tensor: Sampled tokens.
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
                v, _ = torch.topk(probs, min(k, probs.size(-1)))
         
     | 
| 42 | 
         
            +
                pivot = v.select(-1, -1).unsqueeze(-1)
         
     | 
| 43 | 
         
            +
                probs = torch.where(probs < pivot, 0.0, probs)
         
     | 
| 44 | 
         
            +
                probs.div_(probs.sum(dim=-1, keepdim=True))
         
     | 
| 45 | 
         
            +
                return probs
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
         
     | 
| 49 | 
         
            +
                """Sample next token from top P probabilities along the last dimension of the input probs tensor.
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                Args:
         
     | 
| 52 | 
         
            +
                    probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
         
     | 
| 53 | 
         
            +
                    p (int): The p in “top-p”.
         
     | 
| 54 | 
         
            +
                Returns:
         
     | 
| 55 | 
         
            +
                    torch.Tensor: Sampled tokens.
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
                probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
         
     | 
| 58 | 
         
            +
                probs_sum = torch.cumsum(probs_sort, dim=-1)
         
     | 
| 59 | 
         
            +
                mask = probs_sum - probs_sort > p
         
     | 
| 60 | 
         
            +
                probs_sort *= (~mask).float()
         
     | 
| 61 | 
         
            +
                probs = probs.scatter(-1, probs_idx, probs_sort)
         
     | 
| 62 | 
         
            +
                probs.div_(probs.sum(dim=-1, keepdim=True))
         
     | 
| 63 | 
         
            +
                return probs
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
         
     | 
| 67 | 
         
            +
                """Sample next token using min-p sampling.
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                Args:
         
     | 
| 70 | 
         
            +
                    scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
         
     | 
| 71 | 
         
            +
                    min_p (float): Minimum token probability, scaled by the probability of the most likely token.
         
     | 
| 72 | 
         
            +
                                   Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
         
     | 
| 73 | 
         
            +
                Returns:
         
     | 
| 74 | 
         
            +
                    torch.Tensor: Sampled tokens.
         
     | 
| 75 | 
         
            +
                """
         
     | 
| 76 | 
         
            +
                top_probs, _ = probs.max(dim=-1, keepdim=True)
         
     | 
| 77 | 
         
            +
                tokens_to_remove = probs < (min_p * top_probs)
         
     | 
| 78 | 
         
            +
                probs = probs.masked_fill(tokens_to_remove, 0.0)
         
     | 
| 79 | 
         
            +
                probs.div_(probs.sum(dim=-1, keepdim=True))
         
     | 
| 80 | 
         
            +
                return probs
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def modify_logit_for_repetition_penalty(
         
     | 
| 84 | 
         
            +
                logits: torch.Tensor,
         
     | 
| 85 | 
         
            +
                generated_tokens: torch.Tensor,
         
     | 
| 86 | 
         
            +
                repetition_penalty: float,
         
     | 
| 87 | 
         
            +
                repetition_penalty_window: int,
         
     | 
| 88 | 
         
            +
            ):
         
     | 
| 89 | 
         
            +
                """See https://arxiv.org/abs/1909.05858
         
     | 
| 90 | 
         
            +
                Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
         
     | 
| 91 | 
         
            +
                logits: (batch_size, n_codebooks, vocab_size)
         
     | 
| 92 | 
         
            +
                generated_tokens: (batch_size, n_codebooks, seq_len)
         
     | 
| 93 | 
         
            +
                """
         
     | 
| 94 | 
         
            +
                generated_tokens = generated_tokens[..., -repetition_penalty_window:]
         
     | 
| 95 | 
         
            +
                generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
         
     | 
| 96 | 
         
            +
                rp = torch.full_like(logits, repetition_penalty)
         
     | 
| 97 | 
         
            +
                factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
         
     | 
| 98 | 
         
            +
                return torch.where(logits <= 0, logits * factors, logits / factors)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            def sample_from_logits(
         
     | 
| 102 | 
         
            +
                logits: torch.Tensor,
         
     | 
| 103 | 
         
            +
                temperature: float = 1.0,
         
     | 
| 104 | 
         
            +
                top_p: float = 0.0,
         
     | 
| 105 | 
         
            +
                top_k: int = 0,
         
     | 
| 106 | 
         
            +
                min_p: float = 0.0,
         
     | 
| 107 | 
         
            +
                generated_tokens: torch.Tensor | None = None,
         
     | 
| 108 | 
         
            +
                repetition_penalty: float = 3.0,
         
     | 
| 109 | 
         
            +
                repetition_penalty_window: float = 2,
         
     | 
| 110 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 111 | 
         
            +
                """Sample next token from logits using temperature, top-p, top-k, or min-p sampling.
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                Args:
         
     | 
| 114 | 
         
            +
                    logits (torch.Tensor): Input logits with token candidates on the last dimension.
         
     | 
| 115 | 
         
            +
                    temperature (float): Sampling temperature. Lower temperature results in more deterministic samples.
         
     | 
| 116 | 
         
            +
                    top_p (float): The p in “top-p”.
         
     | 
| 117 | 
         
            +
                    top_k (int): The k in “top-k”.
         
     | 
| 118 | 
         
            +
                    min_p (float): Minimum token probability, scaled by the probability of the most likely token.
         
     | 
| 119 | 
         
            +
                                   Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                Returns:
         
     | 
| 122 | 
         
            +
                    torch.Tensor: Sampled tokens.
         
     | 
| 123 | 
         
            +
                """
         
     | 
| 124 | 
         
            +
                if repetition_penalty != 1.0 and generated_tokens is not None:
         
     | 
| 125 | 
         
            +
                    logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                if temperature > 0:
         
     | 
| 128 | 
         
            +
                    probs = torch.softmax(logits / temperature, dim=-1)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    if top_p > 0:
         
     | 
| 131 | 
         
            +
                        probs = apply_top_p(probs, top_p)
         
     | 
| 132 | 
         
            +
                    if top_k > 0:
         
     | 
| 133 | 
         
            +
                        probs = apply_top_k(probs, top_k)
         
     | 
| 134 | 
         
            +
                    if min_p > 0:
         
     | 
| 135 | 
         
            +
                        probs = apply_min_p(probs, min_p)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    next_token = multinomial(probs, num_samples=1)
         
     | 
| 138 | 
         
            +
                else:
         
     | 
| 139 | 
         
            +
                    next_token = torch.argmax(logits, dim=-1, keepdim=True)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                return next_token  # [batch_size, num_codebooks, 1]
         
     | 
    	
        qhash/speaker_cloning.py
    ADDED
    
    | 
         @@ -0,0 +1,406 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            from functools import cache
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            import torchaudio
         
     | 
| 8 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 9 | 
         
            +
            import os
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class logFbankCal(nn.Module):
         
     | 
| 13 | 
         
            +
                def __init__(
         
     | 
| 14 | 
         
            +
                    self,
         
     | 
| 15 | 
         
            +
                    sample_rate: int = 16_000,
         
     | 
| 16 | 
         
            +
                    n_fft: int = 512,
         
     | 
| 17 | 
         
            +
                    win_length: float = 0.025,
         
     | 
| 18 | 
         
            +
                    hop_length: float = 0.01,
         
     | 
| 19 | 
         
            +
                    n_mels: int = 80,
         
     | 
| 20 | 
         
            +
                ):
         
     | 
| 21 | 
         
            +
                    super().__init__()
         
     | 
| 22 | 
         
            +
                    self.fbankCal = torchaudio.transforms.MelSpectrogram(
         
     | 
| 23 | 
         
            +
                        sample_rate=sample_rate,
         
     | 
| 24 | 
         
            +
                        n_fft=n_fft,
         
     | 
| 25 | 
         
            +
                        win_length=int(win_length * sample_rate),
         
     | 
| 26 | 
         
            +
                        hop_length=int(hop_length * sample_rate),
         
     | 
| 27 | 
         
            +
                        n_mels=n_mels,
         
     | 
| 28 | 
         
            +
                    )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def forward(self, x):
         
     | 
| 31 | 
         
            +
                    out = self.fbankCal(x)
         
     | 
| 32 | 
         
            +
                    out = torch.log(out + 1e-6)
         
     | 
| 33 | 
         
            +
                    out = out - out.mean(axis=2).unsqueeze(dim=2)
         
     | 
| 34 | 
         
            +
                    return out
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            class ASP(nn.Module):
         
     | 
| 38 | 
         
            +
                # Attentive statistics pooling
         
     | 
| 39 | 
         
            +
                def __init__(self, in_planes, acoustic_dim):
         
     | 
| 40 | 
         
            +
                    super(ASP, self).__init__()
         
     | 
| 41 | 
         
            +
                    outmap_size = int(acoustic_dim / 8)
         
     | 
| 42 | 
         
            +
                    self.out_dim = in_planes * 8 * outmap_size * 2
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    self.attention = nn.Sequential(
         
     | 
| 45 | 
         
            +
                        nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
         
     | 
| 46 | 
         
            +
                        nn.ReLU(),
         
     | 
| 47 | 
         
            +
                        nn.BatchNorm1d(128),
         
     | 
| 48 | 
         
            +
                        nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
         
     | 
| 49 | 
         
            +
                        nn.Softmax(dim=2),
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def forward(self, x):
         
     | 
| 53 | 
         
            +
                    x = x.reshape(x.size()[0], -1, x.size()[-1])
         
     | 
| 54 | 
         
            +
                    w = self.attention(x)
         
     | 
| 55 | 
         
            +
                    mu = torch.sum(x * w, dim=2)
         
     | 
| 56 | 
         
            +
                    sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
         
     | 
| 57 | 
         
            +
                    x = torch.cat((mu, sg), 1)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    x = x.view(x.size()[0], -1)
         
     | 
| 60 | 
         
            +
                    return x
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            class SimAMBasicBlock(nn.Module):
         
     | 
| 64 | 
         
            +
                expansion = 1
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
         
     | 
| 67 | 
         
            +
                    super(SimAMBasicBlock, self).__init__()
         
     | 
| 68 | 
         
            +
                    self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
         
     | 
| 69 | 
         
            +
                    self.bn1 = NormLayer(planes)
         
     | 
| 70 | 
         
            +
                    self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
         
     | 
| 71 | 
         
            +
                    self.bn2 = NormLayer(planes)
         
     | 
| 72 | 
         
            +
                    self.relu = nn.ReLU(inplace=True)
         
     | 
| 73 | 
         
            +
                    self.sigmoid = nn.Sigmoid()
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    self.downsample = nn.Sequential()
         
     | 
| 76 | 
         
            +
                    if stride != 1 or in_planes != self.expansion * planes:
         
     | 
| 77 | 
         
            +
                        self.downsample = nn.Sequential(
         
     | 
| 78 | 
         
            +
                            ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
         
     | 
| 79 | 
         
            +
                            NormLayer(self.expansion * planes),
         
     | 
| 80 | 
         
            +
                        )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def forward(self, x):
         
     | 
| 83 | 
         
            +
                    out = self.relu(self.bn1(self.conv1(x)))
         
     | 
| 84 | 
         
            +
                    out = self.bn2(self.conv2(out))
         
     | 
| 85 | 
         
            +
                    out = self.SimAM(out)
         
     | 
| 86 | 
         
            +
                    out += self.downsample(x)
         
     | 
| 87 | 
         
            +
                    out = self.relu(out)
         
     | 
| 88 | 
         
            +
                    return out
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def SimAM(self, X, lambda_p=1e-4):
         
     | 
| 91 | 
         
            +
                    n = X.shape[2] * X.shape[3] - 1
         
     | 
| 92 | 
         
            +
                    d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
         
     | 
| 93 | 
         
            +
                    v = d.sum(dim=[2, 3], keepdim=True) / n
         
     | 
| 94 | 
         
            +
                    E_inv = d / (4 * (v + lambda_p)) + 0.5
         
     | 
| 95 | 
         
            +
                    return X * self.sigmoid(E_inv)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            class BasicBlock(nn.Module):
         
     | 
| 99 | 
         
            +
                expansion = 1
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
         
     | 
| 102 | 
         
            +
                    super(BasicBlock, self).__init__()
         
     | 
| 103 | 
         
            +
                    self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
         
     | 
| 104 | 
         
            +
                    self.bn1 = NormLayer(planes)
         
     | 
| 105 | 
         
            +
                    self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
         
     | 
| 106 | 
         
            +
                    self.bn2 = NormLayer(planes)
         
     | 
| 107 | 
         
            +
                    self.relu = nn.ReLU(inplace=True)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    self.downsample = nn.Sequential()
         
     | 
| 110 | 
         
            +
                    if stride != 1 or in_planes != self.expansion * planes:
         
     | 
| 111 | 
         
            +
                        self.downsample = nn.Sequential(
         
     | 
| 112 | 
         
            +
                            ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
         
     | 
| 113 | 
         
            +
                            NormLayer(self.expansion * planes),
         
     | 
| 114 | 
         
            +
                        )
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def forward(self, x):
         
     | 
| 117 | 
         
            +
                    out = self.relu(self.bn1(self.conv1(x)))
         
     | 
| 118 | 
         
            +
                    out = self.bn2(self.conv2(out))
         
     | 
| 119 | 
         
            +
                    out += self.downsample(x)
         
     | 
| 120 | 
         
            +
                    out = self.relu(out)
         
     | 
| 121 | 
         
            +
                    return out
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            class Bottleneck(nn.Module):
         
     | 
| 125 | 
         
            +
                expansion = 4
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
         
     | 
| 128 | 
         
            +
                    super(Bottleneck, self).__init__()
         
     | 
| 129 | 
         
            +
                    self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
         
     | 
| 130 | 
         
            +
                    self.bn1 = nn.BatchNorm2d(planes)
         
     | 
| 131 | 
         
            +
                    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
         
     | 
| 132 | 
         
            +
                    self.bn2 = nn.BatchNorm2d(planes)
         
     | 
| 133 | 
         
            +
                    self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
         
     | 
| 134 | 
         
            +
                    self.bn3 = nn.BatchNorm2d(self.expansion * planes)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    self.shortcut = nn.Sequential()
         
     | 
| 137 | 
         
            +
                    if stride != 1 or in_planes != self.expansion * planes:
         
     | 
| 138 | 
         
            +
                        self.shortcut = nn.Sequential(
         
     | 
| 139 | 
         
            +
                            nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
         
     | 
| 140 | 
         
            +
                            nn.BatchNorm2d(self.expansion * planes),
         
     | 
| 141 | 
         
            +
                        )
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                def forward(self, x):
         
     | 
| 144 | 
         
            +
                    out = F.relu(self.bn1(self.conv1(x)))
         
     | 
| 145 | 
         
            +
                    out = F.relu(self.bn2(self.conv2(out)))
         
     | 
| 146 | 
         
            +
                    out = self.bn3(self.conv3(out))
         
     | 
| 147 | 
         
            +
                    out += self.shortcut(x)
         
     | 
| 148 | 
         
            +
                    out = F.relu(out)
         
     | 
| 149 | 
         
            +
                    return out
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            class ResNet(nn.Module):
         
     | 
| 153 | 
         
            +
                def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
         
     | 
| 154 | 
         
            +
                    super(ResNet, self).__init__()
         
     | 
| 155 | 
         
            +
                    if feat_dim == "1d":
         
     | 
| 156 | 
         
            +
                        self.NormLayer = nn.BatchNorm1d
         
     | 
| 157 | 
         
            +
                        self.ConvLayer = nn.Conv1d
         
     | 
| 158 | 
         
            +
                    elif feat_dim == "2d":
         
     | 
| 159 | 
         
            +
                        self.NormLayer = nn.BatchNorm2d
         
     | 
| 160 | 
         
            +
                        self.ConvLayer = nn.Conv2d
         
     | 
| 161 | 
         
            +
                    elif feat_dim == "3d":
         
     | 
| 162 | 
         
            +
                        self.NormLayer = nn.BatchNorm3d
         
     | 
| 163 | 
         
            +
                        self.ConvLayer = nn.Conv3d
         
     | 
| 164 | 
         
            +
                    else:
         
     | 
| 165 | 
         
            +
                        print("error")
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    self.in_planes = in_planes
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
         
     | 
| 170 | 
         
            +
                    self.bn1 = self.NormLayer(in_planes)
         
     | 
| 171 | 
         
            +
                    self.relu = nn.ReLU(inplace=True)
         
     | 
| 172 | 
         
            +
                    self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
         
     | 
| 173 | 
         
            +
                    self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
         
     | 
| 174 | 
         
            +
                    self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
         
     | 
| 175 | 
         
            +
                    self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
         
     | 
| 178 | 
         
            +
                    strides = [stride] + [1] * (num_blocks - 1)
         
     | 
| 179 | 
         
            +
                    layers = []
         
     | 
| 180 | 
         
            +
                    for stride in strides:
         
     | 
| 181 | 
         
            +
                        layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
         
     | 
| 182 | 
         
            +
                        self.in_planes = planes * block.expansion
         
     | 
| 183 | 
         
            +
                    return nn.Sequential(*layers)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                def forward(self, x):
         
     | 
| 186 | 
         
            +
                    x = self.relu(self.bn1(self.conv1(x)))
         
     | 
| 187 | 
         
            +
                    x = self.layer1(x)
         
     | 
| 188 | 
         
            +
                    x = self.layer2(x)
         
     | 
| 189 | 
         
            +
                    x = self.layer3(x)
         
     | 
| 190 | 
         
            +
                    x = self.layer4(x)
         
     | 
| 191 | 
         
            +
                    return x
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            def ResNet293(in_planes: int, **kwargs):
         
     | 
| 195 | 
         
            +
                return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            class ResNet293_based(nn.Module):
         
     | 
| 199 | 
         
            +
                def __init__(
         
     | 
| 200 | 
         
            +
                    self,
         
     | 
| 201 | 
         
            +
                    in_planes: int = 64,
         
     | 
| 202 | 
         
            +
                    embd_dim: int = 256,
         
     | 
| 203 | 
         
            +
                    acoustic_dim: int = 80,
         
     | 
| 204 | 
         
            +
                    featCal=None,
         
     | 
| 205 | 
         
            +
                    dropout: float = 0,
         
     | 
| 206 | 
         
            +
                    **kwargs,
         
     | 
| 207 | 
         
            +
                ):
         
     | 
| 208 | 
         
            +
                    super(ResNet293_based, self).__init__()
         
     | 
| 209 | 
         
            +
                    self.featCal = featCal
         
     | 
| 210 | 
         
            +
                    self.front = ResNet293(in_planes)
         
     | 
| 211 | 
         
            +
                    block_expansion = SimAMBasicBlock.expansion
         
     | 
| 212 | 
         
            +
                    self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
         
     | 
| 213 | 
         
            +
                    self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
         
     | 
| 214 | 
         
            +
                    self.drop = nn.Dropout(dropout) if dropout else None
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def forward(self, x):
         
     | 
| 217 | 
         
            +
                    x = self.featCal(x)
         
     | 
| 218 | 
         
            +
                    x = self.front(x.unsqueeze(dim=1))
         
     | 
| 219 | 
         
            +
                    x = self.pooling(x)
         
     | 
| 220 | 
         
            +
                    if self.drop:
         
     | 
| 221 | 
         
            +
                        x = self.drop(x)
         
     | 
| 222 | 
         
            +
                    x = self.bottleneck(x)
         
     | 
| 223 | 
         
            +
                    return x
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
            class SEModule(nn.Module):
         
     | 
| 227 | 
         
            +
                def __init__(self, channels, bottleneck=128):
         
     | 
| 228 | 
         
            +
                    super(SEModule, self).__init__()
         
     | 
| 229 | 
         
            +
                    self.se = nn.Sequential(
         
     | 
| 230 | 
         
            +
                        nn.AdaptiveAvgPool1d(1),
         
     | 
| 231 | 
         
            +
                        nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
         
     | 
| 232 | 
         
            +
                        nn.ReLU(),
         
     | 
| 233 | 
         
            +
                        # nn.BatchNorm1d(bottleneck), # Removed
         
     | 
| 234 | 
         
            +
                        nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
         
     | 
| 235 | 
         
            +
                        nn.Sigmoid(),
         
     | 
| 236 | 
         
            +
                    )
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                def forward(self, input):
         
     | 
| 239 | 
         
            +
                    x = self.se(input)
         
     | 
| 240 | 
         
            +
                    return input * x
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
            class Bottle2neck(nn.Module):
         
     | 
| 244 | 
         
            +
                def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
         
     | 
| 245 | 
         
            +
                    super(Bottle2neck, self).__init__()
         
     | 
| 246 | 
         
            +
                    width = int(math.floor(planes / scale))
         
     | 
| 247 | 
         
            +
                    self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
         
     | 
| 248 | 
         
            +
                    self.bn1 = nn.BatchNorm1d(width * scale)
         
     | 
| 249 | 
         
            +
                    self.nums = scale - 1
         
     | 
| 250 | 
         
            +
                    convs = []
         
     | 
| 251 | 
         
            +
                    bns = []
         
     | 
| 252 | 
         
            +
                    num_pad = math.floor(kernel_size / 2) * dilation
         
     | 
| 253 | 
         
            +
                    for i in range(self.nums):
         
     | 
| 254 | 
         
            +
                        convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
         
     | 
| 255 | 
         
            +
                        bns.append(nn.BatchNorm1d(width))
         
     | 
| 256 | 
         
            +
                    self.convs = nn.ModuleList(convs)
         
     | 
| 257 | 
         
            +
                    self.bns = nn.ModuleList(bns)
         
     | 
| 258 | 
         
            +
                    self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
         
     | 
| 259 | 
         
            +
                    self.bn3 = nn.BatchNorm1d(planes)
         
     | 
| 260 | 
         
            +
                    self.relu = nn.ReLU()
         
     | 
| 261 | 
         
            +
                    self.width = width
         
     | 
| 262 | 
         
            +
                    self.se = SEModule(planes)
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def forward(self, x):
         
     | 
| 265 | 
         
            +
                    residual = x
         
     | 
| 266 | 
         
            +
                    out = self.conv1(x)
         
     | 
| 267 | 
         
            +
                    out = self.relu(out)
         
     | 
| 268 | 
         
            +
                    out = self.bn1(out)
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    spx = torch.split(out, self.width, 1)
         
     | 
| 271 | 
         
            +
                    for i in range(self.nums):
         
     | 
| 272 | 
         
            +
                        if i == 0:
         
     | 
| 273 | 
         
            +
                            sp = spx[i]
         
     | 
| 274 | 
         
            +
                        else:
         
     | 
| 275 | 
         
            +
                            sp = sp + spx[i]
         
     | 
| 276 | 
         
            +
                        sp = self.convs[i](sp)
         
     | 
| 277 | 
         
            +
                        sp = self.relu(sp)
         
     | 
| 278 | 
         
            +
                        sp = self.bns[i](sp)
         
     | 
| 279 | 
         
            +
                        if i == 0:
         
     | 
| 280 | 
         
            +
                            out = sp
         
     | 
| 281 | 
         
            +
                        else:
         
     | 
| 282 | 
         
            +
                            out = torch.cat((out, sp), 1)
         
     | 
| 283 | 
         
            +
                    out = torch.cat((out, spx[self.nums]), 1)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    out = self.conv3(out)
         
     | 
| 286 | 
         
            +
                    out = self.relu(out)
         
     | 
| 287 | 
         
            +
                    out = self.bn3(out)
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    out = self.se(out)
         
     | 
| 290 | 
         
            +
                    out += residual
         
     | 
| 291 | 
         
            +
                    return out
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
            class ECAPA_TDNN(nn.Module):
         
     | 
| 295 | 
         
            +
                def __init__(self, C, featCal):
         
     | 
| 296 | 
         
            +
                    super(ECAPA_TDNN, self).__init__()
         
     | 
| 297 | 
         
            +
                    self.featCal = featCal
         
     | 
| 298 | 
         
            +
                    self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
         
     | 
| 299 | 
         
            +
                    self.relu = nn.ReLU()
         
     | 
| 300 | 
         
            +
                    self.bn1 = nn.BatchNorm1d(C)
         
     | 
| 301 | 
         
            +
                    self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
         
     | 
| 302 | 
         
            +
                    self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
         
     | 
| 303 | 
         
            +
                    self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
         
     | 
| 304 | 
         
            +
                    # I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
         
     | 
| 305 | 
         
            +
                    self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
         
     | 
| 306 | 
         
            +
                    self.attention = nn.Sequential(
         
     | 
| 307 | 
         
            +
                        nn.Conv1d(4608, 256, kernel_size=1),
         
     | 
| 308 | 
         
            +
                        nn.ReLU(),
         
     | 
| 309 | 
         
            +
                        nn.BatchNorm1d(256),
         
     | 
| 310 | 
         
            +
                        nn.Tanh(),  # Added
         
     | 
| 311 | 
         
            +
                        nn.Conv1d(256, 1536, kernel_size=1),
         
     | 
| 312 | 
         
            +
                        nn.Softmax(dim=2),
         
     | 
| 313 | 
         
            +
                    )
         
     | 
| 314 | 
         
            +
                    self.bn5 = nn.BatchNorm1d(3072)
         
     | 
| 315 | 
         
            +
                    self.fc6 = nn.Linear(3072, 192)
         
     | 
| 316 | 
         
            +
                    self.bn6 = nn.BatchNorm1d(192)
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                def forward(self, x):
         
     | 
| 319 | 
         
            +
                    x = self.featCal(x)
         
     | 
| 320 | 
         
            +
                    x = self.conv1(x)
         
     | 
| 321 | 
         
            +
                    x = self.relu(x)
         
     | 
| 322 | 
         
            +
                    x = self.bn1(x)
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    x1 = self.layer1(x)
         
     | 
| 325 | 
         
            +
                    x2 = self.layer2(x + x1)
         
     | 
| 326 | 
         
            +
                    x3 = self.layer3(x + x1 + x2)
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                    x = self.layer4(torch.cat((x1, x2, x3), dim=1))
         
     | 
| 329 | 
         
            +
                    x = self.relu(x)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    t = x.size()[-1]
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    global_x = torch.cat(
         
     | 
| 334 | 
         
            +
                        (
         
     | 
| 335 | 
         
            +
                            x,
         
     | 
| 336 | 
         
            +
                            torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
         
     | 
| 337 | 
         
            +
                            torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
         
     | 
| 338 | 
         
            +
                        ),
         
     | 
| 339 | 
         
            +
                        dim=1,
         
     | 
| 340 | 
         
            +
                    )
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    w = self.attention(global_x)
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                    mu = torch.sum(x * w, dim=2)
         
     | 
| 345 | 
         
            +
                    sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    x = torch.cat((mu, sg), 1)
         
     | 
| 348 | 
         
            +
                    x = self.bn5(x)
         
     | 
| 349 | 
         
            +
                    x = self.fc6(x)
         
     | 
| 350 | 
         
            +
                    x = self.bn6(x)
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    return x
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
            class SpeakerEmbedding(nn.Module):
         
     | 
| 356 | 
         
            +
                def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = "cuda"):
         
     | 
| 357 | 
         
            +
                    super().__init__()
         
     | 
| 358 | 
         
            +
                    self.device = device
         
     | 
| 359 | 
         
            +
                    with torch.device(device):
         
     | 
| 360 | 
         
            +
                        self.model = ResNet293_based()
         
     | 
| 361 | 
         
            +
                        self.model.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
         
     | 
| 362 | 
         
            +
                        self.model.featCal = logFbankCal()
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                    self.requires_grad_(False).eval()
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                @property
         
     | 
| 367 | 
         
            +
                def dtype(self):
         
     | 
| 368 | 
         
            +
                    return next(self.parameters()).dtype
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                @cache
         
     | 
| 371 | 
         
            +
                def _get_resampler(self, orig_sample_rate: int):
         
     | 
| 372 | 
         
            +
                    return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
         
     | 
| 375 | 
         
            +
                    assert wav.ndim < 3
         
     | 
| 376 | 
         
            +
                    if wav.ndim == 2:
         
     | 
| 377 | 
         
            +
                        wav = wav.mean(0, keepdim=True)
         
     | 
| 378 | 
         
            +
                    wav = self._get_resampler(sample_rate)(wav)
         
     | 
| 379 | 
         
            +
                    return wav
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                def forward(self, wav: torch.Tensor, sample_rate: int):
         
     | 
| 382 | 
         
            +
                    wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
         
     | 
| 383 | 
         
            +
                    return self.model(wav).to(wav.device)
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
            class SpeakerEmbeddingLDA(nn.Module):
         
     | 
| 386 | 
         
            +
                def __init__(
         
     | 
| 387 | 
         
            +
                    self,
         
     | 
| 388 | 
         
            +
                    device: str = "cuda",
         
     | 
| 389 | 
         
            +
                ):
         
     | 
| 390 | 
         
            +
                    super().__init__()
         
     | 
| 391 | 
         
            +
                    spk_model_path = hf_hub_download(repo_id="Quantumhash/Qhash-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base.pt")
         
     | 
| 392 | 
         
            +
                    lda_spk_model_path = hf_hub_download(repo_id="Quantumhash/Qhash-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base_LDA-128.pt")
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    self.device = device
         
     | 
| 395 | 
         
            +
                    with torch.device(device):
         
     | 
| 396 | 
         
            +
                        self.model = SpeakerEmbedding(spk_model_path, device)
         
     | 
| 397 | 
         
            +
                        lda_sd = torch.load(lda_spk_model_path, weights_only=True)
         
     | 
| 398 | 
         
            +
                        out_features, in_features = lda_sd["weight"].shape
         
     | 
| 399 | 
         
            +
                        self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
         
     | 
| 400 | 
         
            +
                        self.lda.load_state_dict(lda_sd)
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    self.requires_grad_(False).eval()
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                def forward(self, wav: torch.Tensor, sample_rate: int):
         
     | 
| 405 | 
         
            +
                    emb = self.model(wav, sample_rate).to(torch.float32)
         
     | 
| 406 | 
         
            +
                    return emb, self.lda(emb)
         
     |