Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	fix audiotools version + sampling trick (#7)
Browse files- sampling tricks, fix audiotools pin (1062aecaf5ba2e866553c342f9a7ad78b6ec695a)
- remove share from demo (4fd6833b5f5c90fa51a330c389e2be09d8b057c5)
- .gitignore +2 -0
 - app.py +53 -14
 - requirements.txt +1 -1
 - scripts/exp/train.py +7 -5
 - vampnet/modules/transformer.py +109 -37
 
    	
        .gitignore
    CHANGED
    
    | 
         @@ -175,6 +175,7 @@ lyrebird-audio-codec 
     | 
|
| 175 | 
         
             
            samples-*/**
         
     | 
| 176 | 
         | 
| 177 | 
         
             
            gradio-outputs/
         
     | 
| 
         | 
|
| 178 | 
         
             
            samples*/
         
     | 
| 179 | 
         
             
            models-all/
         
     | 
| 180 | 
         
             
            models.zip
         
     | 
| 
         @@ -183,3 +184,4 @@ descript-audio-codec/ 
     | 
|
| 183 | 
         
             
            # *.pth
         
     | 
| 184 | 
         
             
            .git-old
         
     | 
| 185 | 
         
             
            conf/generated/*
         
     | 
| 
         | 
| 
         | 
|
| 175 | 
         
             
            samples-*/**
         
     | 
| 176 | 
         | 
| 177 | 
         
             
            gradio-outputs/
         
     | 
| 178 | 
         
            +
            models/
         
     | 
| 179 | 
         
             
            samples*/
         
     | 
| 180 | 
         
             
            models-all/
         
     | 
| 181 | 
         
             
            models.zip
         
     | 
| 
         | 
|
| 184 | 
         
             
            # *.pth
         
     | 
| 185 | 
         
             
            .git-old
         
     | 
| 186 | 
         
             
            conf/generated/*
         
     | 
| 187 | 
         
            +
            runs*/
         
     | 
    	
        app.py
    CHANGED
    
    | 
         @@ -107,24 +107,36 @@ def _vamp(data, return_mask=False): 
     | 
|
| 107 | 
         
             
                mask = pmask.codebook_unmask(mask, ncc)
         
     | 
| 108 | 
         | 
| 109 | 
         | 
| 110 | 
         
            -
                print( 
     | 
| 
         | 
|
| 111 | 
         
             
                # save the mask as a txt file
         
     | 
| 112 | 
         
             
                np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
         
     | 
| 113 | 
         | 
| 
         | 
|
| 114 | 
         
             
                zv, mask_z = interface.coarse_vamp(
         
     | 
| 115 | 
         
             
                    z, 
         
     | 
| 116 | 
         
             
                    mask=mask,
         
     | 
| 117 | 
         
             
                    sampling_steps=data[num_steps],
         
     | 
| 118 | 
         
            -
                     
     | 
| 
         | 
|
| 119 | 
         
             
                    return_mask=True, 
         
     | 
| 120 | 
         
             
                    typical_filtering=data[typical_filtering], 
         
     | 
| 121 | 
         
             
                    typical_mass=data[typical_mass], 
         
     | 
| 122 | 
         
             
                    typical_min_tokens=data[typical_min_tokens], 
         
     | 
| 
         | 
|
| 123 | 
         
             
                    gen_fn=interface.coarse.generate,
         
     | 
| 
         | 
|
| 124 | 
         
             
                )
         
     | 
| 125 | 
         | 
| 126 | 
         
             
                if use_coarse2fine: 
         
     | 
| 127 | 
         
            -
                    zv = interface.coarse_to_fine( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 128 | 
         | 
| 129 | 
         
             
                sig = interface.to_signal(zv).cpu()
         
     | 
| 130 | 
         
             
                print("done")
         
     | 
| 
         @@ -157,7 +169,9 @@ def save_vamp(data): 
     | 
|
| 157 | 
         
             
                sig_out.write(out_dir / "output.wav")
         
     | 
| 158 | 
         | 
| 159 | 
         
             
                _data = {
         
     | 
| 160 | 
         
            -
                    " 
     | 
| 
         | 
|
| 
         | 
|
| 161 | 
         
             
                    "prefix_s": data[prefix_s],
         
     | 
| 162 | 
         
             
                    "suffix_s": data[suffix_s],
         
     | 
| 163 | 
         
             
                    "rand_mask_intensity": data[rand_mask_intensity],
         
     | 
| 
         @@ -168,6 +182,7 @@ def save_vamp(data): 
     | 
|
| 168 | 
         
             
                    "n_conditioning_codebooks": data[n_conditioning_codebooks], 
         
     | 
| 169 | 
         
             
                    "use_coarse2fine": data[use_coarse2fine],
         
     | 
| 170 | 
         
             
                    "stretch_factor": data[stretch_factor],
         
     | 
| 
         | 
|
| 171 | 
         
             
                }
         
     | 
| 172 | 
         | 
| 173 | 
         
             
                # save with yaml
         
     | 
| 
         @@ -183,13 +198,14 @@ def save_vamp(data): 
     | 
|
| 183 | 
         
             
                return f"saved! your save code is {out_dir.stem}", zip_path
         
     | 
| 184 | 
         | 
| 185 | 
         | 
| 
         | 
|
| 186 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 187 | 
         | 
| 188 | 
         
             
                with gr.Row():
         
     | 
| 189 | 
         
             
                    with gr.Column():
         
     | 
| 190 | 
         
            -
                        gr.Markdown("# VampNet")
         
     | 
| 191 | 
         
             
                        gr.Markdown("""## Description:
         
     | 
| 192 | 
         
            -
                        This is a demo of VampNet, a  
     | 
| 193 | 
         
             
                        You can control the extent and nature of variation with a set of manual controls and presets. 
         
     | 
| 194 | 
         
             
                        Use this interface to experiment with different mask settings and explore the audio outputs.
         
     | 
| 195 | 
         
             
                        """)
         
     | 
| 
         @@ -197,8 +213,8 @@ with gr.Blocks() as demo: 
     | 
|
| 197 | 
         
             
                        gr.Markdown("""
         
     | 
| 198 | 
         
             
                        ## Instructions:
         
     | 
| 199 | 
         
             
                        1. You can start by uploading some audio, or by loading the example audio. 
         
     | 
| 200 | 
         
            -
                        2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.  
     | 
| 201 | 
         
            -
                        3. Click the "generate (vamp)!!!" button to  
     | 
| 202 | 
         
             
                        4. Optionally, you can add some notes and save the result. 
         
     | 
| 203 | 
         
             
                        5. You can also use the output as the new input and continue experimenting!
         
     | 
| 204 | 
         
             
                        """)
         
     | 
| 
         @@ -377,16 +393,28 @@ with gr.Blocks() as demo: 
     | 
|
| 377 | 
         
             
                                value=0.0
         
     | 
| 378 | 
         
             
                            )
         
     | 
| 379 | 
         | 
| 380 | 
         
            -
                         
     | 
| 381 | 
         
            -
                            label="temperature",
         
     | 
| 382 | 
         
             
                            minimum=0.0,
         
     | 
| 383 | 
         
             
                            maximum=10.0,
         
     | 
| 384 | 
         
            -
                            value=1. 
     | 
| 385 | 
         
             
                        )
         
     | 
| 386 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 387 | 
         | 
| 388 | 
         | 
| 389 | 
         
             
                        with gr.Accordion("sampling settings", open=False):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 390 | 
         
             
                            typical_filtering = gr.Checkbox(
         
     | 
| 391 | 
         
             
                                label="typical filtering ",
         
     | 
| 392 | 
         
             
                                value=False
         
     | 
| 
         @@ -428,6 +456,14 @@ with gr.Blocks() as demo: 
     | 
|
| 428 | 
         
             
                        )
         
     | 
| 429 | 
         | 
| 430 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 431 | 
         
             
                    # mask settings
         
     | 
| 432 | 
         
             
                    with gr.Column():
         
     | 
| 433 | 
         
             
                        vamp_button = gr.Button("generate (vamp)!!!")
         
     | 
| 
         @@ -455,7 +491,9 @@ with gr.Blocks() as demo: 
     | 
|
| 455 | 
         
             
                _inputs = {
         
     | 
| 456 | 
         
             
                        input_audio, 
         
     | 
| 457 | 
         
             
                        num_steps,
         
     | 
| 458 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 459 | 
         
             
                        prefix_s, suffix_s, 
         
     | 
| 460 | 
         
             
                        rand_mask_intensity, 
         
     | 
| 461 | 
         
             
                        periodic_p, periodic_w,
         
     | 
| 
         @@ -468,6 +506,7 @@ with gr.Blocks() as demo: 
     | 
|
| 468 | 
         
             
                        typical_mass,
         
     | 
| 469 | 
         
             
                        typical_min_tokens,
         
     | 
| 470 | 
         
             
                        beat_mask_width,
         
     | 
| 
         | 
|
| 471 | 
         
             
                        beat_mask_downbeats
         
     | 
| 472 | 
         
             
                    }
         
     | 
| 473 | 
         | 
| 
         @@ -498,4 +537,4 @@ with gr.Blocks() as demo: 
     | 
|
| 498 | 
         
             
                    outputs=[thank_you, download_file]
         
     | 
| 499 | 
         
             
                )
         
     | 
