Spaces:
Sleeping
Sleeping
Commit
·
133e1dc
1
Parent(s):
7173f20
Add shuffling effect order, all effects present for chain_inference to cfg
Browse files
cfg/exp/chain_inference.yaml
CHANGED
|
@@ -63,4 +63,6 @@ inference_effects_ordering:
|
|
| 63 |
- "RandomPedalboardReverb"
|
| 64 |
- "RandomPedalboardChorus"
|
| 65 |
- "RandomPedalboardDelay"
|
| 66 |
-
num_bins: 1025
|
|
|
|
|
|
|
|
|
| 63 |
- "RandomPedalboardReverb"
|
| 64 |
- "RandomPedalboardChorus"
|
| 65 |
- "RandomPedalboardDelay"
|
| 66 |
+
num_bins: 1025
|
| 67 |
+
inference_effects_shuffle: False
|
| 68 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_aug.yaml
CHANGED
|
@@ -63,4 +63,6 @@ inference_effects_ordering:
|
|
| 63 |
- "RandomPedalboardReverb"
|
| 64 |
- "RandomPedalboardChorus"
|
| 65 |
- "RandomPedalboardDelay"
|
| 66 |
-
num_bins: 1025
|
|
|
|
|
|
|
|
|
| 63 |
- "RandomPedalboardReverb"
|
| 64 |
- "RandomPedalboardChorus"
|
| 65 |
- "RandomPedalboardDelay"
|
| 66 |
+
num_bins: 1025
|
| 67 |
+
inference_effects_shuffle: False
|
| 68 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_aug_classifier.yaml
CHANGED
|
@@ -82,4 +82,6 @@ inference_effects_ordering:
|
|
| 82 |
- "RandomPedalboardReverb"
|
| 83 |
- "RandomPedalboardChorus"
|
| 84 |
- "RandomPedalboardDelay"
|
| 85 |
-
num_bins: 1025
|
|
|
|
|
|
|
|
|
| 82 |
- "RandomPedalboardReverb"
|
| 83 |
- "RandomPedalboardChorus"
|
| 84 |
- "RandomPedalboardDelay"
|
| 85 |
+
num_bins: 1025
|
| 86 |
+
inference_effects_shuffle: False
|
| 87 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_custom.yaml
CHANGED
|
@@ -68,4 +68,6 @@ inference_effects_ordering:
|
|
| 68 |
- "RandomPedalboardReverb"
|
| 69 |
- "RandomPedalboardChorus"
|
| 70 |
- "RandomPedalboardDelay"
|
| 71 |
-
num_bins: 1025
|
|
|
|
|
|
|
|
|
| 68 |
- "RandomPedalboardReverb"
|
| 69 |
- "RandomPedalboardChorus"
|
| 70 |
- "RandomPedalboardDelay"
|
| 71 |
+
num_bins: 1025
|
| 72 |
+
inference_effects_shuffle: False
|
| 73 |
+
inference_use_all_effect_models: False
|
remfx/models.py
CHANGED
|
@@ -16,12 +16,22 @@ from remfx.callbacks import log_wandb_audio_batch
|
|
| 16 |
from einops import rearrange
|
| 17 |
from remfx import effects
|
| 18 |
import asteroid
|
|
|
|
| 19 |
|
| 20 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
| 21 |
|
| 22 |
|
| 23 |
class RemFXChainInference(pl.LightningModule):
|
| 24 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
super().__init__()
|
| 26 |
self.model = models
|
| 27 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
|
@@ -37,7 +47,9 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 37 |
self.sample_rate = sample_rate
|
| 38 |
self.effect_order = effect_order
|
| 39 |
self.classifier = classifier
|
|
|
|
| 40 |
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
|
|
|
| 41 |
|
| 42 |
def forward(self, batch, batch_idx, order=None):
|
| 43 |
x, y, _, rem_fx_labels = batch
|
|
@@ -46,36 +58,45 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 46 |
effects_order = order
|
| 47 |
else:
|
| 48 |
effects_order = self.effect_order
|
| 49 |
-
|
| 50 |
# Use classifier labels
|
| 51 |
if self.classifier:
|
| 52 |
threshold = 0.5
|
| 53 |
with torch.no_grad():
|
| 54 |
labels = torch.sigmoid(self.classifier(x))
|
| 55 |
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
output = []
|
| 62 |
-
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
| 63 |
-
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
| 64 |
-
|
| 65 |
-
log_wandb_audio_batch(
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
)
|
| 72 |
-
log_wandb_audio_batch(
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
)
|
| 79 |
with torch.no_grad():
|
| 80 |
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
| 81 |
elem = elem.unsqueeze(0) # Add batch dim
|
|
@@ -111,7 +132,6 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 111 |
# )
|
| 112 |
output.append(elem.squeeze(0))
|
| 113 |
output = torch.stack(output)
|
| 114 |
-
output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
|
| 115 |
|
| 116 |
# log_wandb_audio_batch(
|
| 117 |
# logger=self.logger,
|
|
@@ -125,8 +145,9 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 125 |
|
| 126 |
def test_step(self, batch, batch_idx):
|
| 127 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
| 128 |
-
|
| 129 |
-
|
|
|
|
| 130 |
loss, output = self.forward(batch, batch_idx, order=self.effect_order)
|
| 131 |
# Crop target to match output
|
| 132 |
if output.shape[-1] < y.shape[-1]:
|
|
|
|
| 16 |
from einops import rearrange
|
| 17 |
from remfx import effects
|
| 18 |
import asteroid
|
| 19 |
+
import random
|
| 20 |
|
| 21 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
| 22 |
|
| 23 |
|
| 24 |
class RemFXChainInference(pl.LightningModule):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
models,
|
| 28 |
+
sample_rate,
|
| 29 |
+
num_bins,
|
| 30 |
+
effect_order,
|
| 31 |
+
classifier=None,
|
| 32 |
+
shuffle_effect_order=False,
|
| 33 |
+
use_all_effect_models=False,
|
| 34 |
+
):
|
| 35 |
super().__init__()
|
| 36 |
self.model = models
|
| 37 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
|
|
|
| 47 |
self.sample_rate = sample_rate
|
| 48 |
self.effect_order = effect_order
|
| 49 |
self.classifier = classifier
|
| 50 |
+
self.shuffle_effect_order = shuffle_effect_order
|
| 51 |
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
| 52 |
+
self.use_all_effect_models = use_all_effect_models
|
| 53 |
|
| 54 |
def forward(self, batch, batch_idx, order=None):
|
| 55 |
x, y, _, rem_fx_labels = batch
|
|
|
|
| 58 |
effects_order = order
|
| 59 |
else:
|
| 60 |
effects_order = self.effect_order
|
| 61 |
+
|
| 62 |
# Use classifier labels
|
| 63 |
if self.classifier:
|
| 64 |
threshold = 0.5
|
| 65 |
with torch.no_grad():
|
| 66 |
labels = torch.sigmoid(self.classifier(x))
|
| 67 |
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
| 68 |
+
if self.use_all_effect_models:
|
| 69 |
+
effects_present = [
|
| 70 |
+
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect]
|
| 71 |
+
for effect_label in rem_fx_labels
|
| 72 |
+
]
|
| 73 |
+
else:
|
| 74 |
+
effects_present = [
|
| 75 |
+
[
|
| 76 |
+
ALL_EFFECTS[i]
|
| 77 |
+
for i, effect in enumerate(effect_label)
|
| 78 |
+
if effect == 1.0
|
| 79 |
+
]
|
| 80 |
+
for effect_label in rem_fx_labels
|
| 81 |
+
]
|
| 82 |
output = []
|
| 83 |
+
# input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
| 84 |
+
# target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
| 85 |
+
|
| 86 |
+
# log_wandb_audio_batch(
|
| 87 |
+
# logger=self.logger,
|
| 88 |
+
# id="input_effected_audio",
|
| 89 |
+
# samples=input_samples.cpu(),
|
| 90 |
+
# sampling_rate=self.sample_rate,
|
| 91 |
+
# caption="Input Data",
|
| 92 |
+
# )
|
| 93 |
+
# log_wandb_audio_batch(
|
| 94 |
+
# logger=self.logger,
|
| 95 |
+
# id="target_audio",
|
| 96 |
+
# samples=target_samples.cpu(),
|
| 97 |
+
# sampling_rate=self.sample_rate,
|
| 98 |
+
# caption="Target Data",
|
| 99 |
+
# )
|
| 100 |
with torch.no_grad():
|
| 101 |
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
| 102 |
elem = elem.unsqueeze(0) # Add batch dim
|
|
|
|
| 132 |
# )
|
| 133 |
output.append(elem.squeeze(0))
|
| 134 |
output = torch.stack(output)
|
|
|
|
| 135 |
|
| 136 |
# log_wandb_audio_batch(
|
| 137 |
# logger=self.logger,
|
|
|
|
| 145 |
|
| 146 |
def test_step(self, batch, batch_idx):
|
| 147 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
| 148 |
+
if self.shuffle_effect_order:
|
| 149 |
+
# Random order
|
| 150 |
+
random.shuffle(self.effect_order)
|
| 151 |
loss, output = self.forward(batch, batch_idx, order=self.effect_order)
|
| 152 |
# Crop target to match output
|
| 153 |
if output.shape[-1] < y.shape[-1]:
|
scripts/chain_inference.py
CHANGED
|
@@ -65,6 +65,8 @@ def main(cfg: DictConfig):
|
|
| 65 |
num_bins=cfg.num_bins,
|
| 66 |
effect_order=cfg.inference_effects_ordering,
|
| 67 |
classifier=classifier,
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
trainer.test(model=inference_model, datamodule=datamodule)
|
| 70 |
|
|
|
|
| 65 |
num_bins=cfg.num_bins,
|
| 66 |
effect_order=cfg.inference_effects_ordering,
|
| 67 |
classifier=classifier,
|
| 68 |
+
shuffle_effect_order=cfg.inference_effects_shuffle,
|
| 69 |
+
use_all_effect_models=cfg.inference_use_all_effect_models,
|
| 70 |
)
|
| 71 |
trainer.test(model=inference_model, datamodule=datamodule)
|
| 72 |
|