kevinwang676 commited on
Commit
534cce5
·
verified ·
1 Parent(s): 05dd4cd

Update egs/visinger2/models.py

Browse files
Files changed (1) hide show
  1. egs/visinger2/models.py +14 -8
egs/visinger2/models.py CHANGED
@@ -573,22 +573,28 @@ class Generator_Noise(torch.nn.Module):
573
  def forward(self, x, mask):
574
  istft_x = x
575
  istft_x = self.istft_pre(istft_x)
576
-
 
577
  istft_x = self.net(istft_x) * mask
578
-
 
579
  amp = self.istft_amplitude(istft_x).unsqueeze(-1)
580
- phase = (torch.rand(amp.shape) * 2 * 3.14 - 3.14).to(amp)
581
-
 
582
  real = amp * torch.cos(phase)
583
  imag = amp * torch.sin(phase)
584
- #spec = torch.cat([real, imag], 3)
 
585
  spec = torch.complex(real, imag)
586
-
587
- istft_x = torch.istft(spec, self.fft_size, self.hop_size, self.win_size, self.window.to(amp), True, length=x.shape[2] * self.hop_size, return_complex=False)
588
-
 
589
  return istft_x.unsqueeze(1)
590
 
591
 
 
592
  class LayerNorm(nn.Module):
593
  def __init__(self, channels, eps=1e-5):
594
  super().__init__()
 
573
  def forward(self, x, mask):
574
  istft_x = x
575
  istft_x = self.istft_pre(istft_x)
576
+
577
+ # Apply mask
578
  istft_x = self.net(istft_x) * mask
579
+
580
+ # Compute amplitude and random phase
581
  amp = self.istft_amplitude(istft_x).unsqueeze(-1)
582
+ phase = (torch.rand(amp.shape) * 2 * 3.14 - 3.14).to(amp.device)
583
+
584
+ # Calculate real and imaginary parts
585
  real = amp * torch.cos(phase)
586
  imag = amp * torch.sin(phase)
587
+
588
+ # Create a complex tensor from real and imaginary parts
589
  spec = torch.complex(real, imag)
590
+
591
+ # Inverse short-time Fourier transform
592
+ istft_x = torch.istft(spec, self.fft_size, self.hop_size, self.win_size, self.window.to(spec.device), True, length=x.shape[2] * self.hop_size, return_complex=False)
593
+
594
  return istft_x.unsqueeze(1)
595
 
596
 
597
+
598
  class LayerNorm(nn.Module):
599
  def __init__(self, channels, eps=1e-5):
600
  super().__init__()