Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
1062aec
1
Parent(s):
1fedcf3
sampling tricks, fix audiotools pin
Browse files- .gitignore +2 -0
- app.py +53 -14
- requirements.txt +1 -1
- scripts/exp/train.py +7 -5
- vampnet/modules/transformer.py +109 -37
.gitignore
CHANGED
|
@@ -175,6 +175,7 @@ lyrebird-audio-codec
|
|
| 175 |
samples-*/**
|
| 176 |
|
| 177 |
gradio-outputs/
|
|
|
|
| 178 |
samples*/
|
| 179 |
models-all/
|
| 180 |
models.zip
|
|
@@ -183,3 +184,4 @@ descript-audio-codec/
|
|
| 183 |
# *.pth
|
| 184 |
.git-old
|
| 185 |
conf/generated/*
|
|
|
|
|
|
| 175 |
samples-*/**
|
| 176 |
|
| 177 |
gradio-outputs/
|
| 178 |
+
models/
|
| 179 |
samples*/
|
| 180 |
models-all/
|
| 181 |
models.zip
|
|
|
|
| 184 |
# *.pth
|
| 185 |
.git-old
|
| 186 |
conf/generated/*
|
| 187 |
+
runs*/
|
app.py
CHANGED
|
@@ -107,24 +107,36 @@ def _vamp(data, return_mask=False):
|
|
| 107 |
mask = pmask.codebook_unmask(mask, ncc)
|
| 108 |
|
| 109 |
|
| 110 |
-
print(
|
|
|
|
| 111 |
# save the mask as a txt file
|
| 112 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 113 |
|
|
|
|
| 114 |
zv, mask_z = interface.coarse_vamp(
|
| 115 |
z,
|
| 116 |
mask=mask,
|
| 117 |
sampling_steps=data[num_steps],
|
| 118 |
-
|
|
|
|
| 119 |
return_mask=True,
|
| 120 |
typical_filtering=data[typical_filtering],
|
| 121 |
typical_mass=data[typical_mass],
|
| 122 |
typical_min_tokens=data[typical_min_tokens],
|
|
|
|
| 123 |
gen_fn=interface.coarse.generate,
|
|
|
|
| 124 |
)
|
| 125 |
|
| 126 |
if use_coarse2fine:
|
| 127 |
-
zv = interface.coarse_to_fine(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
sig = interface.to_signal(zv).cpu()
|
| 130 |
print("done")
|
|
@@ -157,7 +169,9 @@ def save_vamp(data):
|
|
| 157 |
sig_out.write(out_dir / "output.wav")
|
| 158 |
|
| 159 |
_data = {
|
| 160 |
-
"
|
|
|
|
|
|
|
| 161 |
"prefix_s": data[prefix_s],
|
| 162 |
"suffix_s": data[suffix_s],
|
| 163 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
@@ -168,6 +182,7 @@ def save_vamp(data):
|
|
| 168 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
| 169 |
"use_coarse2fine": data[use_coarse2fine],
|
| 170 |
"stretch_factor": data[stretch_factor],
|
|
|
|
| 171 |
}
|
| 172 |
|
| 173 |
# save with yaml
|
|
@@ -183,13 +198,14 @@ def save_vamp(data):
|
|
| 183 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
| 184 |
|
| 185 |
|
|
|
|
| 186 |
with gr.Blocks() as demo:
|
| 187 |
|
| 188 |
with gr.Row():
|
| 189 |
with gr.Column():
|
| 190 |
-
gr.Markdown("# VampNet")
|
| 191 |
gr.Markdown("""## Description:
|
| 192 |
-
This is a demo of VampNet, a
|
| 193 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
| 194 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
| 195 |
""")
|
|
@@ -197,8 +213,8 @@ with gr.Blocks() as demo:
|
|
| 197 |
gr.Markdown("""
|
| 198 |
## Instructions:
|
| 199 |
1. You can start by uploading some audio, or by loading the example audio.
|
| 200 |
-
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
| 201 |
-
3. Click the "generate (vamp)!!!" button to
|
| 202 |
4. Optionally, you can add some notes and save the result.
|
| 203 |
5. You can also use the output as the new input and continue experimenting!
|
| 204 |
""")
|
|
@@ -377,16 +393,28 @@ with gr.Blocks() as demo:
|
|
| 377 |
value=0.0
|
| 378 |
)
|
| 379 |
|
| 380 |
-
|
| 381 |
-
label="temperature",
|
| 382 |
minimum=0.0,
|
| 383 |
maximum=10.0,
|
| 384 |
-
value=1.
|
| 385 |
)
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
|
| 389 |
with gr.Accordion("sampling settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
typical_filtering = gr.Checkbox(
|
| 391 |
label="typical filtering ",
|
| 392 |
value=False
|
|
@@ -428,6 +456,14 @@ with gr.Blocks() as demo:
|
|
| 428 |
)
|
| 429 |
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
# mask settings
|
| 432 |
with gr.Column():
|
| 433 |
vamp_button = gr.Button("generate (vamp)!!!")
|
|
@@ -455,7 +491,9 @@ with gr.Blocks() as demo:
|
|
| 455 |
_inputs = {
|
| 456 |
input_audio,
|
| 457 |
num_steps,
|
| 458 |
-
|
|
|
|
|
|
|
| 459 |
prefix_s, suffix_s,
|
| 460 |
rand_mask_intensity,
|
| 461 |
periodic_p, periodic_w,
|
|
@@ -468,6 +506,7 @@ with gr.Blocks() as demo:
|
|
| 468 |
typical_mass,
|
| 469 |
typical_min_tokens,
|
| 470 |
beat_mask_width,
|
|
|
|
| 471 |
beat_mask_downbeats
|
| 472 |
}
|
| 473 |
|
|
@@ -498,4 +537,4 @@ with gr.Blocks() as demo:
|
|
| 498 |
outputs=[thank_you, download_file]
|
| 499 |
)
|
| 500 |
|
| 501 |
-
demo.
|
|
|
|
| 107 |
mask = pmask.codebook_unmask(mask, ncc)
|
| 108 |
|
| 109 |
|
| 110 |
+
print(data)
|
| 111 |
+
_top_p = data[top_p] if data[top_p] > 0 else None
|
| 112 |
# save the mask as a txt file
|
| 113 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 114 |
|
| 115 |
+
_seed = data[seed] if data[seed] > 0 else None
|
| 116 |
zv, mask_z = interface.coarse_vamp(
|
| 117 |
z,
|
| 118 |
mask=mask,
|
| 119 |
sampling_steps=data[num_steps],
|
| 120 |
+
mask_temperature=data[masktemp]*10,
|
| 121 |
+
sampling_temperature=data[sampletemp],
|
| 122 |
return_mask=True,
|
| 123 |
typical_filtering=data[typical_filtering],
|
| 124 |
typical_mass=data[typical_mass],
|
| 125 |
typical_min_tokens=data[typical_min_tokens],
|
| 126 |
+
top_p=_top_p,
|
| 127 |
gen_fn=interface.coarse.generate,
|
| 128 |
+
seed=_seed,
|
| 129 |
)
|
| 130 |
|
| 131 |
if use_coarse2fine:
|
| 132 |
+
zv = interface.coarse_to_fine(
|
| 133 |
+
zv,
|
| 134 |
+
mask_temperature=data[masktemp]*10,
|
| 135 |
+
sampling_temperature=data[sampletemp],
|
| 136 |
+
mask=mask,
|
| 137 |
+
sampling_steps=data[num_steps],
|
| 138 |
+
seed=_seed,
|
| 139 |
+
)
|
| 140 |
|
| 141 |
sig = interface.to_signal(zv).cpu()
|
| 142 |
print("done")
|
|
|
|
| 169 |
sig_out.write(out_dir / "output.wav")
|
| 170 |
|
| 171 |
_data = {
|
| 172 |
+
"masktemp": data[masktemp],
|
| 173 |
+
"sampletemp": data[sampletemp],
|
| 174 |
+
"top_p": data[top_p],
|
| 175 |
"prefix_s": data[prefix_s],
|
| 176 |
"suffix_s": data[suffix_s],
|
| 177 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
|
|
| 182 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
| 183 |
"use_coarse2fine": data[use_coarse2fine],
|
| 184 |
"stretch_factor": data[stretch_factor],
|
| 185 |
+
"seed": data[seed],
|
| 186 |
}
|
| 187 |
|
| 188 |
# save with yaml
|
|
|
|
| 198 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
| 199 |
|
| 200 |
|
| 201 |
+
|
| 202 |
with gr.Blocks() as demo:
|
| 203 |
|
| 204 |
with gr.Row():
|
| 205 |
with gr.Column():
|
| 206 |
+
gr.Markdown("# VampNet Audio Vamping")
|
| 207 |
gr.Markdown("""## Description:
|
| 208 |
+
This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
|
| 209 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
| 210 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
| 211 |
""")
|
|
|
|
| 213 |
gr.Markdown("""
|
| 214 |
## Instructions:
|
| 215 |
1. You can start by uploading some audio, or by loading the example audio.
|
| 216 |
+
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
| 217 |
+
3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
|
| 218 |
4. Optionally, you can add some notes and save the result.
|
| 219 |
5. You can also use the output as the new input and continue experimenting!
|
| 220 |
""")
|
|
|
|
| 393 |
value=0.0
|
| 394 |
)
|
| 395 |
|
| 396 |
+
masktemp = gr.Slider(
|
| 397 |
+
label="mask temperature",
|
| 398 |
minimum=0.0,
|
| 399 |
maximum=10.0,
|
| 400 |
+
value=1.5
|
| 401 |
)
|
| 402 |
+
sampletemp = gr.Slider(
|
| 403 |
+
label="sample temperature",
|
| 404 |
+
minimum=0.1,
|
| 405 |
+
maximum=2.0,
|
| 406 |
+
value=1.0
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
|
| 410 |
|
| 411 |
with gr.Accordion("sampling settings", open=False):
|
| 412 |
+
top_p = gr.Slider(
|
| 413 |
+
label="top p (0.0 = off)",
|
| 414 |
+
minimum=0.0,
|
| 415 |
+
maximum=1.0,
|
| 416 |
+
value=0.0
|
| 417 |
+
)
|
| 418 |
typical_filtering = gr.Checkbox(
|
| 419 |
label="typical filtering ",
|
| 420 |
value=False
|
|
|
|
| 456 |
)
|
| 457 |
|
| 458 |
|
| 459 |
+
seed = gr.Number(
|
| 460 |
+
label="seed (0 for random)",
|
| 461 |
+
value=0,
|
| 462 |
+
precision=0,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
|
| 467 |
# mask settings
|
| 468 |
with gr.Column():
|
| 469 |
vamp_button = gr.Button("generate (vamp)!!!")
|
|
|
|
| 491 |
_inputs = {
|
| 492 |
input_audio,
|
| 493 |
num_steps,
|
| 494 |
+
masktemp,
|
| 495 |
+
sampletemp,
|
| 496 |
+
top_p,
|
| 497 |
prefix_s, suffix_s,
|
| 498 |
rand_mask_intensity,
|
| 499 |
periodic_p, periodic_w,
|
|
|
|
| 506 |
typical_mass,
|
| 507 |
typical_min_tokens,
|
| 508 |
beat_mask_width,
|
| 509 |
+
seed,
|
| 510 |
beat_mask_downbeats
|
| 511 |
}
|
| 512 |
|
|
|
|
| 537 |
outputs=[thank_you, download_file]
|
| 538 |
)
|
| 539 |
|
| 540 |
+
demo.launch(share=True, enable_queue=False, debug=True)
|
requirements.txt
CHANGED
|
@@ -5,4 +5,4 @@ gradio
|
|
| 5 |
loralib
|
| 6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
| 7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
| 8 |
-
audiotools @ git+https://github.com/
|
|
|
|
| 5 |
loralib
|
| 6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
| 7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
| 8 |
+
descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
|
scripts/exp/train.py
CHANGED
|
@@ -485,7 +485,6 @@ def load(
|
|
| 485 |
save_path: str,
|
| 486 |
resume: bool = False,
|
| 487 |
tag: str = "latest",
|
| 488 |
-
load_weights: bool = False,
|
| 489 |
fine_tune_checkpoint: Optional[str] = None,
|
| 490 |
grad_clip_val: float = 5.0,
|
| 491 |
) -> State:
|
|
@@ -498,7 +497,7 @@ def load(
|
|
| 498 |
kwargs = {
|
| 499 |
"folder": f"{save_path}/{tag}",
|
| 500 |
"map_location": "cpu",
|
| 501 |
-
"package":
|
| 502 |
}
|
| 503 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
@@ -511,11 +510,14 @@ def load(
|
|
| 511 |
|
| 512 |
if args["fine_tune"]:
|
| 513 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 514 |
-
model =
|
| 515 |
-
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
-
model = VampNet() if model is None else model
|
| 518 |
|
|
|
|
| 519 |
model = accel.prepare_model(model)
|
| 520 |
|
| 521 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
|
| 485 |
save_path: str,
|
| 486 |
resume: bool = False,
|
| 487 |
tag: str = "latest",
|
|
|
|
| 488 |
fine_tune_checkpoint: Optional[str] = None,
|
| 489 |
grad_clip_val: float = 5.0,
|
| 490 |
) -> State:
|
|
|
|
| 497 |
kwargs = {
|
| 498 |
"folder": f"{save_path}/{tag}",
|
| 499 |
"map_location": "cpu",
|
| 500 |
+
"package": False,
|
| 501 |
}
|
| 502 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 503 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
|
|
| 510 |
|
| 511 |
if args["fine_tune"]:
|
| 512 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 513 |
+
model = torch.compile(
|
| 514 |
+
VampNet.load(location=Path(fine_tune_checkpoint),
|
| 515 |
+
map_location="cpu",
|
| 516 |
+
)
|
| 517 |
+
)
|
| 518 |
|
|
|
|
| 519 |
|
| 520 |
+
model = torch.compile(VampNet()) if model is None else model
|
| 521 |
model = accel.prepare_model(model)
|
| 522 |
|
| 523 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
vampnet/modules/transformer.py
CHANGED
|
@@ -367,6 +367,15 @@ class TransformerLayer(nn.Module):
|
|
| 367 |
|
| 368 |
return x, position_bias, encoder_decoder_position_bias
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
class TransformerStack(nn.Module):
|
| 372 |
def __init__(
|
|
@@ -580,20 +589,20 @@ class VampNet(at.ml.BaseModel):
|
|
| 580 |
time_steps: int = 300,
|
| 581 |
sampling_steps: int = 24,
|
| 582 |
start_tokens: Optional[torch.Tensor] = None,
|
|
|
|
| 583 |
mask: Optional[torch.Tensor] = None,
|
| 584 |
-
|
| 585 |
typical_filtering=False,
|
| 586 |
typical_mass=0.2,
|
| 587 |
typical_min_tokens=1,
|
|
|
|
| 588 |
return_signal=True,
|
|
|
|
| 589 |
):
|
|
|
|
|
|
|
| 590 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 591 |
|
| 592 |
-
#####################
|
| 593 |
-
# resolve temperature #
|
| 594 |
-
#####################
|
| 595 |
-
|
| 596 |
-
logging.debug(f"temperature: {temperature}")
|
| 597 |
|
| 598 |
|
| 599 |
#####################
|
|
@@ -641,13 +650,11 @@ class VampNet(at.ml.BaseModel):
|
|
| 641 |
#################
|
| 642 |
# begin sampling #
|
| 643 |
#################
|
|
|
|
| 644 |
|
| 645 |
for i in range(sampling_steps):
|
| 646 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 647 |
|
| 648 |
-
# our current temperature
|
| 649 |
-
logging.debug(f"temperature: {temperature}")
|
| 650 |
-
|
| 651 |
# our current schedule step
|
| 652 |
r = scalar_to_batch_tensor(
|
| 653 |
(i + 1) / sampling_steps,
|
|
@@ -664,39 +671,19 @@ class VampNet(at.ml.BaseModel):
|
|
| 664 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 665 |
logits = self.forward(latents, r) # b, prob, seq
|
| 666 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 667 |
-
|
| 668 |
-
typical_filter(logits,
|
| 669 |
-
typical_mass=typical_mass,
|
| 670 |
-
typical_min_tokens=typical_min_tokens
|
| 671 |
-
)
|
| 672 |
-
|
| 673 |
|
| 674 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
|
| 677 |
-
# logits2probs
|
| 678 |
-
probs = torch.softmax(logits, dim=-1)
|
| 679 |
-
logging.debug(f"computed probs with shape: {probs.shape}")
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
# sample from logits with multinomial sampling
|
| 683 |
-
b = probs.shape[0]
|
| 684 |
-
probs = rearrange(probs, "b seq prob -> (b seq) prob")
|
| 685 |
-
|
| 686 |
-
sampled_z = torch.multinomial(probs, 1).squeeze(-1)
|
| 687 |
-
|
| 688 |
-
sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
|
| 689 |
-
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
| 690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 691 |
|
| 692 |
-
# get the confidences: which tokens did we sample?
|
| 693 |
-
selected_probs = (
|
| 694 |
-
torch.take_along_dim(
|
| 695 |
-
probs, sampled_z.long().unsqueeze(-1),
|
| 696 |
-
dim=-1
|
| 697 |
-
).squeeze(-1)
|
| 698 |
-
)
|
| 699 |
-
|
| 700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 702 |
# remove conditioning codebooks, we'll add them back at the end
|
|
@@ -733,7 +720,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 733 |
|
| 734 |
# get our new mask
|
| 735 |
mask = mask_by_random_topk(
|
| 736 |
-
num_to_mask, selected_probs,
|
| 737 |
)
|
| 738 |
|
| 739 |
# update the mask
|
|
@@ -766,6 +753,91 @@ class VampNet(at.ml.BaseModel):
|
|
| 766 |
else:
|
| 767 |
return sampled_z
|
| 768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
|
| 770 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
| 771 |
"""
|
|
|
|
| 367 |
|
| 368 |
return x, position_bias, encoder_decoder_position_bias
|
| 369 |
|
| 370 |
+
def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
|
| 371 |
+
x = np.linspace(0, 1, n_steps)
|
| 372 |
+
a = (0.5 - min_temp) / (max_temp - min_temp)
|
| 373 |
+
|
| 374 |
+
x = (x * 12) - 6
|
| 375 |
+
x0 = np.log((1 / a - 1) + 1e-5) / k
|
| 376 |
+
y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
|
| 377 |
+
|
| 378 |
+
return y
|
| 379 |
|
| 380 |
class TransformerStack(nn.Module):
|
| 381 |
def __init__(
|
|
|
|
| 589 |
time_steps: int = 300,
|
| 590 |
sampling_steps: int = 24,
|
| 591 |
start_tokens: Optional[torch.Tensor] = None,
|
| 592 |
+
sampling_temperature: float = 1.0,
|
| 593 |
mask: Optional[torch.Tensor] = None,
|
| 594 |
+
mask_temperature: float = 20.5,
|
| 595 |
typical_filtering=False,
|
| 596 |
typical_mass=0.2,
|
| 597 |
typical_min_tokens=1,
|
| 598 |
+
top_p=None,
|
| 599 |
return_signal=True,
|
| 600 |
+
seed: int = None
|
| 601 |
):
|
| 602 |
+
if seed is not None:
|
| 603 |
+
at.util.seed(seed)
|
| 604 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 605 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
|
| 608 |
#####################
|
|
|
|
| 650 |
#################
|
| 651 |
# begin sampling #
|
| 652 |
#################
|
| 653 |
+
t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
|
| 654 |
|
| 655 |
for i in range(sampling_steps):
|
| 656 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 657 |
|
|
|
|
|
|
|
|
|
|
| 658 |
# our current schedule step
|
| 659 |
r = scalar_to_batch_tensor(
|
| 660 |
(i + 1) / sampling_steps,
|
|
|
|
| 671 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 672 |
logits = self.forward(latents, r) # b, prob, seq
|
| 673 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 674 |
+
b = logits.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
|
| 676 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 677 |
|
| 678 |
+
sampled_z, selected_probs = sample_from_logits(
|
| 679 |
+
logits, sample=True, temperature=t_sched[i],
|
| 680 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
| 681 |
+
typical_min_tokens=typical_min_tokens,
|
| 682 |
+
top_k=None, top_p=top_p, return_probs=True
|
| 683 |
+
)
|
| 684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 688 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 689 |
# remove conditioning codebooks, we'll add them back at the end
|
|
|
|
| 720 |
|
| 721 |
# get our new mask
|
| 722 |
mask = mask_by_random_topk(
|
| 723 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
| 724 |
)
|
| 725 |
|
| 726 |
# update the mask
|
|
|
|
| 753 |
else:
|
| 754 |
return sampled_z
|
| 755 |
|
| 756 |
+
def sample_from_logits(
|
| 757 |
+
logits,
|
| 758 |
+
sample: bool = True,
|
| 759 |
+
temperature: float = 1.0,
|
| 760 |
+
top_k: int = None,
|
| 761 |
+
top_p: float = None,
|
| 762 |
+
typical_filtering: bool = False,
|
| 763 |
+
typical_mass: float = 0.2,
|
| 764 |
+
typical_min_tokens: int = 1,
|
| 765 |
+
return_probs: bool = False
|
| 766 |
+
):
|
| 767 |
+
"""Convenience function to sample from a categorial distribution with input as
|
| 768 |
+
unnormalized logits.
|
| 769 |
+
|
| 770 |
+
Parameters
|
| 771 |
+
----------
|
| 772 |
+
logits : Tensor[..., vocab_size]
|
| 773 |
+
config: SamplingConfig
|
| 774 |
+
The set of hyperparameters to be used for sampling
|
| 775 |
+
sample : bool, optional
|
| 776 |
+
Whether to perform multinomial sampling, by default True
|
| 777 |
+
temperature : float, optional
|
| 778 |
+
Scaling parameter when multinomial samping, by default 1.0
|
| 779 |
+
top_k : int, optional
|
| 780 |
+
Restricts sampling to only `top_k` values acc. to probability,
|
| 781 |
+
by default None
|
| 782 |
+
top_p : float, optional
|
| 783 |
+
Restricts sampling to only those values with cumulative
|
| 784 |
+
probability = `top_p`, by default None
|
| 785 |
+
|
| 786 |
+
Returns
|
| 787 |
+
-------
|
| 788 |
+
Tensor[...]
|
| 789 |
+
Sampled tokens
|
| 790 |
+
"""
|
| 791 |
+
shp = logits.shape[:-1]
|
| 792 |
+
|
| 793 |
+
if typical_filtering:
|
| 794 |
+
typical_filter(logits,
|
| 795 |
+
typical_mass=typical_mass,
|
| 796 |
+
typical_min_tokens=typical_min_tokens
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# Apply top_k sampling
|
| 800 |
+
if top_k is not None:
|
| 801 |
+
v, _ = logits.topk(top_k)
|
| 802 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
| 803 |
+
|
| 804 |
+
# Apply top_p (nucleus) sampling
|
| 805 |
+
if top_p is not None and top_p < 1.0:
|
| 806 |
+
v, sorted_indices = logits.sort(descending=True)
|
| 807 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
| 808 |
+
|
| 809 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 810 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
| 811 |
+
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
| 812 |
+
..., :-1
|
| 813 |
+
]
|
| 814 |
+
|
| 815 |
+
# Compute indices_to_remove in unsorted array
|
| 816 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 817 |
+
-1, sorted_indices, sorted_indices_to_remove
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
logits[indices_to_remove] = -float("inf")
|
| 821 |
+
|
| 822 |
+
# Perform multinomial sampling after normalizing logits
|
| 823 |
+
probs = (
|
| 824 |
+
F.softmax(logits / temperature, dim=-1)
|
| 825 |
+
if temperature > 0
|
| 826 |
+
else logits.softmax(dim=-1)
|
| 827 |
+
)
|
| 828 |
+
token = (
|
| 829 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
| 830 |
+
if sample
|
| 831 |
+
else logits.argmax(-1)
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
if return_probs:
|
| 835 |
+
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
| 836 |
+
return token, token_probs
|
| 837 |
+
else:
|
| 838 |
+
return token
|
| 839 |
+
|
| 840 |
+
|
| 841 |
|
| 842 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
| 843 |
"""
|