Spaces:
Sleeping
Sleeping
Commit
·
14ae0ea
1
Parent(s):
a22b103
WIP: Initial pipeline scripts
Browse files- .gitignore +1 -0
- README.md +5 -2
- datasets.py +61 -0
- download_egfx.sh +21 -0
- egfx.ipynb +0 -0
- guitar_generation_test.ipynb +0 -0
- models.py +105 -0
- train.py +32 -0
.gitignore
CHANGED
|
@@ -4,3 +4,4 @@ wandb/
|
|
| 4 |
*.egg-info/
|
| 5 |
data/
|
| 6 |
.DS_Store
|
|
|
|
|
|
| 4 |
*.egg-info/
|
| 5 |
data/
|
| 6 |
.DS_Store
|
| 7 |
+
__pycache__/
|
README.md
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
|
| 2 |
-
wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 Clean.zip
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 Clean.zip
|
| 3 |
|
| 4 |
+
unzip Clean.zip
|
| 5 |
+
|
| 6 |
+
python3 -m venv env
|
| 7 |
+
pip install -e .
|
datasets.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import torchaudio
|
| 4 |
+
import torchaudio.transforms as T
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
# https://zenodo.org/record/7044411/
|
| 10 |
+
|
| 11 |
+
LENGTH = 2**18 # 12 seconds
|
| 12 |
+
ORIG_SR = 48000
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GuitarFXDataset(Dataset):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
root: str,
|
| 19 |
+
sample_rate: int,
|
| 20 |
+
length: int = LENGTH,
|
| 21 |
+
effect_type: List[str] = None,
|
| 22 |
+
):
|
| 23 |
+
self.length = length
|
| 24 |
+
self.wet_files = []
|
| 25 |
+
self.dry_files = []
|
| 26 |
+
self.labels = []
|
| 27 |
+
self.root = Path(root)
|
| 28 |
+
if effect_type is None:
|
| 29 |
+
effect_type = [
|
| 30 |
+
d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
|
| 31 |
+
]
|
| 32 |
+
for i, effect in enumerate(effect_type):
|
| 33 |
+
for pickup in Path(self.root / effect).iterdir():
|
| 34 |
+
self.wet_files += list(pickup.glob("*.wav"))
|
| 35 |
+
self.dry_files += list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
|
| 36 |
+
self.labels += [i] * len(self.wet_files)
|
| 37 |
+
print(
|
| 38 |
+
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files"
|
| 39 |
+
)
|
| 40 |
+
self.resampler = T.Resample(ORIG_SR, sample_rate)
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.dry_files)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, idx):
|
| 46 |
+
x, sr = torchaudio.load(self.wet_files[idx])
|
| 47 |
+
y, sr = torchaudio.load(self.dry_files[idx])
|
| 48 |
+
effect_label = self.labels[idx]
|
| 49 |
+
|
| 50 |
+
resampled_x = self.resampler(x)
|
| 51 |
+
resampled_y = self.resampler(y)
|
| 52 |
+
# Pad or crop to length
|
| 53 |
+
if resampled_x.shape[-1] < self.length:
|
| 54 |
+
resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))
|
| 55 |
+
elif resampled_x.shape[-1] > self.length:
|
| 56 |
+
resampled_x = resampled_x[:, : self.length]
|
| 57 |
+
if resampled_y.shape[-1] < self.length:
|
| 58 |
+
resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))
|
| 59 |
+
elif resampled_y.shape[-1] > self.length:
|
| 60 |
+
resampled_y = resampled_y[:, : self.length]
|
| 61 |
+
return (resampled_x, resampled_y, effect_label)
|
download_egfx.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#/bin/bash
|
| 2 |
+
mkdir -p data
|
| 3 |
+
cd data
|
| 4 |
+
mkdir -p egfx
|
| 5 |
+
cd egfx
|
| 6 |
+
wget https://zenodo.org/record/7044411/files/BluesDriver.zip?download=1 -O BluesDriver.zip
|
| 7 |
+
wget https://zenodo.org/record/7044411/files/Chorus.zip?download=1 -O Chorus.zip
|
| 8 |
+
wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 -O Clean.zip
|
| 9 |
+
wget https://zenodo.org/record/7044411/files/Digital-Delay.zip?download=1 -O Digital-Delay.zip
|
| 10 |
+
wget https://zenodo.org/record/7044411/files/Flanger.zip?download=1 -O Flanger.zip
|
| 11 |
+
wget https://zenodo.org/record/7044411/files/Hall-Reverb.zip?download=1 -O Hall-Reverb.zip
|
| 12 |
+
wget https://zenodo.org/record/7044411/files/Phaser.zip?download=1 -O Phaser.zip
|
| 13 |
+
wget https://zenodo.org/record/7044411/files/Plate-Reverb.zip?download=1 -O Plate-Reverb.zip
|
| 14 |
+
wget https://zenodo.org/record/7044411/files/RAT.zip?download=1 -O RAT.zip
|
| 15 |
+
wget https://zenodo.org/record/7044411/files/Spring-Reverb.zip?download=1 -O Spring-Reverb.zip
|
| 16 |
+
wget https://zenodo.org/record/7044411/files/Sweep-Echo.zip?download=1 -O Sweep-Echo.zip
|
| 17 |
+
wget https://zenodo.org/record/7044411/files/TapeEcho.zip?download=1 -O TapeEcho.zip
|
| 18 |
+
wget https://zenodo.org/record/7044411/files/TubeScreamer.zip?download=1 -O TubeScreamer.zip
|
| 19 |
+
unzip \*.zip
|
| 20 |
+
|
| 21 |
+
|
egfx.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
guitar_generation_test.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from audio_diffusion_pytorch import AudioDiffusionModel
|
| 2 |
+
import torch
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import wandb
|
| 7 |
+
|
| 8 |
+
SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TCNWrapper(pl.LightningModule):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = AudioDiffusionModel(in_channels=1)
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor):
|
| 17 |
+
return self.model(x)
|
| 18 |
+
|
| 19 |
+
def training_step(self, batch, batch_idx):
|
| 20 |
+
loss = self.common_step(batch, batch_idx, mode="train")
|
| 21 |
+
return loss
|
| 22 |
+
|
| 23 |
+
def validation_step(self, batch, batch_idx):
|
| 24 |
+
loss = self.common_step(batch, batch_idx, mode="val")
|
| 25 |
+
|
| 26 |
+
def common_step(self, batch, batch_idx, mode: str = "train"):
|
| 27 |
+
x, target, label = batch
|
| 28 |
+
loss = self(x)
|
| 29 |
+
self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
|
| 30 |
+
return loss
|
| 31 |
+
|
| 32 |
+
def configure_optimizers(self):
|
| 33 |
+
return torch.optim.Adam(
|
| 34 |
+
self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AudioDiffusionWrapper(pl.LightningModule):
|
| 39 |
+
def __init__(self):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.model = AudioDiffusionModel(in_channels=1)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor):
|
| 44 |
+
return self.model(x)
|
| 45 |
+
|
| 46 |
+
def sample(self, *args, **kwargs) -> Tensor:
|
| 47 |
+
return self.model.sample(*args, **kwargs)
|
| 48 |
+
|
| 49 |
+
def training_step(self, batch, batch_idx):
|
| 50 |
+
loss = self.common_step(batch, batch_idx, mode="train")
|
| 51 |
+
return loss
|
| 52 |
+
|
| 53 |
+
def validation_step(self, batch, batch_idx):
|
| 54 |
+
loss = self.common_step(batch, batch_idx, mode="val")
|
| 55 |
+
|
| 56 |
+
def common_step(self, batch, batch_idx, mode: str = "train"):
|
| 57 |
+
x, target, label = batch
|
| 58 |
+
loss = self(x)
|
| 59 |
+
self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
|
| 60 |
+
return loss
|
| 61 |
+
|
| 62 |
+
def configure_optimizers(self):
|
| 63 |
+
return torch.optim.Adam(
|
| 64 |
+
self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def on_validation_epoch_start(self):
|
| 68 |
+
self.log_next = True
|
| 69 |
+
|
| 70 |
+
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 71 |
+
x, target, label = batch
|
| 72 |
+
if self.log_next:
|
| 73 |
+
self.log_sample(x)
|
| 74 |
+
self.log_next = False
|
| 75 |
+
|
| 76 |
+
@torch.no_grad()
|
| 77 |
+
def log_sample(self, batch, num_steps=10):
|
| 78 |
+
# Get start diffusion noise
|
| 79 |
+
noise = torch.randn(batch.shape, device=self.device)
|
| 80 |
+
sampled = self.model.sample(
|
| 81 |
+
noise=noise, num_steps=num_steps # Suggested range: 2-50
|
| 82 |
+
)
|
| 83 |
+
self.log_wandb_audio_batch(
|
| 84 |
+
id="sample",
|
| 85 |
+
samples=sampled,
|
| 86 |
+
sampling_rate=SAMPLE_RATE,
|
| 87 |
+
caption=f"Sampled in {num_steps} steps",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def log_wandb_audio_batch(
|
| 92 |
+
id: str, samples: Tensor, sampling_rate: int, caption: str = ""
|
| 93 |
+
):
|
| 94 |
+
num_items = samples.shape[0]
|
| 95 |
+
samples = rearrange(samples, "b c t -> b t c")
|
| 96 |
+
for idx in range(num_items):
|
| 97 |
+
wandb.log(
|
| 98 |
+
{
|
| 99 |
+
f"sample_{idx}_{id}": wandb.Audio(
|
| 100 |
+
samples[idx].cpu().numpy(),
|
| 101 |
+
caption=caption,
|
| 102 |
+
sample_rate=sampling_rate,
|
| 103 |
+
)
|
| 104 |
+
}
|
| 105 |
+
)
|
train.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from datasets import GuitarFXDataset
|
| 6 |
+
from models import AudioDiffusionWrapper
|
| 7 |
+
|
| 8 |
+
SAMPLE_RATE = 22050
|
| 9 |
+
TRAIN_SPLIT = 0.8
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
# wandb_logger = WandbLogger(project="RemFX", save_dir="./")
|
| 14 |
+
trainer = pl.Trainer() # logger=wandb_logger)
|
| 15 |
+
guitfx = GuitarFXDataset(
|
| 16 |
+
root="/Users/matthewrice/mir_datasets/egfxset",
|
| 17 |
+
sample_rate=SAMPLE_RATE,
|
| 18 |
+
effect_type=["Phaser"],
|
| 19 |
+
)
|
| 20 |
+
train_size = int(TRAIN_SPLIT * len(guitfx))
|
| 21 |
+
val_size = len(guitfx) - train_size
|
| 22 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
| 23 |
+
guitfx, [train_size, val_size]
|
| 24 |
+
)
|
| 25 |
+
train = DataLoader(train_dataset, batch_size=2)
|
| 26 |
+
val = DataLoader(val_dataset, batch_size=2)
|
| 27 |
+
model = AudioDiffusionWrapper()
|
| 28 |
+
trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|