Spaces:
Sleeping
Sleeping
Commit
·
0e3a05d
1
Parent(s):
0fbacb2
Add target cropping if outputs are different length
Browse files- remfx/models.py +7 -20
- remfx/tcn.py +0 -2
remfx/models.py
CHANGED
|
@@ -13,6 +13,7 @@ from remfx.utils import FADLoss, spectrogram
|
|
| 13 |
from remfx.dptnet import DPTNet_base
|
| 14 |
from remfx.dcunet import RefineSpectrogramUnet
|
| 15 |
from remfx.tcn import TCN
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class RemFX(pl.LightningModule):
|
|
@@ -223,21 +224,14 @@ class DCUNetModel(nn.Module):
|
|
| 223 |
def forward(self, batch):
|
| 224 |
x, target = batch
|
| 225 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
| 226 |
-
#
|
| 227 |
-
if output.shape[-1]
|
| 228 |
-
|
| 229 |
-
elif output.shape[-1] < target.shape[-1]:
|
| 230 |
-
output = F.pad(output, (0, target.shape[-1] - output.shape[-1]))
|
| 231 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 232 |
return loss, output
|
| 233 |
|
| 234 |
def sample(self, x: Tensor) -> Tensor:
|
| 235 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
| 236 |
-
# Pad or crop to match target
|
| 237 |
-
if output.shape[-1] > x.shape[-1]:
|
| 238 |
-
output = output[:, : x.shape[-1]]
|
| 239 |
-
elif output.shape[-1] < x.shape[-1]:
|
| 240 |
-
output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
|
| 241 |
return output
|
| 242 |
|
| 243 |
|
|
@@ -253,21 +247,14 @@ class TCNModel(nn.Module):
|
|
| 253 |
def forward(self, batch):
|
| 254 |
x, target = batch
|
| 255 |
output = self.model(x) # B x 1 x T
|
| 256 |
-
#
|
| 257 |
-
if output.shape[-1]
|
| 258 |
-
|
| 259 |
-
elif output.shape[-1] < x.shape[-1]:
|
| 260 |
-
output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
|
| 261 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 262 |
return loss, output
|
| 263 |
|
| 264 |
def sample(self, x: Tensor) -> Tensor:
|
| 265 |
output = self.model(x) # B x 1 x T
|
| 266 |
-
# Pad or crop to match target
|
| 267 |
-
if output.shape[-1] > x.shape[-1]:
|
| 268 |
-
output = output[:, : x.shape[-1]]
|
| 269 |
-
elif output.shape[-1] < x.shape[-1]:
|
| 270 |
-
output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
|
| 271 |
return output
|
| 272 |
|
| 273 |
|
|
|
|
| 13 |
from remfx.dptnet import DPTNet_base
|
| 14 |
from remfx.dcunet import RefineSpectrogramUnet
|
| 15 |
from remfx.tcn import TCN
|
| 16 |
+
from remfx.utils import causal_crop
|
| 17 |
|
| 18 |
|
| 19 |
class RemFX(pl.LightningModule):
|
|
|
|
| 224 |
def forward(self, batch):
|
| 225 |
x, target = batch
|
| 226 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
| 227 |
+
# Crop target to match output
|
| 228 |
+
if output.shape[-1] < target.shape[-1]:
|
| 229 |
+
target = causal_crop(target, output.shape[-1])
|
|
|
|
|
|
|
| 230 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 231 |
return loss, output
|
| 232 |
|
| 233 |
def sample(self, x: Tensor) -> Tensor:
|
| 234 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
return output
|
| 236 |
|
| 237 |
|
|
|
|
| 247 |
def forward(self, batch):
|
| 248 |
x, target = batch
|
| 249 |
output = self.model(x) # B x 1 x T
|
| 250 |
+
# Crop target to match output
|
| 251 |
+
if output.shape[-1] < target.shape[-1]:
|
| 252 |
+
target = causal_crop(target, output.shape[-1])
|
|
|
|
|
|
|
| 253 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 254 |
return loss, output
|
| 255 |
|
| 256 |
def sample(self, x: Tensor) -> Tensor:
|
| 257 |
output = self.model(x) # B x 1 x T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
return output
|
| 259 |
|
| 260 |
|
remfx/tcn.py
CHANGED
|
@@ -25,8 +25,6 @@ class TCNBlock(nn.Module):
|
|
| 25 |
self.stride = stride
|
| 26 |
|
| 27 |
self.crop_fn = crop_fn
|
| 28 |
-
# Assumes stride of 1
|
| 29 |
-
padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
|
| 30 |
self.conv1 = nn.Conv1d(
|
| 31 |
in_ch,
|
| 32 |
out_ch,
|
|
|
|
| 25 |
self.stride = stride
|
| 26 |
|
| 27 |
self.crop_fn = crop_fn
|
|
|
|
|
|
|
| 28 |
self.conv1 = nn.Conv1d(
|
| 29 |
in_ch,
|
| 30 |
out_ch,
|