Spaces:
Sleeping
Sleeping
Commit
·
7173f20
1
Parent(s):
7fc4de1
Re-sample effects if STFT too low
Browse files- remfx/datasets.py +49 -42
remfx/datasets.py
CHANGED
|
@@ -13,10 +13,10 @@ from typing import Any, List, Dict
|
|
| 13 |
from torch.utils.data import Dataset, DataLoader
|
| 14 |
from remfx.utils import select_random_chunk
|
| 15 |
import multiprocessing
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
| 21 |
# print(ALL_EFFECTS)
|
| 22 |
|
|
@@ -275,6 +275,7 @@ class EffectDataset(Dataset):
|
|
| 275 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
| 276 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
| 277 |
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
|
|
|
| 278 |
self.effects = effect_modules
|
| 279 |
self.shuffle_kept_effects = shuffle_kept_effects
|
| 280 |
self.shuffle_removed_effects = shuffle_removed_effects
|
|
@@ -438,46 +439,52 @@ class EffectDataset(Dataset):
|
|
| 438 |
# Index in effect settings
|
| 439 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
| 440 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
| 441 |
-
#
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
#
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
#
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
| 482 |
|
| 483 |
|
|
|
|
| 13 |
from torch.utils.data import Dataset, DataLoader
|
| 14 |
from remfx.utils import select_random_chunk
|
| 15 |
import multiprocessing
|
| 16 |
+
from auraloss.freq import MultiResolutionSTFTLoss
|
| 17 |
|
| 18 |
|
| 19 |
+
STFT_THRESH = 1e-3
|
|
|
|
| 20 |
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
| 21 |
# print(ALL_EFFECTS)
|
| 22 |
|
|
|
|
| 275 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
| 276 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
| 277 |
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
| 278 |
+
self.mrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate)
|
| 279 |
self.effects = effect_modules
|
| 280 |
self.shuffle_kept_effects = shuffle_kept_effects
|
| 281 |
self.shuffle_removed_effects = shuffle_removed_effects
|
|
|
|
| 439 |
# Index in effect settings
|
| 440 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
| 441 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
| 442 |
+
# stft comparison
|
| 443 |
+
stft = 0
|
| 444 |
+
while stft < STFT_THRESH:
|
| 445 |
+
# Apply
|
| 446 |
+
dry_labels = []
|
| 447 |
+
for effect in effects_to_apply:
|
| 448 |
+
# Normalize in-between effects
|
| 449 |
+
dry = self.normalize(effect(dry))
|
| 450 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
| 451 |
+
|
| 452 |
+
# Apply effects_to_remove
|
| 453 |
+
# Shuffle effects if specified
|
| 454 |
+
if self.shuffle_removed_effects:
|
| 455 |
+
effect_indices = torch.randperm(len(self.effects_to_remove))
|
| 456 |
+
else:
|
| 457 |
+
effect_indices = torch.arange(len(self.effects_to_remove))
|
| 458 |
+
wet = torch.clone(dry)
|
| 459 |
+
r1 = self.num_removed_effects[0]
|
| 460 |
+
r2 = self.num_removed_effects[1]
|
| 461 |
+
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
| 462 |
+
effect_indices = effect_indices[:num_removed_effects]
|
| 463 |
+
# Index in effect settings
|
| 464 |
+
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
|
| 465 |
+
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
| 466 |
+
# Apply
|
| 467 |
+
wet_labels = []
|
| 468 |
+
for effect in effects_to_apply:
|
| 469 |
+
# Normalize in-between effects
|
| 470 |
+
wet = self.normalize(effect(wet))
|
| 471 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
| 472 |
+
|
| 473 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
| 474 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
| 475 |
+
|
| 476 |
+
for label_idx in wet_labels:
|
| 477 |
+
wet_labels_tensor[label_idx] = 1.0
|
| 478 |
+
|
| 479 |
+
for label_idx in dry_labels:
|
| 480 |
+
dry_labels_tensor[label_idx] = 1.0
|
| 481 |
+
|
| 482 |
+
# Normalize
|
| 483 |
+
normalized_dry = self.normalize(dry)
|
| 484 |
+
normalized_wet = self.normalize(wet)
|
| 485 |
+
|
| 486 |
+
# Check STFT, pick different effects if necessary
|
| 487 |
+
stft = self.mrstft(normalized_wet, normalized_dry)
|
| 488 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
| 489 |
|
| 490 |
|