| 500 | 
         | 
| 501 | 
         
            -
            demo. 
     | 
| 
         | 
|
| 107 | 
         
             
                mask = pmask.codebook_unmask(mask, ncc)
         
     | 
| 108 | 
         | 
| 109 | 
         | 
| 110 | 
         
            +
                print(data)
         
     | 
| 111 | 
         
            +
                _top_p = data[top_p] if data[top_p] > 0 else None
         
     | 
| 112 | 
         
             
                # save the mask as a txt file
         
     | 
| 113 | 
         
             
                np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
         
     | 
| 114 | 
         | 
| 115 | 
         
            +
                _seed = data[seed] if data[seed] > 0 else None
         
     | 
| 116 | 
         
             
                zv, mask_z = interface.coarse_vamp(
         
     | 
| 117 | 
         
             
                    z, 
         
     | 
| 118 | 
         
             
                    mask=mask,
         
     | 
| 119 | 
         
             
                    sampling_steps=data[num_steps],
         
     | 
| 120 | 
         
            +
                    mask_temperature=data[masktemp]*10,
         
     | 
| 121 | 
         
            +
                    sampling_temperature=data[sampletemp],
         
     | 
| 122 | 
         
             
                    return_mask=True, 
         
     | 
| 123 | 
         
             
                    typical_filtering=data[typical_filtering], 
         
     | 
| 124 | 
         
             
                    typical_mass=data[typical_mass], 
         
     | 
| 125 | 
         
             
                    typical_min_tokens=data[typical_min_tokens], 
         
     | 
| 126 | 
         
            +
                    top_p=_top_p,
         
     | 
| 127 | 
         
             
                    gen_fn=interface.coarse.generate,
         
     | 
| 128 | 
         
            +
                    seed=_seed,
         
     | 
| 129 | 
         
             
                )
         
     | 
