kevinwang676 commited on
Commit
8a184bc
·
verified ·
1 Parent(s): 1db33f9

Update usr/diff/shallow_diffusion_tts.py

Browse files
Files changed (1) hide show
  1. usr/diff/shallow_diffusion_tts.py +5 -8
usr/diff/shallow_diffusion_tts.py CHANGED
@@ -157,7 +157,7 @@ class GaussianDiffusion(nn.Module):
157
 
158
  @torch.no_grad()
159
  def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
160
- b, *_, device = *x.shape, x.device
161
  model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
162
  noise = noise_like(x.shape, device, repeat_noise)
163
  # no noise when t == 0
@@ -172,10 +172,7 @@ class GaussianDiffusion(nn.Module):
172
 
173
  def get_x_pred(x, noise_t, t):
174
  a_t = extract(self.alphas_cumprod, t, x.shape)
175
- if t[0] < interval:
176
- a_prev = torch.ones_like(a_t)
177
- else:
178
- a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
179
  a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
180
 
181
  x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
@@ -231,7 +228,7 @@ class GaussianDiffusion(nn.Module):
231
 
232
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
233
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
234
- b, *_, device = *txt_tokens.shape, txt_tokens.device
235
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
236
  skip_decoder=(not infer), infer=infer, **kwargs)
237
  cond = ret['decoder_inp'].transpose(1, 2)
@@ -290,7 +287,7 @@ class GaussianDiffusion(nn.Module):
290
  class OfflineGaussianDiffusion(GaussianDiffusion):
291
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
292
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
293
- b, *_, device = *txt_tokens.shape, txt_tokens.device
294
 
295
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
296
  skip_decoder=True, infer=True, **kwargs)
@@ -319,4 +316,4 @@ class OfflineGaussianDiffusion(GaussianDiffusion):
319
  x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
320
  x = x[:, 0].transpose(1, 2)
321
  ret['mel_out'] = self.denorm_spec(x)
322
- return ret
 
157
 
158
  @torch.no_grad()
159
  def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
160
+ b, *_, device = *x.shape, "cuda"
161
  model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
162
  noise = noise_like(x.shape, device, repeat_noise)
163
  # no noise when t == 0
 
172
 
173
  def get_x_pred(x, noise_t, t):
174
  a_t = extract(self.alphas_cumprod, t, x.shape)
175
+ a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
 
 
 
176
  a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
177
 
178
  x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
 
228
 
229
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
230
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
231
+ b, *_, device = *txt_tokens.shape, "cuda"
232
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
233
  skip_decoder=(not infer), infer=infer, **kwargs)
234
  cond = ret['decoder_inp'].transpose(1, 2)
 
287
  class OfflineGaussianDiffusion(GaussianDiffusion):
288
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
289
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
290
+ b, *_, device = *txt_tokens.shape, "cuda"
291
 
292
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
293
  skip_decoder=True, infer=True, **kwargs)
 
316
  x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
317
  x = x[:, 0].transpose(1, 2)
318
  ret['mel_out'] = self.denorm_spec(x)
319
+ return ret