Spaces:
Sleeping
Sleeping
Commit
·
1ff07dc
1
Parent(s):
d8d3e30
Add new loss to umx
Browse files- remfx/models.py +2 -2
remfx/models.py
CHANGED
|
@@ -237,7 +237,7 @@ class OpenUnmixModel(torch.nn.Module):
|
|
| 237 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
| 238 |
Y = self.model(X)
|
| 239 |
sep_out = self.separator(x).squeeze(1)
|
| 240 |
-
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
|
| 241 |
|
| 242 |
return loss, sep_out
|
| 243 |
|
|
@@ -258,7 +258,7 @@ class DemucsModel(torch.nn.Module):
|
|
| 258 |
def forward(self, batch):
|
| 259 |
x, target, label = batch
|
| 260 |
output = self.model(x).squeeze(1)
|
| 261 |
-
loss = self.mrstftloss(output, target) + self.l1loss(output, target)
|
| 262 |
return loss, output
|
| 263 |
|
| 264 |
def sample(self, x: Tensor) -> Tensor:
|
|
|
|
| 237 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
| 238 |
Y = self.model(X)
|
| 239 |
sep_out = self.separator(x).squeeze(1)
|
| 240 |
+
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
|
| 241 |
|
| 242 |
return loss, sep_out
|
| 243 |
|
|
|
|
| 258 |
def forward(self, batch):
|
| 259 |
x, target, label = batch
|
| 260 |
output = self.model(x).squeeze(1)
|
| 261 |
+
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 262 |
return loss, output
|
| 263 |
|
| 264 |
def sample(self, x: Tensor) -> Tensor:
|