| 130 | 
         | 
| 131 | 
         
             
                if use_coarse2fine: 
         
     | 
| 132 | 
         
            +
                    zv = interface.coarse_to_fine(
         
     | 
| 133 | 
         
            +
                        zv, 
         
     | 
| 134 | 
         
            +
                        mask_temperature=data[masktemp]*10, 
         
     | 
| 135 | 
         
            +
                        sampling_temperature=data[sampletemp],
         
     | 
| 136 | 
         
            +
                        mask=mask,
         
     | 
| 137 | 
         
            +
                        sampling_steps=data[num_steps], 
         
     | 
| 138 | 
         
            +
                        seed=_seed,
         
     | 
| 139 | 
         
            +
                    )
         
     | 
| 140 | 
         | 
| 141 | 
         
             
                sig = interface.to_signal(zv).cpu()
         
     | 
| 142 | 
         
             
                print("done")
         
     | 
| 
         | 
|
| 169 | 
         
             
                sig_out.write(out_dir / "output.wav")
         
     | 
| 170 | 
         | 
| 171 | 
         
             
                _data = {
         
     | 
| 172 | 
         
            +
                    "masktemp": data[masktemp],
         
     | 
| 173 | 
         
            +
                    "sampletemp": data[sampletemp],
         
     | 
| 174 | 
         
            +
                    "top_p": data[top_p],
         
     | 
| 175 | 
         
             
                    "prefix_s": data[prefix_s],
         
     | 
| 176 | 
         
             
                    "suffix_s": data[suffix_s],
         
     | 
| 177 | 
         
             
                    "rand_mask_intensity": data[rand_mask_intensity],
         
     | 
| 
         | 
|
| 182 | 
         
             
                    "n_conditioning_codebooks": data[n_conditioning_codebooks], 
         
     | 
| 183 | 
         
             
                    "use_coarse2fine": data[use_coarse2fine],
         
     | 
| 184 | 
         
             
                    "stretch_factor": data[stretch_factor],
         
     | 
| 185 | 
         
            +
                    "seed": data[seed],
         
     | 
| 186 | 
         
             
                }
         
     | 
| 187 | 
         | 
| 188 | 
         
             
                # save with yaml
         
     | 
| 
         | 
|
| 198 | 
         
             
                return f"saved! your save code is {out_dir.stem}", zip_path
         
     | 
| 199 | 
         | 
| 200 | 
         | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 203 | 
         | 
| 204 | 
         
             
                with gr.Row():
         
     | 
| 205 | 
         
             
                    with gr.Column():
         
     | 
| 206 | 
         
            +
                        gr.Markdown("# VampNet Audio Vamping")
         
     | 
| 207 | 
         
             
                        gr.Markdown("""## Description:
         
     | 
| 208 | 
         
            +
                        This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings. 
         
     | 
| 209 | 
         
             
                        You can control the extent and nature of variation with a set of manual controls and presets. 
         
     | 
| 210 | 
         
             
                        Use this interface to experiment with different mask settings and explore the audio outputs.
         
     | 
| 211 | 
         
             
                        """)
         
     | 
| 
         | 
|
| 213 | 
         
             
                        gr.Markdown("""
         
     | 
| 214 | 
         
             
                        ## Instructions:
         
     | 
| 215 | 
         
             
                        1. You can start by uploading some audio, or by loading the example audio. 
         
     | 
| 216 | 
         
            +
                        2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. 
         
     | 
| 217 | 
         
            +
                        3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
         
     | 
| 218 | 
         
             
                        4. Optionally, you can add some notes and save the result. 
         
     | 
| 219 | 
         
             
                        5. You can also use the output as the new input and continue experimenting!
         
     | 
| 220 | 
         
             
                        """)
         
     | 
| 
         | 
|
| 393 | 
         
             
                                value=0.0
         
     | 
| 394 | 
         
             
                            )
         
     | 
| 395 | 
         | 
| 396 | 
         
            +
                        masktemp = gr.Slider(
         
     | 
| 397 | 
         
            +
                            label="mask temperature",
         
     | 
| 398 | 
         
             
                            minimum=0.0,
         
     | 
| 399 | 
         
             
                            maximum=10.0,
         
     | 
| 400 | 
         
            +
                            value=1.5
         
     | 
| 401 | 
         
             
                        )
         
     | 
