kevinwang676 commited on
Commit
44d17ea
·
verified ·
1 Parent(s): 7de3d91

Update usr/diff/shallow_diffusion_tts.py

Browse files
Files changed (1) hide show
  1. usr/diff/shallow_diffusion_tts.py +9 -5
usr/diff/shallow_diffusion_tts.py CHANGED
@@ -16,6 +16,7 @@ from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
16
  from utils.hparams import hparams
17
 
18
 
 
19
  def exists(x):
20
  return x is not None
21
 
@@ -157,7 +158,7 @@ class GaussianDiffusion(nn.Module):
157
 
158
  @torch.no_grad()
159
  def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
160
- b, *_, device = *x.shape, "cuda"
161
  model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
162
  noise = noise_like(x.shape, device, repeat_noise)
163
  # no noise when t == 0
@@ -172,7 +173,10 @@ class GaussianDiffusion(nn.Module):
172
 
173
  def get_x_pred(x, noise_t, t):
174
  a_t = extract(self.alphas_cumprod, t, x.shape)
175
- a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
 
 
 
176
  a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
177
 
178
  x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
@@ -228,7 +232,7 @@ class GaussianDiffusion(nn.Module):
228
 
229
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
230
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
231
- b, *_, device = *txt_tokens.shape, "cuda"
232
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
233
  skip_decoder=(not infer), infer=infer, **kwargs)
234
  cond = ret['decoder_inp'].transpose(1, 2)
@@ -287,7 +291,7 @@ class GaussianDiffusion(nn.Module):
287
  class OfflineGaussianDiffusion(GaussianDiffusion):
288
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
289
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
290
- b, *_, device = *txt_tokens.shape, "cuda"
291
 
292
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
293
  skip_decoder=True, infer=True, **kwargs)
@@ -316,4 +320,4 @@ class OfflineGaussianDiffusion(GaussianDiffusion):
316
  x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
317
  x = x[:, 0].transpose(1, 2)
318
  ret['mel_out'] = self.denorm_spec(x)
319
- return ret
 
16
  from utils.hparams import hparams
17
 
18
 
19
+
20
  def exists(x):
21
  return x is not None
22
 
 
158
 
159
  @torch.no_grad()
160
  def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
161
+ b, *_, device = *x.shape, x.device
162
  model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
163
  noise = noise_like(x.shape, device, repeat_noise)
164
  # no noise when t == 0
 
173
 
174
  def get_x_pred(x, noise_t, t):
175
  a_t = extract(self.alphas_cumprod, t, x.shape)
176
+ if t[0] < interval:
177
+ a_prev = torch.ones_like(a_t)
178
+ else:
179
+ a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
180
  a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
181
 
182
  x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
 
232
 
233
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
234
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
235
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
236
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
237
  skip_decoder=(not infer), infer=infer, **kwargs)
238
  cond = ret['decoder_inp'].transpose(1, 2)
 
291
  class OfflineGaussianDiffusion(GaussianDiffusion):
292
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
293
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
294
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
295
 
296
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
297
  skip_decoder=True, infer=True, **kwargs)
 
320
  x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
321
  x = x[:, 0].transpose(1, 2)
322
  ret['mel_out'] = self.denorm_spec(x)
323
+ return ret