Spaces:
Sleeping
Sleeping
Commit
·
b676040
1
Parent(s):
8f8de0d
Remove FAD logging on input data during train
Browse files- remfx/models.py +4 -1
remfx/models.py
CHANGED
|
@@ -127,6 +127,9 @@ class RemFXModel(pl.LightningModule):
|
|
| 127 |
negate = -1
|
| 128 |
else:
|
| 129 |
negate = 1
|
|
|
|
|
|
|
|
|
|
| 130 |
self.log(
|
| 131 |
f"Input_{metric}",
|
| 132 |
negate * self.metrics[metric](x, target),
|
|
@@ -215,7 +218,7 @@ class DemucsModel(torch.nn.Module):
|
|
| 215 |
self.model = HDemucs(**kwargs)
|
| 216 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
| 217 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
| 218 |
-
n_bins=self.num_bins, sample_rate=
|
| 219 |
)
|
| 220 |
self.l1loss = torch.nn.L1Loss()
|
| 221 |
|
|
|
|
| 127 |
negate = -1
|
| 128 |
else:
|
| 129 |
negate = 1
|
| 130 |
+
# Only Log FAD on test set
|
| 131 |
+
if metric == "FAD":
|
| 132 |
+
continue
|
| 133 |
self.log(
|
| 134 |
f"Input_{metric}",
|
| 135 |
negate * self.metrics[metric](x, target),
|
|
|
|
| 218 |
self.model = HDemucs(**kwargs)
|
| 219 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
| 220 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
| 221 |
+
n_bins=self.num_bins, sample_rate=sample_rate
|
| 222 |
)
|
| 223 |
self.l1loss = torch.nn.L1Loss()
|
| 224 |
|