| 402 | 
         
            +
                        sampletemp = gr.Slider(
         
     | 
| 403 | 
         
            +
                            label="sample temperature",
         
     | 
| 404 | 
         
            +
                            minimum=0.1,
         
     | 
| 405 | 
         
            +
                            maximum=2.0,
         
     | 
| 406 | 
         
            +
                            value=1.0
         
     | 
| 407 | 
         
            +
                        )
         
     | 
| 408 | 
         
            +
                    
         
     | 
| 409 | 
         | 
| 410 | 
         | 
| 411 | 
         
             
                        with gr.Accordion("sampling settings", open=False):
         
     | 
| 412 | 
         
            +
                            top_p = gr.Slider(
         
     | 
| 413 | 
         
            +
                                label="top p (0.0 = off)",
         
     | 
| 414 | 
         
            +
                                minimum=0.0,
         
     | 
| 415 | 
         
            +
                                maximum=1.0,
         
     | 
| 416 | 
         
            +
                                value=0.0
         
     | 
| 417 | 
         
            +
                            )
         
     | 
| 418 | 
         
             
                            typical_filtering = gr.Checkbox(
         
     | 
| 419 | 
         
             
                                label="typical filtering ",
         
     | 
| 420 | 
         
             
                                value=False
         
     | 
| 
         | 
|
| 456 | 
         
             
                        )
         
     | 
| 457 | 
         | 
| 458 | 
         | 
| 459 | 
         
            +
                        seed = gr.Number(
         
     | 
| 460 | 
         
            +
                            label="seed (0 for random)",
         
     | 
| 461 | 
         
            +
                            value=0,
         
     | 
| 462 | 
         
            +
                            precision=0,
         
     | 
| 463 | 
         
            +
                        )
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
             
                    # mask settings
         
     | 
| 468 | 
         
             
                    with gr.Column():
         
     | 
| 469 | 
         
             
                        vamp_button = gr.Button("generate (vamp)!!!")
         
     | 
| 
         | 
|
| 491 | 
         
             
                _inputs = {
         
     | 
| 492 | 
         
             
                        input_audio, 
         
     | 
| 493 | 
         
             
                        num_steps,
         
     | 
| 494 | 
         
            +
                        masktemp,
         
     | 
| 495 | 
         
            +
                        sampletemp,
         
     | 
| 496 | 
         
            +
                        top_p,
         
     | 
| 497 | 
         
             
                        prefix_s, suffix_s, 
         
     | 
| 498 | 
         
             
                        rand_mask_intensity, 
         
     | 
| 499 | 
         
             
                        periodic_p, periodic_w,
         
     | 
| 
         | 
|
| 506 | 
         
             
                        typical_mass,
         
     | 
| 507 | 
         
             
                        typical_min_tokens,
         
     | 
| 508 | 
         
             
                        beat_mask_width,
         
     | 
| 509 | 
         
            +
                        seed, 
         
     | 
| 510 | 
         
             
                        beat_mask_downbeats
         
     | 
| 511 | 
         
             
                    }
         
     | 
| 512 | 
         | 
| 
         | 
|
| 537 | 
         
             
                    outputs=[thank_you, download_file]
         
     | 
| 538 | 
         
             
                )
         
     | 
| 539 | 
         | 
| 540 | 
         
            +
            demo.launch()
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -5,4 +5,4 @@ gradio 
     | 
|
| 5 | 
         
             
            loralib
         
     | 
| 6 | 
         
             
            wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
         
     | 
| 7 | 
         
             
            lac @ git+https://github.com/hugofloresgarcia/lac.git
         
     | 
| 8 | 
         
            -
            audiotools @ git+https://github.com/ 
     | 
| 
         | 
|
| 5 | 
         
             
            loralib
         
     | 
| 6 | 
         
             
            wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
         
     | 
| 7 | 
         
             
            lac @ git+https://github.com/hugofloresgarcia/lac.git
         
     | 
| 8 | 
         
            +
            descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
         
     | 
    	
        scripts/exp/train.py
    CHANGED
    
    | 
         @@ -485,7 +485,6 @@ def load( 
     | 
|
| 485 | 
         
             
                save_path: str,
         
     | 
| 486 | 
         
             
                resume: bool = False,
         
     | 
| 487 | 
         
             
                tag: str = "latest",
         
     | 
| 488 | 
         
            -
                load_weights: bool = False,
         
     | 
| 489 | 
         
             
                fine_tune_checkpoint: Optional[str] = None,
         
     | 
| 490 | 
         
             
                grad_clip_val: float = 5.0,
         
     | 
| 491 | 
         
             
            ) -> State:
         
     | 
