Spaces:
Runtime error
Runtime error
Commit
·
6448f47
1
Parent(s):
e543fe8
Fix new dataset to work for remfx training
Browse files- README.md +1 -2
- remfx/datasets.py +0 -4
- remfx/models.py +8 -32
README.md
CHANGED
|
@@ -9,10 +9,9 @@
|
|
| 9 |
5. `pip install -e umx`
|
| 10 |
|
| 11 |
## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
|
| 12 |
-
1. `wget https://zenodo.org/record/
|
| 13 |
2. `mv VocalSet.zip?download=1 VocalSet.zip`
|
| 14 |
3. `unzip VocalSet.zip`
|
| 15 |
-
4. Manually split singers into train, val, test directories
|
| 16 |
|
| 17 |
# Training
|
| 18 |
## Steps
|
|
|
|
| 9 |
5. `pip install -e umx`
|
| 10 |
|
| 11 |
## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
|
| 12 |
+
1. `wget https://zenodo.org/record/1442513/files/VocalSet1-2.zip?download=1`
|
| 13 |
2. `mv VocalSet.zip?download=1 VocalSet.zip`
|
| 14 |
3. `unzip VocalSet.zip`
|
|
|
|
| 15 |
|
| 16 |
# Training
|
| 17 |
## Steps
|
remfx/datasets.py
CHANGED
|
@@ -19,7 +19,6 @@ from remfx.utils import create_sequential_chunks
|
|
| 19 |
# https://zenodo.org/record/1193957 -> VocalSet
|
| 20 |
|
| 21 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
| 22 |
-
print(ALL_EFFECTS)
|
| 23 |
|
| 24 |
|
| 25 |
singer_splits = {
|
|
@@ -206,7 +205,6 @@ class VocalSet(Dataset):
|
|
| 206 |
else:
|
| 207 |
num_kept_effects = len(self.effects_to_keep)
|
| 208 |
effect_indices = effect_indices[:num_kept_effects]
|
| 209 |
-
print(effect_indices)
|
| 210 |
|
| 211 |
# Index in effect settings
|
| 212 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
|
@@ -249,8 +247,6 @@ class VocalSet(Dataset):
|
|
| 249 |
for label_idx in dry_labels:
|
| 250 |
dry_labels_tensor[label_idx] = 1.0
|
| 251 |
|
| 252 |
-
# effects_present = torch.sum(one_hot, dim=0).float()
|
| 253 |
-
print(dry_labels_tensor, wet_labels_tensor)
|
| 254 |
# Normalize
|
| 255 |
normalized_dry = self.normalize(dry)
|
| 256 |
normalized_wet = self.normalize(wet)
|
|
|
|
| 19 |
# https://zenodo.org/record/1193957 -> VocalSet
|
| 20 |
|
| 21 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
singer_splits = {
|
|
|
|
| 205 |
else:
|
| 206 |
num_kept_effects = len(self.effects_to_keep)
|
| 207 |
effect_indices = effect_indices[:num_kept_effects]
|
|
|
|
| 208 |
|
| 209 |
# Index in effect settings
|
| 210 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
|
|
|
| 247 |
for label_idx in dry_labels:
|
| 248 |
dry_labels_tensor[label_idx] = 1.0
|
| 249 |
|
|
|
|
|
|
|
| 250 |
# Normalize
|
| 251 |
normalized_dry = self.normalize(dry)
|
| 252 |
normalized_wet = self.normalize(wet)
|
remfx/models.py
CHANGED
|
@@ -94,9 +94,9 @@ class RemFXModel(pl.LightningModule):
|
|
| 94 |
return loss
|
| 95 |
|
| 96 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
| 97 |
-
|
|
|
|
| 98 |
self.log(f"{mode}_loss", loss)
|
| 99 |
-
x, y, label = batch
|
| 100 |
# Metric logging
|
| 101 |
with torch.no_grad():
|
| 102 |
for metric in self.metrics:
|
|
@@ -123,7 +123,7 @@ class RemFXModel(pl.LightningModule):
|
|
| 123 |
def on_train_batch_start(self, batch, batch_idx):
|
| 124 |
# Log initial audio
|
| 125 |
if self.log_train_audio:
|
| 126 |
-
x, y,
|
| 127 |
# Concat samples together for easier viewing in dashboard
|
| 128 |
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
| 129 |
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
|
@@ -145,7 +145,7 @@ class RemFXModel(pl.LightningModule):
|
|
| 145 |
self.log_train_audio = False
|
| 146 |
|
| 147 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 148 |
-
x, target,
|
| 149 |
# Log Input Metrics
|
| 150 |
for metric in self.metrics:
|
| 151 |
# SISDR returns negative values, so negate them
|
|
@@ -189,7 +189,7 @@ class RemFXModel(pl.LightningModule):
|
|
| 189 |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 190 |
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
| 191 |
# Log FAD
|
| 192 |
-
x, target,
|
| 193 |
self.log(
|
| 194 |
"Input_FAD",
|
| 195 |
self.metrics["FAD"](x, target),
|
|
@@ -237,7 +237,7 @@ class OpenUnmixModel(torch.nn.Module):
|
|
| 237 |
self.l1loss = torch.nn.L1Loss()
|
| 238 |
|
| 239 |
def forward(self, batch):
|
| 240 |
-
x, target
|
| 241 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
| 242 |
Y = self.model(X)
|
| 243 |
sep_out = self.separator(x).squeeze(1)
|
|
@@ -260,7 +260,7 @@ class DemucsModel(torch.nn.Module):
|
|
| 260 |
self.l1loss = torch.nn.L1Loss()
|
| 261 |
|
| 262 |
def forward(self, batch):
|
| 263 |
-
x, target
|
| 264 |
output = self.model(x).squeeze(1)
|
| 265 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 266 |
return loss, output
|
|
@@ -275,7 +275,7 @@ class DiffusionGenerationModel(nn.Module):
|
|
| 275 |
self.model = DiffusionModel(in_channels=n_channels)
|
| 276 |
|
| 277 |
def forward(self, batch):
|
| 278 |
-
x, target
|
| 279 |
sampled_out = self.model.sample(x)
|
| 280 |
return self.model(x), sampled_out
|
| 281 |
|
|
@@ -481,30 +481,6 @@ class Cnn14(nn.Module):
|
|
| 481 |
return clipwise_output
|
| 482 |
|
| 483 |
|
| 484 |
-
def spectrogram(
|
| 485 |
-
x: torch.Tensor,
|
| 486 |
-
window: torch.Tensor,
|
| 487 |
-
n_fft: int,
|
| 488 |
-
hop_length: int,
|
| 489 |
-
alpha: float,
|
| 490 |
-
) -> torch.Tensor:
|
| 491 |
-
bs, chs, samp = x.size()
|
| 492 |
-
x = x.view(bs * chs, -1) # move channels onto batch dim
|
| 493 |
-
|
| 494 |
-
X = torch.stft(
|
| 495 |
-
x,
|
| 496 |
-
n_fft=n_fft,
|
| 497 |
-
hop_length=hop_length,
|
| 498 |
-
window=window,
|
| 499 |
-
return_complex=True,
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
# move channels back
|
| 503 |
-
X = X.view(bs, chs, X.shape[-2], X.shape[-1])
|
| 504 |
-
|
| 505 |
-
return torch.pow(X.abs() + 1e-8, alpha)
|
| 506 |
-
|
| 507 |
-
|
| 508 |
class FXClassifier(pl.LightningModule):
|
| 509 |
def __init__(
|
| 510 |
self,
|
|
|
|
| 94 |
return loss
|
| 95 |
|
| 96 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
| 97 |
+
x, y, _, _ = batch
|
| 98 |
+
loss, output = self.model((x, y))
|
| 99 |
self.log(f"{mode}_loss", loss)
|
|
|
|
| 100 |
# Metric logging
|
| 101 |
with torch.no_grad():
|
| 102 |
for metric in self.metrics:
|
|
|
|
| 123 |
def on_train_batch_start(self, batch, batch_idx):
|
| 124 |
# Log initial audio
|
| 125 |
if self.log_train_audio:
|
| 126 |
+
x, y, _, _ = batch
|
| 127 |
# Concat samples together for easier viewing in dashboard
|
| 128 |
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
| 129 |
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
|
|
|
| 145 |
self.log_train_audio = False
|
| 146 |
|
| 147 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 148 |
+
x, target, _, _ = batch
|
| 149 |
# Log Input Metrics
|
| 150 |
for metric in self.metrics:
|
| 151 |
# SISDR returns negative values, so negate them
|
|
|
|
| 189 |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 190 |
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
| 191 |
# Log FAD
|
| 192 |
+
x, target, _, _ = batch
|
| 193 |
self.log(
|
| 194 |
"Input_FAD",
|
| 195 |
self.metrics["FAD"](x, target),
|
|
|
|
| 237 |
self.l1loss = torch.nn.L1Loss()
|
| 238 |
|
| 239 |
def forward(self, batch):
|
| 240 |
+
x, target = batch
|
| 241 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
| 242 |
Y = self.model(X)
|
| 243 |
sep_out = self.separator(x).squeeze(1)
|
|
|
|
| 260 |
self.l1loss = torch.nn.L1Loss()
|
| 261 |
|
| 262 |
def forward(self, batch):
|
| 263 |
+
x, target = batch
|
| 264 |
output = self.model(x).squeeze(1)
|
| 265 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 266 |
return loss, output
|
|
|
|
| 275 |
self.model = DiffusionModel(in_channels=n_channels)
|
| 276 |
|
| 277 |
def forward(self, batch):
|
| 278 |
+
x, target = batch
|
| 279 |
sampled_out = self.model.sample(x)
|
| 280 |
return self.model(x), sampled_out
|
| 281 |
|
|
|
|
| 481 |
return clipwise_output
|
| 482 |
|
| 483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
class FXClassifier(pl.LightningModule):
|
| 485 |
def __init__(
|
| 486 |
self,
|