Spaces:
Runtime error
Runtime error
c2f prompts (#5)
Browse files- use c2f prompt tokens (43aa23938e66e458e92c32fc8e1adc959c431af9)
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- README.md +5 -3
- app.py +3 -2
- conf/generated-v0/berta-goldman-speech/c2f.yml +0 -15
- conf/generated-v0/berta-goldman-speech/coarse.yml +0 -8
- conf/generated-v0/berta-goldman-speech/interface.yml +0 -5
- conf/generated-v0/gamelan-xeno-canto/c2f.yml +0 -17
- conf/generated-v0/gamelan-xeno-canto/coarse.yml +0 -10
- conf/generated-v0/gamelan-xeno-canto/interface.yml +0 -6
- conf/generated-v0/nasralla/c2f.yml +0 -15
- conf/generated-v0/nasralla/coarse.yml +0 -8
- conf/generated-v0/nasralla/interface.yml +0 -5
- conf/generated/breaks-steps/c2f.yml +0 -15
- conf/generated/breaks-steps/coarse.yml +0 -8
- conf/generated/breaks-steps/interface.yml +0 -7
- conf/generated/bulgarian-tv-choir/c2f.yml +0 -15
- conf/generated/bulgarian-tv-choir/coarse.yml +0 -8
- conf/generated/bulgarian-tv-choir/interface.yml +0 -7
- conf/generated/dariacore/c2f.yml +0 -15
- conf/generated/dariacore/coarse.yml +0 -8
- conf/generated/dariacore/interface.yml +0 -7
- conf/generated/musica-bolero-marimba/c2f.yml +0 -18
- conf/generated/musica-bolero-marimba/coarse.yml +0 -11
- conf/generated/musica-bolero-marimba/interface.yml +0 -8
- conf/generated/panchos/c2f.yml +0 -15
- conf/generated/panchos/coarse.yml +0 -8
- conf/generated/panchos/interface.yml +0 -7
- conf/generated/titi-monkey/c2f.yml +0 -15
- conf/generated/titi-monkey/coarse.yml +0 -8
- conf/generated/titi-monkey/interface.yml +0 -7
- conf/generated/xeno-canto/c2f.yml +0 -15
- conf/generated/xeno-canto/coarse.yml +0 -8
- conf/generated/xeno-canto/interface.yml +0 -7
- conf/lora/birds.yml +0 -10
- conf/lora/birdss.yml +0 -12
- conf/lora/constructions.yml +0 -10
- conf/lora/ella-baila-sola.yml +0 -10
- conf/lora/gas-station.yml +0 -10
- conf/lora/lora-is-this-charlie-parker.yml +0 -10
- conf/lora/lora.yml +4 -6
- conf/lora/underworld.yml +0 -10
- conf/lora/xeno-canto/c2f.yml +0 -21
- conf/lora/xeno-canto/coarse.yml +0 -10
- conf/vampnet-musdb-drums.yml +0 -22
- conf/vampnet.yml +8 -18
- scripts/exp/fine_tune.py +5 -5
- scripts/exp/train.py +468 -418
- setup.py +1 -1
- vampnet/beats.py +2 -1
- vampnet/interface.py +40 -21
.gitignore
CHANGED
|
@@ -182,3 +182,4 @@ audiotools/
|
|
| 182 |
descript-audio-codec/
|
| 183 |
# *.pth
|
| 184 |
.git-old
|
|
|
|
|
|
| 182 |
descript-audio-codec/
|
| 183 |
# *.pth
|
| 184 |
.git-old
|
| 185 |
+
conf/generated/*
|
README.md
CHANGED
|
@@ -7,12 +7,14 @@ sdk: gradio
|
|
| 7 |
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
duplicated_from: hugggof/vampnet
|
| 11 |
---
|
| 12 |
|
| 13 |
# VampNet
|
| 14 |
|
| 15 |
-
This repository contains recipes for training generative music models on top of the
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Setting up
|
| 18 |
|
|
@@ -35,7 +37,7 @@ Config files are stored in the `conf/` folder.
|
|
| 35 |
### Licensing for Pretrained Models:
|
| 36 |
The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
|
| 37 |
|
| 38 |
-
Download the pretrained models from [this link](https://zenodo.org/record/
|
| 39 |
|
| 40 |
|
| 41 |
# Usage
|
|
|
|
| 7 |
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
# VampNet
|
| 13 |
|
| 14 |
+
This repository contains recipes for training generative music models on top of the Descript Audio Codec.
|
| 15 |
+
|
| 16 |
+
## try `unloop`
|
| 17 |
+
you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop
|
| 18 |
|
| 19 |
# Setting up
|
| 20 |
|
|
|
|
| 37 |
### Licensing for Pretrained Models:
|
| 38 |
The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
|
| 39 |
|
| 40 |
+
Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder.
|
| 41 |
|
| 42 |
|
| 43 |
# Usage
|
app.py
CHANGED
|
@@ -124,7 +124,7 @@ def _vamp(data, return_mask=False):
|
|
| 124 |
)
|
| 125 |
|
| 126 |
if use_coarse2fine:
|
| 127 |
-
zv = interface.coarse_to_fine(zv, temperature=data[temp])
|
| 128 |
|
| 129 |
sig = interface.to_signal(zv).cpu()
|
| 130 |
print("done")
|
|
@@ -407,7 +407,8 @@ with gr.Blocks() as demo:
|
|
| 407 |
|
| 408 |
use_coarse2fine = gr.Checkbox(
|
| 409 |
label="use coarse2fine",
|
| 410 |
-
value=True
|
|
|
|
| 411 |
)
|
| 412 |
|
| 413 |
num_steps = gr.Slider(
|
|
|
|
| 124 |
)
|
| 125 |
|
| 126 |
if use_coarse2fine:
|
| 127 |
+
zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
|
| 128 |
|
| 129 |
sig = interface.to_signal(zv).cpu()
|
| 130 |
print("done")
|
|
|
|
| 407 |
|
| 408 |
use_coarse2fine = gr.Checkbox(
|
| 409 |
label="use coarse2fine",
|
| 410 |
+
value=True,
|
| 411 |
+
visible=False
|
| 412 |
)
|
| 413 |
|
| 414 |
num_steps = gr.Slider(
|
conf/generated-v0/berta-goldman-speech/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
save_path: ./runs/berta-goldman-speech/c2f
|
| 12 |
-
train/AudioLoader.sources:
|
| 13 |
-
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
| 14 |
-
val/AudioLoader.sources:
|
| 15 |
-
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/berta-goldman-speech/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
save_path: ./runs/berta-goldman-speech/coarse
|
| 5 |
-
train/AudioLoader.sources:
|
| 6 |
-
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
| 7 |
-
val/AudioLoader.sources:
|
| 8 |
-
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/berta-goldman-speech/interface.yml
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
| 3 |
-
Interface.coarse2fine_ckpt: ./runs/berta-goldman-speech/c2f/best/vampnet/weights.pth
|
| 4 |
-
Interface.coarse_ckpt: ./runs/berta-goldman-speech/coarse/best/vampnet/weights.pth
|
| 5 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/gamelan-xeno-canto/c2f.yml
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
save_path: ./runs/gamelan-xeno-canto/c2f
|
| 12 |
-
train/AudioLoader.sources:
|
| 13 |
-
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
| 14 |
-
- /media/CHONK/hugo/loras/xeno-canto-2
|
| 15 |
-
val/AudioLoader.sources:
|
| 16 |
-
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
| 17 |
-
- /media/CHONK/hugo/loras/xeno-canto-2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/gamelan-xeno-canto/coarse.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
save_path: ./runs/gamelan-xeno-canto/coarse
|
| 5 |
-
train/AudioLoader.sources:
|
| 6 |
-
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
| 7 |
-
- /media/CHONK/hugo/loras/xeno-canto-2
|
| 8 |
-
val/AudioLoader.sources:
|
| 9 |
-
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
| 10 |
-
- /media/CHONK/hugo/loras/xeno-canto-2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/gamelan-xeno-canto/interface.yml
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
| 3 |
-
- /media/CHONK/hugo/loras/xeno-canto-2
|
| 4 |
-
Interface.coarse2fine_ckpt: ./runs/gamelan-xeno-canto/c2f/best/vampnet/weights.pth
|
| 5 |
-
Interface.coarse_ckpt: ./runs/gamelan-xeno-canto/coarse/best/vampnet/weights.pth
|
| 6 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/nasralla/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
save_path: ./runs/nasralla/c2f
|
| 12 |
-
train/AudioLoader.sources:
|
| 13 |
-
- /media/CHONK/hugo/nasralla
|
| 14 |
-
val/AudioLoader.sources:
|
| 15 |
-
- /media/CHONK/hugo/nasralla
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/nasralla/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
save_path: ./runs/nasralla/coarse
|
| 5 |
-
train/AudioLoader.sources:
|
| 6 |
-
- /media/CHONK/hugo/nasralla
|
| 7 |
-
val/AudioLoader.sources:
|
| 8 |
-
- /media/CHONK/hugo/nasralla
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/nasralla/interface.yml
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- /media/CHONK/hugo/nasralla
|
| 3 |
-
Interface.coarse2fine_ckpt: ./runs/nasralla/c2f/best/vampnet/weights.pth
|
| 4 |
-
Interface.coarse_ckpt: ./runs/nasralla/coarse/best/vampnet/weights.pth
|
| 5 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/breaks-steps/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
| 12 |
-
save_path: ./runs/breaks-steps/c2f
|
| 13 |
-
train/AudioLoader.sources: &id001
|
| 14 |
-
- /media/CHONK/hugo/breaks-steps
|
| 15 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/breaks-steps/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
| 5 |
-
save_path: ./runs/breaks-steps/coarse
|
| 6 |
-
train/AudioLoader.sources: &id001
|
| 7 |
-
- /media/CHONK/hugo/breaks-steps
|
| 8 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/breaks-steps/interface.yml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- - /media/CHONK/hugo/breaks-steps
|
| 3 |
-
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
| 4 |
-
Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth
|
| 5 |
-
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
| 6 |
-
Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth
|
| 7 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/bulgarian-tv-choir/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
| 12 |
-
save_path: ./runs/bulgarian-tv-choir/c2f
|
| 13 |
-
train/AudioLoader.sources: &id001
|
| 14 |
-
- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
| 15 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/bulgarian-tv-choir/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
| 5 |
-
save_path: ./runs/bulgarian-tv-choir/coarse
|
| 6 |
-
train/AudioLoader.sources: &id001
|
| 7 |
-
- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
| 8 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/bulgarian-tv-choir/interface.yml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
| 3 |
-
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
| 4 |
-
Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth
|
| 5 |
-
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
| 6 |
-
Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth
|
| 7 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/dariacore/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
| 12 |
-
save_path: ./runs/dariacore/c2f
|
| 13 |
-
train/AudioLoader.sources: &id001
|
| 14 |
-
- /media/CHONK/hugo/loras/dariacore
|
| 15 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/dariacore/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
| 5 |
-
save_path: ./runs/dariacore/coarse
|
| 6 |
-
train/AudioLoader.sources: &id001
|
| 7 |
-
- /media/CHONK/hugo/loras/dariacore
|
| 8 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/dariacore/interface.yml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- - /media/CHONK/hugo/loras/dariacore
|
| 3 |
-
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
| 4 |
-
Interface.coarse2fine_lora_ckpt: ./runs/dariacore/c2f/latest/lora.pth
|
| 5 |
-
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
| 6 |
-
Interface.coarse_lora_ckpt: ./runs/dariacore/coarse/latest/lora.pth
|
| 7 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/musica-bolero-marimba/c2f.yml
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
| 12 |
-
save_path: ./runs/musica-bolero-marimba/c2f
|
| 13 |
-
train/AudioLoader.sources:
|
| 14 |
-
- /media/CHONK/hugo/loras/boleros
|
| 15 |
-
- /media/CHONK/hugo/loras/marimba-honduras
|
| 16 |
-
val/AudioLoader.sources:
|
| 17 |
-
- /media/CHONK/hugo/loras/boleros
|
| 18 |
-
- /media/CHONK/hugo/loras/marimba-honduras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/musica-bolero-marimba/coarse.yml
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
| 5 |
-
save_path: ./runs/musica-bolero-marimba/coarse
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/loras/boleros
|
| 8 |
-
- /media/CHONK/hugo/loras/marimba-honduras
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/loras/boleros
|
| 11 |
-
- /media/CHONK/hugo/loras/marimba-honduras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/musica-bolero-marimba/interface.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- /media/CHONK/hugo/loras/boleros
|
| 3 |
-
- /media/CHONK/hugo/loras/marimba-honduras
|
| 4 |
-
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
| 5 |
-
Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth
|
| 6 |
-
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
| 7 |
-
Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth
|
| 8 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/panchos/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
| 12 |
-
save_path: ./runs/panchos/c2f
|
| 13 |
-
train/AudioLoader.sources: &id001
|
| 14 |
-
- /media/CHONK/hugo/loras/panchos/
|
| 15 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/panchos/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
| 5 |
-
save_path: ./runs/panchos/coarse
|
| 6 |
-
train/AudioLoader.sources: &id001
|
| 7 |
-
- /media/CHONK/hugo/loras/panchos/
|
| 8 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/panchos/interface.yml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- - /media/CHONK/hugo/loras/panchos/
|
| 3 |
-
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
| 4 |
-
Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth
|
| 5 |
-
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
| 6 |
-
Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth
|
| 7 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/titi-monkey/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
| 12 |
-
save_path: ./runs/titi-monkey/c2f
|
| 13 |
-
train/AudioLoader.sources: &id001
|
| 14 |
-
- /media/CHONK/hugo/loras/titi-monkey.mp3
|
| 15 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/titi-monkey/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
| 5 |
-
save_path: ./runs/titi-monkey/coarse
|
| 6 |
-
train/AudioLoader.sources: &id001
|
| 7 |
-
- /media/CHONK/hugo/loras/titi-monkey.mp3
|
| 8 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/titi-monkey/interface.yml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- - /media/CHONK/hugo/loras/titi-monkey.mp3
|
| 3 |
-
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
| 4 |
-
Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth
|
| 5 |
-
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
| 6 |
-
Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth
|
| 7 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/xeno-canto/c2f.yml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
AudioDataset.duration: 3.0
|
| 4 |
-
AudioDataset.loudness_cutoff: -40.0
|
| 5 |
-
VampNet.embedding_dim: 1280
|
| 6 |
-
VampNet.n_codebooks: 14
|
| 7 |
-
VampNet.n_conditioning_codebooks: 4
|
| 8 |
-
VampNet.n_heads: 20
|
| 9 |
-
VampNet.n_layers: 16
|
| 10 |
-
fine_tune: true
|
| 11 |
-
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
| 12 |
-
save_path: ./runs/xeno-canto/c2f
|
| 13 |
-
train/AudioLoader.sources: &id001
|
| 14 |
-
- /media/CHONK/hugo/loras/xeno-canto-2/
|
| 15 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/xeno-canto/coarse.yml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
fine_tune: true
|
| 4 |
-
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
| 5 |
-
save_path: ./runs/xeno-canto/coarse
|
| 6 |
-
train/AudioLoader.sources: &id001
|
| 7 |
-
- /media/CHONK/hugo/loras/xeno-canto-2/
|
| 8 |
-
val/AudioLoader.sources: *id001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated/xeno-canto/interface.yml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
AudioLoader.sources:
|
| 2 |
-
- - /media/CHONK/hugo/loras/xeno-canto-2/
|
| 3 |
-
Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth
|
| 4 |
-
Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth
|
| 5 |
-
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
| 6 |
-
Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth
|
| 7 |
-
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/birds.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/spotdl/subsets/birds
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/spotdl/subsets/birds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/birdss.yml
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/spotdl/subsets/birds
|
| 8 |
-
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
|
| 9 |
-
|
| 10 |
-
val/AudioLoader.sources:
|
| 11 |
-
- /media/CHONK/hugo/spotdl/subsets/birds
|
| 12 |
-
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/constructions.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/ella-baila-sola.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/gas-station.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/lora-is-this-charlie-parker.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/lora.yml
CHANGED
|
@@ -3,20 +3,18 @@ $include:
|
|
| 3 |
|
| 4 |
fine_tune: True
|
| 5 |
|
| 6 |
-
train/AudioDataset.n_examples:
|
| 7 |
-
|
| 8 |
-
val/AudioDataset.n_examples: 10
|
| 9 |
|
| 10 |
|
| 11 |
NoamScheduler.warmup: 500
|
| 12 |
|
| 13 |
batch_size: 7
|
| 14 |
num_workers: 7
|
| 15 |
-
|
| 16 |
-
save_audio_epochs: 10
|
| 17 |
|
| 18 |
AdamW.lr: 0.0001
|
| 19 |
|
| 20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
| 21 |
AudioDataset.without_replacement: False
|
| 22 |
-
|
|
|
|
| 3 |
|
| 4 |
fine_tune: True
|
| 5 |
|
| 6 |
+
train/AudioDataset.n_examples: 100000000
|
| 7 |
+
val/AudioDataset.n_examples: 100
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
NoamScheduler.warmup: 500
|
| 11 |
|
| 12 |
batch_size: 7
|
| 13 |
num_workers: 7
|
| 14 |
+
save_iters: [100000, 200000, 300000, 4000000, 500000]
|
|
|
|
| 15 |
|
| 16 |
AdamW.lr: 0.0001
|
| 17 |
|
| 18 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
| 19 |
AudioDataset.without_replacement: False
|
| 20 |
+
num_iters: 500000
|
conf/lora/underworld.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/spotdl/subsets/underworld.mp3
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/spotdl/subsets/underworld.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/xeno-canto/c2f.yml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/xeno-canto-2
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/xeno-canto-2
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
VampNet.n_codebooks: 14
|
| 14 |
-
VampNet.n_conditioning_codebooks: 4
|
| 15 |
-
|
| 16 |
-
VampNet.embedding_dim: 1280
|
| 17 |
-
VampNet.n_layers: 16
|
| 18 |
-
VampNet.n_heads: 20
|
| 19 |
-
|
| 20 |
-
AudioDataset.duration: 3.0
|
| 21 |
-
AudioDataset.loudness_cutoff: -40.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/xeno-canto/coarse.yml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/lora/lora.yml
|
| 3 |
-
|
| 4 |
-
fine_tune: True
|
| 5 |
-
|
| 6 |
-
train/AudioLoader.sources:
|
| 7 |
-
- /media/CHONK/hugo/xeno-canto-2
|
| 8 |
-
|
| 9 |
-
val/AudioLoader.sources:
|
| 10 |
-
- /media/CHONK/hugo/xeno-canto-2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/vampnet-musdb-drums.yml
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
$include:
|
| 2 |
-
- conf/vampnet.yml
|
| 3 |
-
|
| 4 |
-
VampNet.embedding_dim: 512
|
| 5 |
-
VampNet.n_layers: 12
|
| 6 |
-
VampNet.n_heads: 8
|
| 7 |
-
|
| 8 |
-
AudioDataset.duration: 12.0
|
| 9 |
-
|
| 10 |
-
train/AudioDataset.n_examples: 10000000
|
| 11 |
-
train/AudioLoader.sources:
|
| 12 |
-
- /data/musdb18hq/train/**/*drums.wav
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
val/AudioDataset.n_examples: 500
|
| 16 |
-
val/AudioLoader.sources:
|
| 17 |
-
- /data/musdb18hq/test/**/*drums.wav
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
test/AudioDataset.n_examples: 1000
|
| 21 |
-
test/AudioLoader.sources:
|
| 22 |
-
- /data/musdb18hq/test/**/*drums.wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/vampnet.yml
CHANGED
|
@@ -1,21 +1,17 @@
|
|
| 1 |
|
| 2 |
-
codec_ckpt: ./models/
|
| 3 |
save_path: ckpt
|
| 4 |
-
max_epochs: 1000
|
| 5 |
-
epoch_length: 1000
|
| 6 |
-
save_audio_epochs: 2
|
| 7 |
-
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
batch_size: 8
|
| 15 |
num_workers: 10
|
| 16 |
|
| 17 |
# Optimization
|
| 18 |
-
detect_anomaly: false
|
| 19 |
amp: false
|
| 20 |
|
| 21 |
CrossEntropyLoss.label_smoothing: 0.1
|
|
@@ -25,9 +21,6 @@ AdamW.lr: 0.001
|
|
| 25 |
NoamScheduler.factor: 2.0
|
| 26 |
NoamScheduler.warmup: 10000
|
| 27 |
|
| 28 |
-
PitchShift.shift_amount: [const, 0]
|
| 29 |
-
PitchShift.prob: 0.0
|
| 30 |
-
|
| 31 |
VampNet.vocab_size: 1024
|
| 32 |
VampNet.n_codebooks: 4
|
| 33 |
VampNet.n_conditioning_codebooks: 0
|
|
@@ -48,12 +41,9 @@ AudioDataset.duration: 10.0
|
|
| 48 |
|
| 49 |
train/AudioDataset.n_examples: 10000000
|
| 50 |
train/AudioLoader.sources:
|
| 51 |
-
- /
|
| 52 |
|
| 53 |
val/AudioDataset.n_examples: 2000
|
| 54 |
val/AudioLoader.sources:
|
| 55 |
-
- /
|
| 56 |
|
| 57 |
-
test/AudioDataset.n_examples: 1000
|
| 58 |
-
test/AudioLoader.sources:
|
| 59 |
-
- /data/spotdl/audio/test
|
|
|
|
| 1 |
|
| 2 |
+
codec_ckpt: ./models/vampnet/codec.pth
|
| 3 |
save_path: ckpt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
num_iters: 1000000000
|
| 6 |
+
save_iters: [10000, 50000, 100000, 300000, 500000]
|
| 7 |
+
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
| 8 |
+
sample_freq: 10000
|
| 9 |
+
val_freq: 1000
|
| 10 |
|
| 11 |
batch_size: 8
|
| 12 |
num_workers: 10
|
| 13 |
|
| 14 |
# Optimization
|
|
|
|
| 15 |
amp: false
|
| 16 |
|
| 17 |
CrossEntropyLoss.label_smoothing: 0.1
|
|
|
|
| 21 |
NoamScheduler.factor: 2.0
|
| 22 |
NoamScheduler.warmup: 10000
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
VampNet.vocab_size: 1024
|
| 25 |
VampNet.n_codebooks: 4
|
| 26 |
VampNet.n_conditioning_codebooks: 0
|
|
|
|
| 41 |
|
| 42 |
train/AudioDataset.n_examples: 10000000
|
| 43 |
train/AudioLoader.sources:
|
| 44 |
+
- /media/CHONK/hugo/spotdl/audio-train
|
| 45 |
|
| 46 |
val/AudioDataset.n_examples: 2000
|
| 47 |
val/AudioLoader.sources:
|
| 48 |
+
- /media/CHONK/hugo/spotdl/audio-val
|
| 49 |
|
|
|
|
|
|
|
|
|
scripts/exp/fine_tune.py
CHANGED
|
@@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
| 35 |
"AudioDataset.duration": 3.0,
|
| 36 |
"AudioDataset.loudness_cutoff": -40.0,
|
| 37 |
"save_path": f"./runs/{name}/c2f",
|
| 38 |
-
"fine_tune_checkpoint": "./models/
|
| 39 |
}
|
| 40 |
|
| 41 |
finetune_coarse_conf = {
|
|
@@ -44,17 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
| 44 |
"train/AudioLoader.sources": audio_files_or_folders,
|
| 45 |
"val/AudioLoader.sources": audio_files_or_folders,
|
| 46 |
"save_path": f"./runs/{name}/coarse",
|
| 47 |
-
"fine_tune_checkpoint": "./models/
|
| 48 |
}
|
| 49 |
|
| 50 |
interface_conf = {
|
| 51 |
-
"Interface.coarse_ckpt": f"./models/
|
| 52 |
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
| 53 |
|
| 54 |
-
"Interface.coarse2fine_ckpt": f"./models/
|
| 55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
| 56 |
|
| 57 |
-
"Interface.codec_ckpt": "./models/
|
| 58 |
"AudioLoader.sources": [audio_files_or_folders],
|
| 59 |
}
|
| 60 |
|
|
|
|
| 35 |
"AudioDataset.duration": 3.0,
|
| 36 |
"AudioDataset.loudness_cutoff": -40.0,
|
| 37 |
"save_path": f"./runs/{name}/c2f",
|
| 38 |
+
"fine_tune_checkpoint": "./models/vampnet/c2f.pth"
|
| 39 |
}
|
| 40 |
|
| 41 |
finetune_coarse_conf = {
|
|
|
|
| 44 |
"train/AudioLoader.sources": audio_files_or_folders,
|
| 45 |
"val/AudioLoader.sources": audio_files_or_folders,
|
| 46 |
"save_path": f"./runs/{name}/coarse",
|
| 47 |
+
"fine_tune_checkpoint": "./models/vampnet/coarse.pth"
|
| 48 |
}
|
| 49 |
|
| 50 |
interface_conf = {
|
| 51 |
+
"Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
|
| 52 |
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
| 53 |
|
| 54 |
+
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
| 55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
| 56 |
|
| 57 |
+
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 58 |
"AudioLoader.sources": [audio_files_or_folders],
|
| 59 |
}
|
| 60 |
|
scripts/exp/train.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
-
import time
|
| 4 |
import warnings
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Optional
|
|
|
|
| 7 |
|
| 8 |
import argbind
|
| 9 |
import audiotools as at
|
|
@@ -23,6 +23,12 @@ from vampnet import mask as pmask
|
|
| 23 |
# from dac.model.dac import DAC
|
| 24 |
from lac.model.lac import LAC as DAC
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Enable cudnn autotuner to speed up training
|
| 28 |
# (can be altered by the funcs.seed function)
|
|
@@ -85,11 +91,7 @@ def build_datasets(args, sample_rate: int):
|
|
| 85 |
)
|
| 86 |
with argbind.scope(args, "val"):
|
| 87 |
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
| 88 |
-
|
| 89 |
-
test_data = AudioDataset(
|
| 90 |
-
AudioLoader(), sample_rate, transform=build_transform()
|
| 91 |
-
)
|
| 92 |
-
return train_data, val_data, test_data
|
| 93 |
|
| 94 |
|
| 95 |
def rand_float(shape, low, high, rng):
|
|
@@ -100,16 +102,393 @@ def flip_coin(shape, p, rng):
|
|
| 100 |
return rng.draw(shape)[:, 0] < p
|
| 101 |
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
@argbind.bind(without_prefix=True)
|
| 104 |
def load(
|
| 105 |
args,
|
| 106 |
accel: at.ml.Accelerator,
|
|
|
|
| 107 |
save_path: str,
|
| 108 |
resume: bool = False,
|
| 109 |
tag: str = "latest",
|
| 110 |
load_weights: bool = False,
|
| 111 |
fine_tune_checkpoint: Optional[str] = None,
|
| 112 |
-
|
|
|
|
| 113 |
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
| 114 |
codec.eval()
|
| 115 |
|
|
@@ -121,6 +500,7 @@ def load(
|
|
| 121 |
"map_location": "cpu",
|
| 122 |
"package": not load_weights,
|
| 123 |
}
|
|
|
|
| 124 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
| 125 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
| 126 |
else:
|
|
@@ -147,89 +527,57 @@ def load(
|
|
| 147 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
| 148 |
scheduler.step()
|
| 149 |
|
| 150 |
-
trainer_state = {"state_dict": None, "start_idx": 0}
|
| 151 |
-
|
| 152 |
if "optimizer.pth" in v_extra:
|
| 153 |
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
| 154 |
-
if "scheduler.pth" in v_extra:
|
| 155 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
| 156 |
-
if "
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
"model": model,
|
| 161 |
-
"codec": codec,
|
| 162 |
-
"optimizer": optimizer,
|
| 163 |
-
"scheduler": scheduler,
|
| 164 |
-
"trainer_state": trainer_state,
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def num_params_hook(o, p):
|
| 170 |
-
return o + f" {p/1e6:<.3f}M params."
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def add_num_params_repr_hook(model):
|
| 174 |
-
import numpy as np
|
| 175 |
-
from functools import partial
|
| 176 |
-
|
| 177 |
-
for n, m in model.named_modules():
|
| 178 |
-
o = m.extra_repr()
|
| 179 |
-
p = sum([np.prod(p.size()) for p in m.parameters()])
|
| 180 |
-
|
| 181 |
-
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def accuracy(
|
| 185 |
-
preds: torch.Tensor,
|
| 186 |
-
target: torch.Tensor,
|
| 187 |
-
top_k: int = 1,
|
| 188 |
-
ignore_index: Optional[int] = None,
|
| 189 |
-
) -> torch.Tensor:
|
| 190 |
-
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
| 191 |
-
preds = rearrange(preds, "b p s -> (b s) p")
|
| 192 |
-
target = rearrange(target, "b s -> (b s)")
|
| 193 |
-
|
| 194 |
-
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
| 195 |
-
if ignore_index is not None:
|
| 196 |
-
# Create a mask for the ignored index
|
| 197 |
-
mask = target != ignore_index
|
| 198 |
-
# Apply the mask to the target and predictions
|
| 199 |
-
preds = preds[mask]
|
| 200 |
-
target = target[mask]
|
| 201 |
|
| 202 |
-
|
| 203 |
-
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
| 204 |
|
| 205 |
-
#
|
| 206 |
-
|
| 207 |
|
| 208 |
-
#
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
|
| 214 |
@argbind.bind(without_prefix=True)
|
| 215 |
def train(
|
| 216 |
args,
|
| 217 |
accel: at.ml.Accelerator,
|
| 218 |
-
codec_ckpt: str = None,
|
| 219 |
seed: int = 0,
|
|
|
|
| 220 |
save_path: str = "ckpt",
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
batch_size: int =
|
| 226 |
-
grad_acc_steps: int = 1,
|
| 227 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
| 228 |
num_workers: int = 10,
|
| 229 |
-
detect_anomaly: bool = False,
|
| 230 |
-
grad_clip_val: float = 5.0,
|
| 231 |
fine_tune: bool = False,
|
| 232 |
-
quiet: bool = False,
|
| 233 |
):
|
| 234 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
| 235 |
|
|
@@ -241,376 +589,76 @@ def train(
|
|
| 241 |
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
| 242 |
argbind.dump_args(args, f"{save_path}/args.yml")
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
codec = loaded["codec"]
|
| 248 |
-
optimizer = loaded["optimizer"]
|
| 249 |
-
scheduler = loaded["scheduler"]
|
| 250 |
-
trainer_state = loaded["trainer_state"]
|
| 251 |
-
|
| 252 |
-
sample_rate = codec.sample_rate
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
-
# log a model summary w/ num params
|
| 258 |
-
if accel.local_rank == 0:
|
| 259 |
-
add_num_params_repr_hook(accel.unwrap(model))
|
| 260 |
-
with open(f"{save_path}/model.txt", "w") as f:
|
| 261 |
-
f.write(repr(accel.unwrap(model)))
|
| 262 |
|
| 263 |
-
# load the datasets
|
| 264 |
-
train_data, val_data, _ = build_datasets(args, sample_rate)
|
| 265 |
train_dataloader = accel.prepare_dataloader(
|
| 266 |
-
train_data,
|
| 267 |
-
start_idx=
|
| 268 |
num_workers=num_workers,
|
| 269 |
batch_size=batch_size,
|
| 270 |
-
collate_fn=train_data.collate,
|
| 271 |
)
|
| 272 |
val_dataloader = accel.prepare_dataloader(
|
| 273 |
-
val_data,
|
| 274 |
start_idx=0,
|
| 275 |
num_workers=num_workers,
|
| 276 |
batch_size=batch_size,
|
| 277 |
-
collate_fn=val_data.collate,
|
|
|
|
| 278 |
)
|
| 279 |
|
| 280 |
-
|
| 281 |
|
| 282 |
if fine_tune:
|
| 283 |
-
|
| 284 |
-
lora.mark_only_lora_as_trainable(model)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
class Trainer(at.ml.BaseTrainer):
|
| 288 |
-
_last_grad_norm = 0.0
|
| 289 |
-
|
| 290 |
-
def _metrics(self, vn, z_hat, r, target, flat_mask, output):
|
| 291 |
-
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
| 292 |
-
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
| 293 |
-
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 294 |
-
|
| 295 |
-
assert target.shape[0] == r.shape[0]
|
| 296 |
-
# grab the indices of the r values that are in the range
|
| 297 |
-
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
| 298 |
-
|
| 299 |
-
# grab the target and z_hat values that are in the range
|
| 300 |
-
r_unmasked_target = unmasked_target[r_idx]
|
| 301 |
-
r_masked_target = masked_target[r_idx]
|
| 302 |
-
r_z_hat = z_hat[r_idx]
|
| 303 |
-
|
| 304 |
-
for topk in (1, 25):
|
| 305 |
-
s, e = r_range
|
| 306 |
-
tag = f"accuracy-{s}-{e}/top{topk}"
|
| 307 |
-
|
| 308 |
-
output[f"{tag}/unmasked"] = accuracy(
|
| 309 |
-
preds=r_z_hat,
|
| 310 |
-
target=r_unmasked_target,
|
| 311 |
-
ignore_index=IGNORE_INDEX,
|
| 312 |
-
top_k=topk,
|
| 313 |
-
)
|
| 314 |
-
output[f"{tag}/masked"] = accuracy(
|
| 315 |
-
preds=r_z_hat,
|
| 316 |
-
target=r_masked_target,
|
| 317 |
-
ignore_index=IGNORE_INDEX,
|
| 318 |
-
top_k=topk,
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
def train_loop(self, engine, batch):
|
| 322 |
-
model.train()
|
| 323 |
-
batch = at.util.prepare_batch(batch, accel.device)
|
| 324 |
-
signal = apply_transform(train_data.transform, batch)
|
| 325 |
-
|
| 326 |
-
output = {}
|
| 327 |
-
vn = accel.unwrap(model)
|
| 328 |
-
with accel.autocast():
|
| 329 |
-
with torch.inference_mode():
|
| 330 |
-
codec.to(accel.device)
|
| 331 |
-
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 332 |
-
z = z[:, : vn.n_codebooks, :]
|
| 333 |
-
|
| 334 |
-
n_batch = z.shape[0]
|
| 335 |
-
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 336 |
-
|
| 337 |
-
mask = pmask.random(z, r)
|
| 338 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 339 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 340 |
-
|
| 341 |
-
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 342 |
-
|
| 343 |
-
dtype = torch.bfloat16 if accel.amp else None
|
| 344 |
-
with accel.autocast(dtype=dtype):
|
| 345 |
-
z_hat = model(z_mask_latent, r)
|
| 346 |
-
|
| 347 |
-
target = codebook_flatten(
|
| 348 |
-
z[:, vn.n_conditioning_codebooks :, :],
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
flat_mask = codebook_flatten(
|
| 352 |
-
mask[:, vn.n_conditioning_codebooks :, :],
|
| 353 |
-
)
|
| 354 |
-
|
| 355 |
-
# replace target with ignore index for masked tokens
|
| 356 |
-
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 357 |
-
output["loss"] = criterion(z_hat, t_masked)
|
| 358 |
-
|
| 359 |
-
self._metrics(
|
| 360 |
-
vn=vn,
|
| 361 |
-
r=r,
|
| 362 |
-
z_hat=z_hat,
|
| 363 |
-
target=target,
|
| 364 |
-
flat_mask=flat_mask,
|
| 365 |
-
output=output,
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
accel.backward(output["loss"] / grad_acc_steps)
|
| 370 |
-
|
| 371 |
-
output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
|
| 372 |
-
output["other/batch_size"] = z.shape[0]
|
| 373 |
-
|
| 374 |
-
if (
|
| 375 |
-
(engine.state.iteration % grad_acc_steps == 0)
|
| 376 |
-
or (engine.state.iteration % epoch_length == 0)
|
| 377 |
-
or (engine.state.iteration % epoch_length == 1)
|
| 378 |
-
): # (or we reached the end of the epoch)
|
| 379 |
-
accel.scaler.unscale_(optimizer)
|
| 380 |
-
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
| 381 |
-
model.parameters(), grad_clip_val
|
| 382 |
-
)
|
| 383 |
-
self._last_grad_norm = output["other/grad_norm"]
|
| 384 |
-
|
| 385 |
-
accel.step(optimizer)
|
| 386 |
-
optimizer.zero_grad()
|
| 387 |
-
|
| 388 |
-
scheduler.step()
|
| 389 |
-
accel.update()
|
| 390 |
-
else:
|
| 391 |
-
output["other/grad_norm"] = self._last_grad_norm
|
| 392 |
-
|
| 393 |
-
return {k: v for k, v in sorted(output.items())}
|
| 394 |
-
|
| 395 |
-
@torch.no_grad()
|
| 396 |
-
def val_loop(self, engine, batch):
|
| 397 |
-
model.eval()
|
| 398 |
-
codec.eval()
|
| 399 |
-
batch = at.util.prepare_batch(batch, accel.device)
|
| 400 |
-
signal = apply_transform(val_data.transform, batch)
|
| 401 |
-
|
| 402 |
-
vn = accel.unwrap(model)
|
| 403 |
-
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 404 |
-
z = z[:, : vn.n_codebooks, :]
|
| 405 |
-
|
| 406 |
-
n_batch = z.shape[0]
|
| 407 |
-
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 408 |
-
|
| 409 |
-
mask = pmask.random(z, r)
|
| 410 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 411 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 412 |
|
| 413 |
-
|
|
|
|
|
|
|
| 414 |
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
)
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
)
|
| 424 |
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 428 |
-
output["loss"] = criterion(z_hat, t_masked)
|
| 429 |
-
|
| 430 |
-
self._metrics(
|
| 431 |
-
vn=vn,
|
| 432 |
-
r=r,
|
| 433 |
-
z_hat=z_hat,
|
| 434 |
-
target=target,
|
| 435 |
-
flat_mask=flat_mask,
|
| 436 |
-
output=output,
|
| 437 |
)
|
| 438 |
|
| 439 |
-
|
|
|
|
| 440 |
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
if self.state.epoch % save_audio_epochs == 0:
|
| 449 |
-
self.save_samples()
|
| 450 |
-
|
| 451 |
-
tags = ["latest"]
|
| 452 |
-
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
| 453 |
-
self.print(f"Saving to {str(Path('.').absolute())}")
|
| 454 |
-
|
| 455 |
-
if self.state.epoch in save_epochs:
|
| 456 |
-
tags.append(f"epoch={self.state.epoch}")
|
| 457 |
-
|
| 458 |
-
if self.is_best(engine, loss_key):
|
| 459 |
-
self.print(f"Best model so far")
|
| 460 |
-
tags.append("best")
|
| 461 |
-
|
| 462 |
-
if fine_tune:
|
| 463 |
-
for tag in tags:
|
| 464 |
-
# save the lora model
|
| 465 |
-
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
| 466 |
-
torch.save(
|
| 467 |
-
lora.lora_state_dict(accel.unwrap(model)),
|
| 468 |
-
f"{save_path}/{tag}/lora.pth"
|
| 469 |
-
)
|
| 470 |
-
|
| 471 |
-
for tag in tags:
|
| 472 |
-
model_extra = {
|
| 473 |
-
"optimizer.pth": optimizer.state_dict(),
|
| 474 |
-
"scheduler.pth": scheduler.state_dict(),
|
| 475 |
-
"trainer.pth": {
|
| 476 |
-
"start_idx": self.state.iteration * batch_size,
|
| 477 |
-
"state_dict": self.state_dict(),
|
| 478 |
-
},
|
| 479 |
-
"metadata.pth": metadata,
|
| 480 |
-
}
|
| 481 |
-
|
| 482 |
-
accel.unwrap(model).metadata = metadata
|
| 483 |
-
accel.unwrap(model).save_to_folder(
|
| 484 |
-
f"{save_path}/{tag}", model_extra,
|
| 485 |
-
)
|
| 486 |
-
|
| 487 |
-
def save_sampled(self, z):
|
| 488 |
-
num_samples = z.shape[0]
|
| 489 |
-
|
| 490 |
-
for i in range(num_samples):
|
| 491 |
-
sampled = accel.unwrap(model).generate(
|
| 492 |
-
codec=codec,
|
| 493 |
-
time_steps=z.shape[-1],
|
| 494 |
-
start_tokens=z[i : i + 1],
|
| 495 |
-
)
|
| 496 |
-
sampled.cpu().write_audio_to_tb(
|
| 497 |
-
f"sampled/{i}",
|
| 498 |
-
self.writer,
|
| 499 |
-
step=self.state.epoch,
|
| 500 |
-
plot_fn=None,
|
| 501 |
-
)
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
def save_imputation(self, z: torch.Tensor):
|
| 505 |
-
n_prefix = int(z.shape[-1] * 0.25)
|
| 506 |
-
n_suffix = int(z.shape[-1] * 0.25)
|
| 507 |
-
|
| 508 |
-
vn = accel.unwrap(model)
|
| 509 |
-
|
| 510 |
-
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
| 511 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 512 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 513 |
-
|
| 514 |
-
imputed_noisy = vn.to_signal(z_mask, codec)
|
| 515 |
-
imputed_true = vn.to_signal(z, codec)
|
| 516 |
-
|
| 517 |
-
imputed = []
|
| 518 |
-
for i in range(len(z)):
|
| 519 |
-
imputed.append(
|
| 520 |
-
vn.generate(
|
| 521 |
-
codec=codec,
|
| 522 |
-
time_steps=z.shape[-1],
|
| 523 |
-
start_tokens=z[i][None, ...],
|
| 524 |
-
mask=mask[i][None, ...],
|
| 525 |
-
)
|
| 526 |
-
)
|
| 527 |
-
imputed = AudioSignal.batch(imputed)
|
| 528 |
-
|
| 529 |
-
for i in range(len(val_idx)):
|
| 530 |
-
imputed_noisy[i].cpu().write_audio_to_tb(
|
| 531 |
-
f"imputed_noisy/{i}",
|
| 532 |
-
self.writer,
|
| 533 |
-
step=self.state.epoch,
|
| 534 |
-
plot_fn=None,
|
| 535 |
-
)
|
| 536 |
-
imputed[i].cpu().write_audio_to_tb(
|
| 537 |
-
f"imputed/{i}",
|
| 538 |
-
self.writer,
|
| 539 |
-
step=self.state.epoch,
|
| 540 |
-
plot_fn=None,
|
| 541 |
-
)
|
| 542 |
-
imputed_true[i].cpu().write_audio_to_tb(
|
| 543 |
-
f"imputed_true/{i}",
|
| 544 |
-
self.writer,
|
| 545 |
-
step=self.state.epoch,
|
| 546 |
-
plot_fn=None,
|
| 547 |
-
)
|
| 548 |
-
|
| 549 |
-
@torch.no_grad()
|
| 550 |
-
def save_samples(self):
|
| 551 |
-
model.eval()
|
| 552 |
-
codec.eval()
|
| 553 |
-
vn = accel.unwrap(model)
|
| 554 |
-
|
| 555 |
-
batch = [val_data[i] for i in val_idx]
|
| 556 |
-
batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
|
| 557 |
-
|
| 558 |
-
signal = apply_transform(val_data.transform, batch)
|
| 559 |
-
|
| 560 |
-
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 561 |
-
z = z[:, : vn.n_codebooks, :]
|
| 562 |
-
|
| 563 |
-
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
mask = pmask.random(z, r)
|
| 567 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 568 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 569 |
-
|
| 570 |
-
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 571 |
-
|
| 572 |
-
z_hat = model(z_mask_latent, r)
|
| 573 |
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
| 577 |
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
| 581 |
-
|
| 582 |
-
for i in range(generated.batch_size):
|
| 583 |
-
audio_dict = {
|
| 584 |
-
"original": signal[i],
|
| 585 |
-
"masked": masked[i],
|
| 586 |
-
"generated": generated[i],
|
| 587 |
-
"reconstructed": reconstructed[i],
|
| 588 |
-
}
|
| 589 |
-
for k, v in audio_dict.items():
|
| 590 |
-
v.cpu().write_audio_to_tb(
|
| 591 |
-
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
| 592 |
-
self.writer,
|
| 593 |
-
step=self.state.epoch,
|
| 594 |
-
plot_fn=None,
|
| 595 |
-
)
|
| 596 |
-
|
| 597 |
-
self.save_sampled(z)
|
| 598 |
-
self.save_imputation(z)
|
| 599 |
-
|
| 600 |
-
trainer = Trainer(writer=writer, quiet=quiet)
|
| 601 |
-
|
| 602 |
-
if trainer_state["state_dict"] is not None:
|
| 603 |
-
trainer.load_state_dict(trainer_state["state_dict"])
|
| 604 |
-
if hasattr(train_dataloader.sampler, "set_epoch"):
|
| 605 |
-
train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
|
| 606 |
-
|
| 607 |
-
trainer.run(
|
| 608 |
-
train_dataloader,
|
| 609 |
-
val_dataloader,
|
| 610 |
-
num_epochs=max_epochs,
|
| 611 |
-
epoch_length=epoch_length,
|
| 612 |
-
detect_anomaly=detect_anomaly,
|
| 613 |
-
)
|
| 614 |
|
| 615 |
|
| 616 |
if __name__ == "__main__":
|
|
@@ -618,4 +666,6 @@ if __name__ == "__main__":
|
|
| 618 |
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
| 619 |
with argbind.scope(args):
|
| 620 |
with Accelerator() as accel:
|
|
|
|
|
|
|
| 621 |
train(args, accel)
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sys
|
|
|
|
| 3 |
import warnings
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
|
| 8 |
import argbind
|
| 9 |
import audiotools as at
|
|
|
|
| 23 |
# from dac.model.dac import DAC
|
| 24 |
from lac.model.lac import LAC as DAC
|
| 25 |
|
| 26 |
+
from audiotools.ml.decorators import (
|
| 27 |
+
timer, Tracker, when
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
import loralib as lora
|
| 31 |
+
|
| 32 |
|
| 33 |
# Enable cudnn autotuner to speed up training
|
| 34 |
# (can be altered by the funcs.seed function)
|
|
|
|
| 91 |
)
|
| 92 |
with argbind.scope(args, "val"):
|
| 93 |
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
| 94 |
+
return train_data, val_data
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def rand_float(shape, low, high, rng):
|
|
|
|
| 102 |
return rng.draw(shape)[:, 0] < p
|
| 103 |
|
| 104 |
|
| 105 |
+
def num_params_hook(o, p):
|
| 106 |
+
return o + f" {p/1e6:<.3f}M params."
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def add_num_params_repr_hook(model):
|
| 110 |
+
import numpy as np
|
| 111 |
+
from functools import partial
|
| 112 |
+
|
| 113 |
+
for n, m in model.named_modules():
|
| 114 |
+
o = m.extra_repr()
|
| 115 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
| 116 |
+
|
| 117 |
+
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def accuracy(
|
| 121 |
+
preds: torch.Tensor,
|
| 122 |
+
target: torch.Tensor,
|
| 123 |
+
top_k: int = 1,
|
| 124 |
+
ignore_index: Optional[int] = None,
|
| 125 |
+
) -> torch.Tensor:
|
| 126 |
+
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
| 127 |
+
preds = rearrange(preds, "b p s -> (b s) p")
|
| 128 |
+
target = rearrange(target, "b s -> (b s)")
|
| 129 |
+
|
| 130 |
+
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
| 131 |
+
if ignore_index is not None:
|
| 132 |
+
# Create a mask for the ignored index
|
| 133 |
+
mask = target != ignore_index
|
| 134 |
+
# Apply the mask to the target and predictions
|
| 135 |
+
preds = preds[mask]
|
| 136 |
+
target = target[mask]
|
| 137 |
+
|
| 138 |
+
# Get the top-k predicted classes and their indices
|
| 139 |
+
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
| 140 |
+
|
| 141 |
+
# Determine if the true target is in the top-k predicted classes
|
| 142 |
+
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
| 143 |
+
|
| 144 |
+
# Calculate the accuracy
|
| 145 |
+
accuracy = torch.mean(correct.float())
|
| 146 |
+
|
| 147 |
+
return accuracy
|
| 148 |
+
|
| 149 |
+
def _metrics(z_hat, r, target, flat_mask, output):
|
| 150 |
+
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
| 151 |
+
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
| 152 |
+
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 153 |
+
|
| 154 |
+
assert target.shape[0] == r.shape[0]
|
| 155 |
+
# grab the indices of the r values that are in the range
|
| 156 |
+
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
| 157 |
+
|
| 158 |
+
# grab the target and z_hat values that are in the range
|
| 159 |
+
r_unmasked_target = unmasked_target[r_idx]
|
| 160 |
+
r_masked_target = masked_target[r_idx]
|
| 161 |
+
r_z_hat = z_hat[r_idx]
|
| 162 |
+
|
| 163 |
+
for topk in (1, 25):
|
| 164 |
+
s, e = r_range
|
| 165 |
+
tag = f"accuracy-{s}-{e}/top{topk}"
|
| 166 |
+
|
| 167 |
+
output[f"{tag}/unmasked"] = accuracy(
|
| 168 |
+
preds=r_z_hat,
|
| 169 |
+
target=r_unmasked_target,
|
| 170 |
+
ignore_index=IGNORE_INDEX,
|
| 171 |
+
top_k=topk,
|
| 172 |
+
)
|
| 173 |
+
output[f"{tag}/masked"] = accuracy(
|
| 174 |
+
preds=r_z_hat,
|
| 175 |
+
target=r_masked_target,
|
| 176 |
+
ignore_index=IGNORE_INDEX,
|
| 177 |
+
top_k=topk,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@dataclass
|
| 182 |
+
class State:
|
| 183 |
+
model: VampNet
|
| 184 |
+
codec: DAC
|
| 185 |
+
|
| 186 |
+
optimizer: AdamW
|
| 187 |
+
scheduler: NoamScheduler
|
| 188 |
+
criterion: CrossEntropyLoss
|
| 189 |
+
grad_clip_val: float
|
| 190 |
+
|
| 191 |
+
rng: torch.quasirandom.SobolEngine
|
| 192 |
+
|
| 193 |
+
train_data: AudioDataset
|
| 194 |
+
val_data: AudioDataset
|
| 195 |
+
|
| 196 |
+
tracker: Tracker
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@timer()
|
| 200 |
+
def train_loop(state: State, batch: dict, accel: Accelerator):
|
| 201 |
+
state.model.train()
|
| 202 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
| 203 |
+
signal = apply_transform(state.train_data.transform, batch)
|
| 204 |
+
|
| 205 |
+
output = {}
|
| 206 |
+
vn = accel.unwrap(state.model)
|
| 207 |
+
with accel.autocast():
|
| 208 |
+
with torch.inference_mode():
|
| 209 |
+
state.codec.to(accel.device)
|
| 210 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 211 |
+
z = z[:, : vn.n_codebooks, :]
|
| 212 |
+
|
| 213 |
+
n_batch = z.shape[0]
|
| 214 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
| 215 |
+
|
| 216 |
+
mask = pmask.random(z, r)
|
| 217 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 218 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 219 |
+
|
| 220 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 221 |
+
|
| 222 |
+
dtype = torch.bfloat16 if accel.amp else None
|
| 223 |
+
with accel.autocast(dtype=dtype):
|
| 224 |
+
z_hat = state.model(z_mask_latent, r)
|
| 225 |
+
|
| 226 |
+
target = codebook_flatten(
|
| 227 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
flat_mask = codebook_flatten(
|
| 231 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# replace target with ignore index for masked tokens
|
| 235 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 236 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
| 237 |
+
|
| 238 |
+
_metrics(
|
| 239 |
+
r=r,
|
| 240 |
+
z_hat=z_hat,
|
| 241 |
+
target=target,
|
| 242 |
+
flat_mask=flat_mask,
|
| 243 |
+
output=output,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
accel.backward(output["loss"])
|
| 248 |
+
|
| 249 |
+
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
| 250 |
+
output["other/batch_size"] = z.shape[0]
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
accel.scaler.unscale_(state.optimizer)
|
| 254 |
+
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
| 255 |
+
state.model.parameters(), state.grad_clip_val
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
accel.step(state.optimizer)
|
| 259 |
+
state.optimizer.zero_grad()
|
| 260 |
+
|
| 261 |
+
state.scheduler.step()
|
| 262 |
+
accel.update()
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
return {k: v for k, v in sorted(output.items())}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@timer()
|
| 269 |
+
@torch.no_grad()
|
| 270 |
+
def val_loop(state: State, batch: dict, accel: Accelerator):
|
| 271 |
+
state.model.eval()
|
| 272 |
+
state.codec.eval()
|
| 273 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
| 274 |
+
signal = apply_transform(state.val_data.transform, batch)
|
| 275 |
+
|
| 276 |
+
vn = accel.unwrap(state.model)
|
| 277 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 278 |
+
z = z[:, : vn.n_codebooks, :]
|
| 279 |
+
|
| 280 |
+
n_batch = z.shape[0]
|
| 281 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
| 282 |
+
|
| 283 |
+
mask = pmask.random(z, r)
|
| 284 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 285 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 286 |
+
|
| 287 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 288 |
+
|
| 289 |
+
z_hat = state.model(z_mask_latent, r)
|
| 290 |
+
|
| 291 |
+
target = codebook_flatten(
|
| 292 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
flat_mask = codebook_flatten(
|
| 296 |
+
mask[:, vn.n_conditioning_codebooks :, :]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
output = {}
|
| 300 |
+
# replace target with ignore index for masked tokens
|
| 301 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 302 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
| 303 |
+
|
| 304 |
+
_metrics(
|
| 305 |
+
r=r,
|
| 306 |
+
z_hat=z_hat,
|
| 307 |
+
target=target,
|
| 308 |
+
flat_mask=flat_mask,
|
| 309 |
+
output=output,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return output
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def validate(state, val_dataloader, accel):
|
| 316 |
+
for batch in val_dataloader:
|
| 317 |
+
output = val_loop(state, batch, accel)
|
| 318 |
+
# Consolidate state dicts if using ZeroRedundancyOptimizer
|
| 319 |
+
if hasattr(state.optimizer, "consolidate_state_dict"):
|
| 320 |
+
state.optimizer.consolidate_state_dict()
|
| 321 |
+
return output
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def checkpoint(state, save_iters, save_path, fine_tune):
|
| 325 |
+
if accel.local_rank != 0:
|
| 326 |
+
state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
| 327 |
+
return
|
| 328 |
+
|
| 329 |
+
metadata = {"logs": dict(state.tracker.history)}
|
| 330 |
+
|
| 331 |
+
tags = ["latest"]
|
| 332 |
+
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
| 333 |
+
|
| 334 |
+
if state.tracker.step in save_iters:
|
| 335 |
+
tags.append(f"{state.tracker.step // 1000}k")
|
| 336 |
+
|
| 337 |
+
if state.tracker.is_best("val", "loss"):
|
| 338 |
+
state.tracker.print(f"Best model so far")
|
| 339 |
+
tags.append("best")
|
| 340 |
+
|
| 341 |
+
if fine_tune:
|
| 342 |
+
for tag in tags:
|
| 343 |
+
# save the lora model
|
| 344 |
+
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
| 345 |
+
torch.save(
|
| 346 |
+
lora.lora_state_dict(accel.unwrap(state.model)),
|
| 347 |
+
f"{save_path}/{tag}/lora.pth"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
for tag in tags:
|
| 351 |
+
model_extra = {
|
| 352 |
+
"optimizer.pth": state.optimizer.state_dict(),
|
| 353 |
+
"scheduler.pth": state.scheduler.state_dict(),
|
| 354 |
+
"tracker.pth": state.tracker.state_dict(),
|
| 355 |
+
"metadata.pth": metadata,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
accel.unwrap(state.model).metadata = metadata
|
| 359 |
+
accel.unwrap(state.model).save_to_folder(
|
| 360 |
+
f"{save_path}/{tag}", model_extra, package=False
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def save_sampled(state, z, writer):
|
| 365 |
+
num_samples = z.shape[0]
|
| 366 |
+
|
| 367 |
+
for i in range(num_samples):
|
| 368 |
+
sampled = accel.unwrap(state.model).generate(
|
| 369 |
+
codec=state.codec,
|
| 370 |
+
time_steps=z.shape[-1],
|
| 371 |
+
start_tokens=z[i : i + 1],
|
| 372 |
+
)
|
| 373 |
+
sampled.cpu().write_audio_to_tb(
|
| 374 |
+
f"sampled/{i}",
|
| 375 |
+
writer,
|
| 376 |
+
step=state.tracker.step,
|
| 377 |
+
plot_fn=None,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def save_imputation(state, z, val_idx, writer):
|
| 382 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
| 383 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
| 384 |
+
|
| 385 |
+
vn = accel.unwrap(state.model)
|
| 386 |
+
|
| 387 |
+
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
| 388 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 389 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 390 |
+
|
| 391 |
+
imputed_noisy = vn.to_signal(z_mask, state.codec)
|
| 392 |
+
imputed_true = vn.to_signal(z, state.codec)
|
| 393 |
+
|
| 394 |
+
imputed = []
|
| 395 |
+
for i in range(len(z)):
|
| 396 |
+
imputed.append(
|
| 397 |
+
vn.generate(
|
| 398 |
+
codec=state.codec,
|
| 399 |
+
time_steps=z.shape[-1],
|
| 400 |
+
start_tokens=z[i][None, ...],
|
| 401 |
+
mask=mask[i][None, ...],
|
| 402 |
+
)
|
| 403 |
+
)
|
| 404 |
+
imputed = AudioSignal.batch(imputed)
|
| 405 |
+
|
| 406 |
+
for i in range(len(val_idx)):
|
| 407 |
+
imputed_noisy[i].cpu().write_audio_to_tb(
|
| 408 |
+
f"imputed_noisy/{i}",
|
| 409 |
+
writer,
|
| 410 |
+
step=state.tracker.step,
|
| 411 |
+
plot_fn=None,
|
| 412 |
+
)
|
| 413 |
+
imputed[i].cpu().write_audio_to_tb(
|
| 414 |
+
f"imputed/{i}",
|
| 415 |
+
writer,
|
| 416 |
+
step=state.tracker.step,
|
| 417 |
+
plot_fn=None,
|
| 418 |
+
)
|
| 419 |
+
imputed_true[i].cpu().write_audio_to_tb(
|
| 420 |
+
f"imputed_true/{i}",
|
| 421 |
+
writer,
|
| 422 |
+
step=state.tracker.step,
|
| 423 |
+
plot_fn=None,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@torch.no_grad()
|
| 428 |
+
def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
| 429 |
+
state.model.eval()
|
| 430 |
+
state.codec.eval()
|
| 431 |
+
vn = accel.unwrap(state.model)
|
| 432 |
+
|
| 433 |
+
batch = [state.val_data[i] for i in val_idx]
|
| 434 |
+
batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
|
| 435 |
+
|
| 436 |
+
signal = apply_transform(state.val_data.transform, batch)
|
| 437 |
+
|
| 438 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 439 |
+
z = z[:, : vn.n_codebooks, :]
|
| 440 |
+
|
| 441 |
+
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
mask = pmask.random(z, r)
|
| 445 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 446 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 447 |
+
|
| 448 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 449 |
+
|
| 450 |
+
z_hat = state.model(z_mask_latent, r)
|
| 451 |
+
|
| 452 |
+
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
| 453 |
+
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
| 454 |
+
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
| 455 |
+
|
| 456 |
+
generated = vn.to_signal(z_pred, state.codec)
|
| 457 |
+
reconstructed = vn.to_signal(z, state.codec)
|
| 458 |
+
masked = vn.to_signal(z_mask.squeeze(1), state.codec)
|
| 459 |
+
|
| 460 |
+
for i in range(generated.batch_size):
|
| 461 |
+
audio_dict = {
|
| 462 |
+
"original": signal[i],
|
| 463 |
+
"masked": masked[i],
|
| 464 |
+
"generated": generated[i],
|
| 465 |
+
"reconstructed": reconstructed[i],
|
| 466 |
+
}
|
| 467 |
+
for k, v in audio_dict.items():
|
| 468 |
+
v.cpu().write_audio_to_tb(
|
| 469 |
+
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
| 470 |
+
writer,
|
| 471 |
+
step=state.tracker.step,
|
| 472 |
+
plot_fn=None,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
save_sampled(state=state, z=z, writer=writer)
|
| 476 |
+
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
|
| 480 |
@argbind.bind(without_prefix=True)
|
| 481 |
def load(
|
| 482 |
args,
|
| 483 |
accel: at.ml.Accelerator,
|
| 484 |
+
tracker: Tracker,
|
| 485 |
save_path: str,
|
| 486 |
resume: bool = False,
|
| 487 |
tag: str = "latest",
|
| 488 |
load_weights: bool = False,
|
| 489 |
fine_tune_checkpoint: Optional[str] = None,
|
| 490 |
+
grad_clip_val: float = 5.0,
|
| 491 |
+
) -> State:
|
| 492 |
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
| 493 |
codec.eval()
|
| 494 |
|
|
|
|
| 500 |
"map_location": "cpu",
|
| 501 |
"package": not load_weights,
|
| 502 |
}
|
| 503 |
+
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
| 505 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
| 506 |
else:
|
|
|
|
| 527 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
| 528 |
scheduler.step()
|
| 529 |
|
|
|
|
|
|
|
| 530 |
if "optimizer.pth" in v_extra:
|
| 531 |
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
|
|
|
| 532 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
| 533 |
+
if "tracker.pth" in v_extra:
|
| 534 |
+
tracker.load_state_dict(v_extra["tracker.pth"])
|
| 535 |
+
|
| 536 |
+
criterion = CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
+
sample_rate = codec.sample_rate
|
|
|
|
| 539 |
|
| 540 |
+
# a better rng for sampling from our schedule
|
| 541 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
| 542 |
|
| 543 |
+
# log a model summary w/ num params
|
| 544 |
+
if accel.local_rank == 0:
|
| 545 |
+
add_num_params_repr_hook(accel.unwrap(model))
|
| 546 |
+
with open(f"{save_path}/model.txt", "w") as f:
|
| 547 |
+
f.write(repr(accel.unwrap(model)))
|
| 548 |
|
| 549 |
+
# load the datasets
|
| 550 |
+
train_data, val_data = build_datasets(args, sample_rate)
|
| 551 |
+
|
| 552 |
+
return State(
|
| 553 |
+
tracker=tracker,
|
| 554 |
+
model=model,
|
| 555 |
+
codec=codec,
|
| 556 |
+
optimizer=optimizer,
|
| 557 |
+
scheduler=scheduler,
|
| 558 |
+
criterion=criterion,
|
| 559 |
+
rng=rng,
|
| 560 |
+
train_data=train_data,
|
| 561 |
+
val_data=val_data,
|
| 562 |
+
grad_clip_val=grad_clip_val,
|
| 563 |
+
)
|
| 564 |
|
| 565 |
|
| 566 |
@argbind.bind(without_prefix=True)
|
| 567 |
def train(
|
| 568 |
args,
|
| 569 |
accel: at.ml.Accelerator,
|
|
|
|
| 570 |
seed: int = 0,
|
| 571 |
+
codec_ckpt: str = None,
|
| 572 |
save_path: str = "ckpt",
|
| 573 |
+
num_iters: int = int(1000e6),
|
| 574 |
+
save_iters: list = [10000, 50000, 100000, 300000, 500000,],
|
| 575 |
+
sample_freq: int = 10000,
|
| 576 |
+
val_freq: int = 1000,
|
| 577 |
+
batch_size: int = 12,
|
|
|
|
| 578 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
| 579 |
num_workers: int = 10,
|
|
|
|
|
|
|
| 580 |
fine_tune: bool = False,
|
|
|
|
| 581 |
):
|
| 582 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
| 583 |
|
|
|
|
| 589 |
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
| 590 |
argbind.dump_args(args, f"{save_path}/args.yml")
|
| 591 |
|
| 592 |
+
tracker = Tracker(
|
| 593 |
+
writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
|
| 594 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
+
# load the codec model
|
| 597 |
+
state: State = load(
|
| 598 |
+
args=args,
|
| 599 |
+
accel=accel,
|
| 600 |
+
tracker=tracker,
|
| 601 |
+
save_path=save_path)
|
| 602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
|
|
|
|
|
|
| 604 |
train_dataloader = accel.prepare_dataloader(
|
| 605 |
+
state.train_data,
|
| 606 |
+
start_idx=state.tracker.step * batch_size,
|
| 607 |
num_workers=num_workers,
|
| 608 |
batch_size=batch_size,
|
| 609 |
+
collate_fn=state.train_data.collate,
|
| 610 |
)
|
| 611 |
val_dataloader = accel.prepare_dataloader(
|
| 612 |
+
state.val_data,
|
| 613 |
start_idx=0,
|
| 614 |
num_workers=num_workers,
|
| 615 |
batch_size=batch_size,
|
| 616 |
+
collate_fn=state.val_data.collate,
|
| 617 |
+
persistent_workers=True,
|
| 618 |
)
|
| 619 |
|
| 620 |
+
|
| 621 |
|
| 622 |
if fine_tune:
|
| 623 |
+
lora.mark_only_lora_as_trainable(state.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
|
| 625 |
+
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
| 626 |
+
# and only run when specific conditions are met.
|
| 627 |
+
global train_loop, val_loop, validate, save_samples, checkpoint
|
| 628 |
|
| 629 |
+
train_loop = tracker.log("train", "value", history=False)(
|
| 630 |
+
tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
|
| 631 |
+
)
|
| 632 |
+
val_loop = tracker.track("val", len(val_dataloader))(val_loop)
|
| 633 |
+
validate = tracker.log("val", "mean")(validate)
|
| 634 |
|
| 635 |
+
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
| 636 |
+
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
|
|
|
| 637 |
|
| 638 |
+
with tracker.live:
|
| 639 |
+
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
| 640 |
+
train_loop(state, batch, accel)
|
| 641 |
|
| 642 |
+
last_iter = (
|
| 643 |
+
tracker.step == num_iters - 1 if num_iters is not None else False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
)
|
| 645 |
|
| 646 |
+
if tracker.step % sample_freq == 0 or last_iter:
|
| 647 |
+
save_samples(state, val_idx, writer)
|
| 648 |
|
| 649 |
+
if tracker.step % val_freq == 0 or last_iter:
|
| 650 |
+
validate(state, val_dataloader, accel)
|
| 651 |
+
checkpoint(
|
| 652 |
+
state=state,
|
| 653 |
+
save_iters=save_iters,
|
| 654 |
+
save_path=save_path,
|
| 655 |
+
fine_tune=fine_tune)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
+
# Reset validation progress bar, print summary since last validation.
|
| 658 |
+
tracker.done("val", f"Iteration {tracker.step}")
|
|
|
|
| 659 |
|
| 660 |
+
if last_iter:
|
| 661 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
|
| 663 |
|
| 664 |
if __name__ == "__main__":
|
|
|
|
| 666 |
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
| 667 |
with argbind.scope(args):
|
| 668 |
with Accelerator() as accel:
|
| 669 |
+
if accel.local_rank != 0:
|
| 670 |
+
sys.tracebacklimit = 0
|
| 671 |
train(args, accel)
|
setup.py
CHANGED
|
@@ -31,7 +31,7 @@ setup(
|
|
| 31 |
"numpy==1.22",
|
| 32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
| 33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
| 34 |
-
"audiotools @ git+https://github.com/
|
| 35 |
"gradio",
|
| 36 |
"tensorboardX",
|
| 37 |
"loralib",
|
|
|
|
| 31 |
"numpy==1.22",
|
| 32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
| 33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
| 34 |
+
"descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
|
| 35 |
"gradio",
|
| 36 |
"tensorboardX",
|
| 37 |
"loralib",
|
vampnet/beats.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Tuple
|
|
| 9 |
from typing import Union
|
| 10 |
|
| 11 |
import librosa
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
from audiotools import AudioSignal
|
| 14 |
|
|
@@ -203,7 +204,7 @@ class WaveBeat(BeatTracker):
|
|
| 203 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 204 |
from wavebeat.dstcn import dsTCNModel
|
| 205 |
|
| 206 |
-
model = dsTCNModel.load_from_checkpoint(ckpt_path)
|
| 207 |
model.eval()
|
| 208 |
|
| 209 |
self.device = device
|
|
|
|
| 9 |
from typing import Union
|
| 10 |
|
| 11 |
import librosa
|
| 12 |
+
import torch
|
| 13 |
import numpy as np
|
| 14 |
from audiotools import AudioSignal
|
| 15 |
|
|
|
|
| 204 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 205 |
from wavebeat.dstcn import dsTCNModel
|
| 206 |
|
| 207 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
|
| 208 |
model.eval()
|
| 209 |
|
| 210 |
self.device = device
|
vampnet/interface.py
CHANGED
|
@@ -22,6 +22,7 @@ def signal_concat(
|
|
| 22 |
|
| 23 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
| 24 |
|
|
|
|
| 25 |
def _load_model(
|
| 26 |
ckpt: str,
|
| 27 |
lora_ckpt: str = None,
|
|
@@ -64,7 +65,7 @@ class Interface(torch.nn.Module):
|
|
| 64 |
):
|
| 65 |
super().__init__()
|
| 66 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 67 |
-
self.codec = DAC.load(
|
| 68 |
self.codec.eval()
|
| 69 |
self.codec.to(device)
|
| 70 |
|
|
@@ -275,34 +276,44 @@ class Interface(torch.nn.Module):
|
|
| 275 |
|
| 276 |
def coarse_to_fine(
|
| 277 |
self,
|
| 278 |
-
|
|
|
|
| 279 |
**kwargs
|
| 280 |
):
|
| 281 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
| 282 |
-
length =
|
| 283 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
| 284 |
-
n_chunks = math.ceil(
|
| 285 |
|
| 286 |
# zero pad to chunk_len
|
| 287 |
if length % chunk_len != 0:
|
| 288 |
pad_len = chunk_len - (length % chunk_len)
|
| 289 |
-
|
|
|
|
| 290 |
|
| 291 |
-
n_codebooks_to_append = self.c2f.n_codebooks -
|
| 292 |
if n_codebooks_to_append > 0:
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
torch.zeros(
|
| 296 |
], dim=1)
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
fine_z = []
|
| 299 |
for i in range(n_chunks):
|
| 300 |
-
chunk =
|
|
|
|
|
|
|
| 301 |
chunk = self.c2f.generate(
|
| 302 |
codec=self.codec,
|
| 303 |
time_steps=chunk_len,
|
| 304 |
start_tokens=chunk,
|
| 305 |
return_signal=False,
|
|
|
|
| 306 |
**kwargs
|
| 307 |
)
|
| 308 |
fine_z.append(chunk)
|
|
@@ -337,6 +348,12 @@ class Interface(torch.nn.Module):
|
|
| 337 |
**kwargs
|
| 338 |
)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
if return_mask:
|
| 341 |
return c_vamp, cz_masked
|
| 342 |
|
|
@@ -352,17 +369,18 @@ if __name__ == "__main__":
|
|
| 352 |
at.util.seed(42)
|
| 353 |
|
| 354 |
interface = Interface(
|
| 355 |
-
coarse_ckpt="./models/
|
| 356 |
-
coarse2fine_ckpt="./models/
|
| 357 |
-
codec_ckpt="./models/
|
| 358 |
device="cuda",
|
| 359 |
wavebeat_ckpt="./models/wavebeat.pth"
|
| 360 |
)
|
| 361 |
|
| 362 |
|
| 363 |
-
sig = at.AudioSignal.
|
| 364 |
|
| 365 |
z = interface.encode(sig)
|
|
|
|
| 366 |
|
| 367 |
# mask = linear_random(z, 1.0)
|
| 368 |
# mask = mask_and(
|
|
@@ -374,13 +392,14 @@ if __name__ == "__main__":
|
|
| 374 |
# )
|
| 375 |
# )
|
| 376 |
|
| 377 |
-
mask = interface.make_beat_mask(
|
| 378 |
-
|
| 379 |
-
)
|
| 380 |
# mask = dropout(mask, 0.0)
|
| 381 |
# mask = codebook_unmask(mask, 0)
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
breakpoint()
|
| 384 |
zv, mask_z = interface.coarse_vamp(
|
| 385 |
z,
|
| 386 |
mask=mask,
|
|
@@ -389,16 +408,16 @@ if __name__ == "__main__":
|
|
| 389 |
return_mask=True,
|
| 390 |
gen_fn=interface.coarse.generate
|
| 391 |
)
|
|
|
|
| 392 |
|
| 393 |
use_coarse2fine = True
|
| 394 |
if use_coarse2fine:
|
| 395 |
-
zv = interface.coarse_to_fine(zv, temperature=0.8)
|
|
|
|
| 396 |
|
| 397 |
mask = interface.to_signal(mask_z).cpu()
|
| 398 |
|
| 399 |
sig = interface.to_signal(zv).cpu()
|
| 400 |
print("done")
|
| 401 |
|
| 402 |
-
sig.write("output3.wav")
|
| 403 |
-
mask.write("mask.wav")
|
| 404 |
|
|
|
|
| 22 |
|
| 23 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
| 24 |
|
| 25 |
+
|
| 26 |
def _load_model(
|
| 27 |
ckpt: str,
|
| 28 |
lora_ckpt: str = None,
|
|
|
|
| 65 |
):
|
| 66 |
super().__init__()
|
| 67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 68 |
+
self.codec = DAC.load(codec_ckpt)
|
| 69 |
self.codec.eval()
|
| 70 |
self.codec.to(device)
|
| 71 |
|
|
|
|
| 276 |
|
| 277 |
def coarse_to_fine(
|
| 278 |
self,
|
| 279 |
+
z: torch.Tensor,
|
| 280 |
+
mask: torch.Tensor = None,
|
| 281 |
**kwargs
|
| 282 |
):
|
| 283 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
| 284 |
+
length = z.shape[-1]
|
| 285 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
| 286 |
+
n_chunks = math.ceil(z.shape[-1] / chunk_len)
|
| 287 |
|
| 288 |
# zero pad to chunk_len
|
| 289 |
if length % chunk_len != 0:
|
| 290 |
pad_len = chunk_len - (length % chunk_len)
|
| 291 |
+
z = torch.nn.functional.pad(z, (0, pad_len))
|
| 292 |
+
mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
|
| 293 |
|
| 294 |
+
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
| 295 |
if n_codebooks_to_append > 0:
|
| 296 |
+
z = torch.cat([
|
| 297 |
+
z,
|
| 298 |
+
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
| 299 |
], dim=1)
|
| 300 |
|
| 301 |
+
# set the mask to 0 for all conditioning codebooks
|
| 302 |
+
if mask is not None:
|
| 303 |
+
mask = mask.clone()
|
| 304 |
+
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
|
| 305 |
+
|
| 306 |
fine_z = []
|
| 307 |
for i in range(n_chunks):
|
| 308 |
+
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
| 309 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
|
| 310 |
+
|
| 311 |
chunk = self.c2f.generate(
|
| 312 |
codec=self.codec,
|
| 313 |
time_steps=chunk_len,
|
| 314 |
start_tokens=chunk,
|
| 315 |
return_signal=False,
|
| 316 |
+
mask=mask_chunk,
|
| 317 |
**kwargs
|
| 318 |
)
|
| 319 |
fine_z.append(chunk)
|
|
|
|
| 348 |
**kwargs
|
| 349 |
)
|
| 350 |
|
| 351 |
+
# add the fine codes back in
|
| 352 |
+
c_vamp = torch.cat(
|
| 353 |
+
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
|
| 354 |
+
dim=1
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
if return_mask:
|
| 358 |
return c_vamp, cz_masked
|
| 359 |
|
|
|
|
| 369 |
at.util.seed(42)
|
| 370 |
|
| 371 |
interface = Interface(
|
| 372 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
| 373 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
| 374 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
| 375 |
device="cuda",
|
| 376 |
wavebeat_ckpt="./models/wavebeat.pth"
|
| 377 |
)
|
| 378 |
|
| 379 |
|
| 380 |
+
sig = at.AudioSignal('assets/example.wav')
|
| 381 |
|
| 382 |
z = interface.encode(sig)
|
| 383 |
+
breakpoint()
|
| 384 |
|
| 385 |
# mask = linear_random(z, 1.0)
|
| 386 |
# mask = mask_and(
|
|
|
|
| 392 |
# )
|
| 393 |
# )
|
| 394 |
|
| 395 |
+
# mask = interface.make_beat_mask(
|
| 396 |
+
# sig, 0.0, 0.075
|
| 397 |
+
# )
|
| 398 |
# mask = dropout(mask, 0.0)
|
| 399 |
# mask = codebook_unmask(mask, 0)
|
| 400 |
+
|
| 401 |
+
mask = inpaint(z, n_prefix=100, n_suffix=100)
|
| 402 |
|
|
|
|
| 403 |
zv, mask_z = interface.coarse_vamp(
|
| 404 |
z,
|
| 405 |
mask=mask,
|
|
|
|
| 408 |
return_mask=True,
|
| 409 |
gen_fn=interface.coarse.generate
|
| 410 |
)
|
| 411 |
+
|
| 412 |
|
| 413 |
use_coarse2fine = True
|
| 414 |
if use_coarse2fine:
|
| 415 |
+
zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
|
| 416 |
+
breakpoint()
|
| 417 |
|
| 418 |
mask = interface.to_signal(mask_z).cpu()
|
| 419 |
|
| 420 |
sig = interface.to_signal(zv).cpu()
|
| 421 |
print("done")
|
| 422 |
|
|
|
|
|
|
|
| 423 |
|