| 
         @@ -498,7 +497,7 @@ def load( 
     | 
|
| 498 | 
         
             
                    kwargs = {
         
     | 
| 499 | 
         
             
                        "folder": f"{save_path}/{tag}",
         
     | 
| 500 | 
         
             
                        "map_location": "cpu",
         
     | 
| 501 | 
         
            -
                        "package":  
     | 
| 502 | 
         
             
                    }
         
     | 
| 503 | 
         
             
                    tracker.print(f"Loading checkpoint from {kwargs['folder']}")
         
     | 
| 504 | 
         
             
                    if (Path(kwargs["folder"]) / "vampnet").exists():
         
     | 
| 
         @@ -511,11 +510,14 @@ def load( 
     | 
|
| 511 | 
         | 
| 512 | 
         
             
                if args["fine_tune"]:
         
     | 
| 513 | 
         
             
                    assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
         
     | 
| 514 | 
         
            -
                    model =  
     | 
| 515 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 516 | 
         | 
| 517 | 
         
            -
                model = VampNet() if model is None else model
         
     | 
| 518 | 
         | 
| 
         | 
|
| 519 | 
         
             
                model = accel.prepare_model(model)
         
     | 
| 520 | 
         | 
| 521 | 
         
             
                # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
         
     | 
| 
         | 
|
| 485 | 
         
             
                save_path: str,
         
     | 
| 486 | 
         
             
                resume: bool = False,
         
     | 
| 487 | 
         
             
                tag: str = "latest",
         
     | 
| 
         | 
|
| 488 | 
         
             
                fine_tune_checkpoint: Optional[str] = None,
         
     | 
| 489 | 
         
             
                grad_clip_val: float = 5.0,
         
     | 
| 490 | 
         
             
            ) -> State:
         
     | 
| 
         | 
|
| 497 | 
         
             
                    kwargs = {
         
     | 
| 498 | 
         
             
                        "folder": f"{save_path}/{tag}",
         
     | 
| 499 | 
         
             
                        "map_location": "cpu",
         
     | 
| 500 | 
         
            +
                        "package": False,
         
     | 
| 501 | 
         
             
                    }
         
     | 
| 502 | 
         
             
                    tracker.print(f"Loading checkpoint from {kwargs['folder']}")
         
     | 
| 503 | 
         
             
                    if (Path(kwargs["folder"]) / "vampnet").exists():
         
     | 
| 
         | 
|
| 510 | 
         | 
| 511 | 
         
             
                if args["fine_tune"]:
         
     | 
| 512 | 
         
             
                    assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
         
     | 
| 513 | 
         
            +
                    model = torch.compile(
         
     | 
| 514 | 
         
            +
                        VampNet.load(location=Path(fine_tune_checkpoint), 
         
     | 
| 515 | 
         
            +
                                     map_location="cpu", 
         
     | 
| 516 | 
         
            +
                        )
         
     | 
| 517 | 
         
            +
                    )
         
     | 
| 518 | 
         | 
| 
         | 
|
| 519 | 
         | 
| 520 | 
         
            +
                model = torch.compile(VampNet()) if model is None else model
         
     | 
| 521 | 
         
             
                model = accel.prepare_model(model)
         
     | 
| 522 | 
         | 
| 523 | 
         
             
                # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
         
     | 
    	
        vampnet/modules/transformer.py
    CHANGED
    
    | 
         @@ -367,6 +367,15 @@ class TransformerLayer(nn.Module): 
     | 
|
| 367 | 
         | 
| 368 | 
         
             
                    return x, position_bias, encoder_decoder_position_bias
         
     | 
| 369 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 370 | 
         | 
| 371 | 
         
             
            class TransformerStack(nn.Module):
         
     | 
| 372 | 
         
             
                def __init__(
         
     | 
| 
         @@ -580,20 +589,20 @@ class VampNet(at.ml.BaseModel): 
     | 
|
| 580 | 
         
             
                    time_steps: int = 300,
         
     | 
| 581 | 
         
             
                    sampling_steps: int = 24,
         
     | 
| 582 | 
         
             
                    start_tokens: Optional[torch.Tensor] = None,
         
     | 
| 
         | 
|
| 583 | 
         
             
                    mask: Optional[torch.Tensor] = None,
         
     | 
| 584 | 
         
            -
                     
     | 
| 585 | 
         
             
                    typical_filtering=False,
         
     | 
| 586 | 
         
             
                    typical_mass=0.2,
         
     | 
| 587 | 
         
             
                    typical_min_tokens=1,
         
     | 
| 
         | 
|
| 588 | 
         
             
                    return_signal=True,
         
     | 
| 
         | 
|
| 589 | 
         
             
                ):
         
     | 
| 
         | 
|
| 
         | 
|
| 590 | 
         
             
                    logging.debug(f"beginning generation with {sampling_steps} steps")
         
     | 
| 591 | 
         | 
| 592 | 
         
            -
                    #####################
         
     | 
| 593 | 
         
            -
                    # resolve temperature #
         
     | 
| 594 | 
         
            -
                    #####################
         
     | 
| 595 | 
         
            -
             
     | 
| 596 | 
         
            -
                    logging.debug(f"temperature: {temperature}")
         
     | 
| 597 | 
         | 
| 598 | 
         | 
| 599 | 
         
             
                    ##################### 
         
     | 
| 
         @@ -641,13 +650,11 @@ class VampNet(at.ml.BaseModel): 
     | 
|
| 641 | 
         
             
                    #################
         
     | 
| 642 | 
         
             
                    # begin sampling #
         
     | 
| 643 | 
         
             
                    #################
         
     | 
| 
         | 
|
| 644 | 
         | 
| 645 | 
         
             
                    for i in range(sampling_steps):
         
     | 
| 646 | 
         
             
                        logging.debug(f"step {i} of {sampling_steps}")
         
     | 
| 647 | 
         | 
| 648 | 
         
            -
                        # our current temperature
         
     | 
| 649 | 
         
            -
                        logging.debug(f"temperature: {temperature}")
         
     | 
| 650 | 
         
            -
             
     | 
| 651 | 
         
             
                        # our current schedule step
         
     | 
| 652 | 
         
             
                        r = scalar_to_batch_tensor(
         
     | 
| 653 | 
         
             
                            (i + 1) / sampling_steps, 
         
     | 
| 
         @@ -664,39 +671,19 @@ class VampNet(at.ml.BaseModel): 
     | 
|
| 664 | 
         
             
                        # NOTE: this collapses the codebook dimension into the sequence dimension
         
     | 
| 665 | 
         
             
                        logits = self.forward(latents, r) # b, prob, seq
         
     | 
| 666 | 
         
             
                        logits = logits.permute(0, 2, 1)  # b, seq, prob
         
     | 
| 667 | 
         
            -
                         
     | 
| 668 | 
         
            -
                            typical_filter(logits, 
         
     | 
| 669 | 
         
            -
                                           typical_mass=typical_mass, 
         
     | 
| 670 | 
         
            -
                                           typical_min_tokens=typical_min_tokens
         
     | 
| 671 | 
         
            -
                            )
         
     | 
| 672 | 
         
            -
             
     | 
| 673 | 
         | 
| 674 | 
         
             
                        logging.debug(f"permuted logits with shape: {logits.shape}")
         
     | 
| 675 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 676 | 
         | 
| 677 | 
         
            -
                        # logits2probs
         
     | 
| 678 | 
         
            -
                        probs = torch.softmax(logits, dim=-1)
         
     | 
| 679 | 
         
            -
                        logging.debug(f"computed probs with shape: {probs.shape}")
         
     | 
| 680 | 
         
            -
             
     | 
| 681 | 
         
            -
             
     | 
| 682 | 
         
            -
                        # sample from logits with multinomial sampling
         
     | 
| 683 | 
         
            -
                        b = probs.shape[0]
         
     | 
| 684 | 
         
            -
                        probs = rearrange(probs, "b seq prob -> (b seq) prob")
         
     | 
| 685 | 
         
            -
             
     | 
| 686 | 
         
            -
                        sampled_z =  torch.multinomial(probs, 1).squeeze(-1)
         
     | 
| 687 | 
         
            -
             
     | 
| 688 | 
         
            -
                        sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
         
     | 
| 689 | 
         
            -
                        probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
         
     | 
| 690 | 
         
             
                        logging.debug(f"sampled z with shape: {sampled_z.shape}")
         
     | 
| 691 | 
         | 
| 692 | 
         
            -
                        # get the confidences: which tokens did we sample? 
         
     | 
| 693 | 
         
            -
                        selected_probs = (
         
     | 
| 694 | 
         
            -
                            torch.take_along_dim(
         
     | 
| 695 | 
         
            -
                                probs, sampled_z.long().unsqueeze(-1), 
         
     | 
| 696 | 
         
            -
                                dim=-1
         
     | 
| 697 | 
         
            -
                            ).squeeze(-1)
         
     | 
| 698 | 
         
            -
                        )
         
     | 
| 699 | 
         
            -
             
     | 
| 700 | 
         
             
                        # flatten z_masked and mask, so we can deal with the sampling logic
         
     | 
| 701 | 
         
             
                        # we'll unflatten them at the end of the loop for the next forward pass
         
     | 
| 702 | 
         
             
                        # remove conditioning codebooks, we'll add them back at the end
         
     | 
| 
         @@ -733,7 +720,7 @@ class VampNet(at.ml.BaseModel): 
     | 
|
| 733 | 
         | 
| 734 | 
         
             
                        # get our new mask
         
     | 
| 735 | 
         
             
                        mask = mask_by_random_topk(
         
     | 
| 736 | 
         
            -
                            num_to_mask, selected_probs,  
     | 
| 737 | 
         
             
                        )  
         
     | 
| 738 | 
         | 
| 739 | 
         
             
                        # update the mask
         
     | 
| 
         @@ -766,6 +753,91 @@ class VampNet(at.ml.BaseModel): 
     | 
|
| 766 | 
         
             
                    else:
         
     | 
| 767 | 
         
             
                        return sampled_z
         
     | 
| 768 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 769 | 
         | 
| 770 | 
         
             
            def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
         
     | 
| 771 | 
         
             
                """
         
     | 
| 
         | 
|
| 367 | 
         | 
| 368 | 
         
             
                    return x, position_bias, encoder_decoder_position_bias
         
     | 
| 369 | 
         | 
| 370 | 
         
            +
            def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
         
     | 
| 371 | 
         
            +
                x = np.linspace(0, 1, n_steps)
         
     | 
| 372 | 
         
            +
                a = (0.5 - min_temp) / (max_temp - min_temp)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                x = (x * 12) - 6
         
     | 
| 375 | 
         
            +
                x0 = np.log((1 / a - 1) + 1e-5) / k
         
     | 
| 376 | 
         
            +
                y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                return y
         
     | 
| 379 | 
         | 
| 380 | 
         
             
            class TransformerStack(nn.Module):
         
     | 
| 381 | 
         
             
                def __init__(
         
     | 
| 
         | 
|
| 589 | 
         
             
                    time_steps: int = 300,
         
     | 
| 590 | 
         
             
                    sampling_steps: int = 24,
         
     | 
| 591 | 
         
             
                    start_tokens: Optional[torch.Tensor] = None,
         
     | 
| 592 | 
         
            +
                    sampling_temperature: float = 1.0,
         
     | 
| 593 | 
         
             
                    mask: Optional[torch.Tensor] = None,
         
     | 
| 594 | 
         
            +
                    mask_temperature: float = 20.5,
         
     | 
| 595 | 
         
             
                    typical_filtering=False,
         
     | 
| 596 | 
         
             
                    typical_mass=0.2,
         
     | 
| 597 | 
         
             
                    typical_min_tokens=1,
         
     | 
| 598 | 
         
            +
                    top_p=None,
         
     | 
| 599 | 
         
             
                    return_signal=True,
         
     | 
| 600 | 
         
            +
                    seed: int = None
         
     | 
| 601 | 
         
             
                ):
         
     | 
| 602 | 
         
            +
                    if seed is not None:
         
     | 
| 603 | 
         
            +
                        at.util.seed(seed)
         
     | 
| 604 | 
         
             
                    logging.debug(f"beginning generation with {sampling_steps} steps")
         
     | 
| 605 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 606 | 
         | 
| 607 | 
         | 
| 608 | 
         
             
                    ##################### 
         
     | 
| 
         | 
|
| 650 | 
         
             
                    #################
         
     | 
| 651 | 
         
             
                    # begin sampling #
         
     | 
| 652 | 
         
             
                    #################
         
     | 
| 653 | 
         
            +
                    t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
         
     | 
| 654 | 
         | 
| 655 | 
         
             
                    for i in range(sampling_steps):
         
     | 
| 656 | 
         
             
                        logging.debug(f"step {i} of {sampling_steps}")
         
     | 
| 657 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 658 | 
         
             
                        # our current schedule step
         
     | 
| 659 | 
         
             
                        r = scalar_to_batch_tensor(
         
     | 
| 660 | 
         
             
                            (i + 1) / sampling_steps, 
         
     | 
| 
         | 
|
| 671 | 
         
             
                        # NOTE: this collapses the codebook dimension into the sequence dimension
         
     | 
| 672 | 
         
             
                        logits = self.forward(latents, r) # b, prob, seq
         
     | 
| 673 | 
         
             
                        logits = logits.permute(0, 2, 1)  # b, seq, prob
         
     | 
| 674 | 
         
            +
                        b = logits.shape[0]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 675 | 
         | 
| 676 | 
         
             
                        logging.debug(f"permuted logits with shape: {logits.shape}")
         
     | 
| 677 | 
         | 
| 678 | 
         
            +
                        sampled_z, selected_probs = sample_from_logits(
         
     | 
| 679 | 
         
            +
                            logits, sample=True, temperature=t_sched[i],
         
     | 
| 680 | 
         
            +
                            typical_filtering=typical_filtering, typical_mass=typical_mass,
         
     | 
| 681 | 
         
            +
                            typical_min_tokens=typical_min_tokens,
         
     | 
| 682 | 
         
            +
                            top_k=None, top_p=top_p, return_probs=True
         
     | 
| 683 | 
         
            +
                        )
         
     | 
| 684 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 685 | 
         
             
                        logging.debug(f"sampled z with shape: {sampled_z.shape}")
         
     | 
| 686 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 687 | 
         
             
                        # flatten z_masked and mask, so we can deal with the sampling logic
         
     | 
| 688 | 
         
             
                        # we'll unflatten them at the end of the loop for the next forward pass
         
     | 
| 689 | 
         
             
                        # remove conditioning codebooks, we'll add them back at the end
         
     | 
| 
         | 
|
| 720 | 
         | 
| 721 | 
         
             
                        # get our new mask
         
     | 
| 722 | 
         
             
                        mask = mask_by_random_topk(
         
     | 
| 723 | 
         
            +
                            num_to_mask, selected_probs, mask_temperature * (1-r)
         
     | 
| 724 | 
         
             
                        )  
         
     | 
| 725 | 
         | 
| 726 | 
         
             
                        # update the mask
         
     | 
| 
         | 
|
| 753 | 
         
             
                    else:
         
     | 
| 754 | 
         
             
                        return sampled_z
         
     | 
| 755 | 
         | 
| 756 | 
         
            +
            def sample_from_logits(
         
     | 
| 757 | 
         
            +
                    logits, 
         
     | 
| 758 | 
         
            +
                    sample: bool = True,
         
     | 
| 759 | 
         
            +
                    temperature: float = 1.0,
         
     | 
| 760 | 
         
            +
                    top_k: int = None,
         
     | 
| 761 | 
         
            +
                    top_p: float = None,
         
     | 
| 762 | 
         
            +
                    typical_filtering: bool = False,
         
     | 
| 763 | 
         
            +
                    typical_mass: float = 0.2,
         
     | 
| 764 | 
         
            +
                    typical_min_tokens: int = 1,
         
     | 
| 765 | 
         
            +
                    return_probs: bool = False
         
     | 
| 766 | 
         
            +
                ):
         
     | 
| 767 | 
         
            +
                """Convenience function to sample from a categorial distribution with input as
         
     | 
| 768 | 
         
            +
                unnormalized logits.
         
     | 
| 769 | 
         
            +
             
     | 
| 770 | 
         
            +
                Parameters
         
     | 
| 771 | 
         
            +
                ----------
         
     | 
| 772 | 
         
            +
                logits : Tensor[..., vocab_size]
         
     | 
| 773 | 
         
            +
                config: SamplingConfig
         
     | 
| 774 | 
         
            +
                    The set of hyperparameters to be used for sampling
         
     | 
| 775 | 
         
            +
                    sample : bool, optional
         
     | 
| 776 | 
         
            +
                        Whether to perform multinomial sampling, by default True
         
     | 
| 777 | 
         
            +
                    temperature : float, optional
         
     | 
| 778 | 
         
            +
                        Scaling parameter when multinomial samping, by default 1.0
         
     | 
| 779 | 
         
            +
                    top_k : int, optional
         
     | 
| 780 | 
         
            +
                        Restricts sampling to only `top_k` values acc. to probability,
         
     | 
| 781 | 
         
            +
                        by default None
         
     | 
| 782 | 
         
            +
                    top_p : float, optional
         
     | 
| 783 | 
         
            +
                        Restricts sampling to only those values with cumulative
         
     | 
| 784 | 
         
            +
                        probability = `top_p`, by default None
         
     | 
| 785 | 
         
            +
             
     | 
| 786 | 
         
            +
                Returns
         
     | 
| 787 | 
         
            +
                -------
         
     | 
| 788 | 
         
            +
                Tensor[...]
         
     | 
| 789 | 
         
            +
                    Sampled tokens
         
     | 
| 790 | 
         
            +
                """
         
     | 
| 791 | 
         
            +
                shp = logits.shape[:-1]
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                if typical_filtering:
         
     | 
| 794 | 
         
            +
                    typical_filter(logits, 
         
     | 
| 795 | 
         
            +
                                    typical_mass=typical_mass, 
         
     | 
| 796 | 
         
            +
                                    typical_min_tokens=typical_min_tokens
         
     | 
| 797 | 
         
            +
                    )
         
     | 
| 798 | 
         
            +
             
     | 
| 799 | 
         
            +
                # Apply top_k sampling
         
     | 
| 800 | 
         
            +
                if top_k is not None:
         
     | 
| 801 | 
         
            +
                    v, _ = logits.topk(top_k)
         
     | 
| 802 | 
         
            +
                    logits[logits < v[..., [-1]]] = -float("inf")
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                # Apply top_p (nucleus) sampling
         
     | 
| 805 | 
         
            +
                if top_p is not None and top_p < 1.0:
         
     | 
| 806 | 
         
            +
                    v, sorted_indices = logits.sort(descending=True)
         
     | 
| 807 | 
         
            +
                    cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
         
     | 
| 808 | 
         
            +
             
     | 
| 809 | 
         
            +
                    sorted_indices_to_remove = cumulative_probs > top_p
         
     | 
| 810 | 
         
            +
                    # Right shift indices_to_remove to keep 1st token over threshold
         
     | 
| 811 | 
         
            +
                    sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
         
     | 
| 812 | 
         
            +
                        ..., :-1
         
     | 
| 813 | 
         
            +
                    ]
         
     | 
| 814 | 
         
            +
             
     | 
| 815 | 
         
            +
                    # Compute indices_to_remove in unsorted array
         
     | 
| 816 | 
         
            +
                    indices_to_remove = sorted_indices_to_remove.scatter(
         
     | 
| 817 | 
         
            +
                        -1, sorted_indices, sorted_indices_to_remove
         
     | 
| 818 | 
         
            +
                    )
         
     | 
| 819 | 
         
            +
             
     | 
| 820 | 
         
            +
                    logits[indices_to_remove] = -float("inf")
         
     | 
| 821 | 
         
            +
             
     | 
| 822 | 
         
            +
                # Perform multinomial sampling after normalizing logits
         
     | 
| 823 | 
         
            +
                probs = (
         
     | 
| 824 | 
         
            +
                    F.softmax(logits / temperature, dim=-1)
         
     | 
| 825 | 
         
            +
                    if temperature > 0
         
     | 
| 826 | 
         
            +
                    else logits.softmax(dim=-1)
         
     | 
| 827 | 
         
            +
                )
         
     | 
| 828 | 
         
            +
                token = (
         
     | 
| 829 | 
         
            +
                    probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
         
     | 
| 830 | 
         
            +
                    if sample
         
     | 
| 831 | 
         
            +
                    else logits.argmax(-1)
         
     | 
| 832 | 
         
            +
                )
         
     | 
| 833 | 
         
            +
             
     | 
| 834 | 
         
            +
                if return_probs:
         
     | 
| 835 | 
         
            +
                    token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
         
     | 
| 836 | 
         
            +
                    return token, token_probs
         
     | 
| 837 | 
         
            +
                else:
         
     | 
| 838 | 
         
            +
                    return token
         
     | 
| 839 | 
         
            +
                
         
     | 
| 840 | 
         
            +
             
     | 
| 841 | 
         | 
| 842 | 
         
             
            def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
         
     | 
| 843 | 
         
             
                """
         
     |