Spaces:
Paused
Paused
Update usr/diff/shallow_diffusion_tts.py
Browse files
usr/diff/shallow_diffusion_tts.py
CHANGED
@@ -157,7 +157,7 @@ class GaussianDiffusion(nn.Module):
|
|
157 |
|
158 |
@torch.no_grad()
|
159 |
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
160 |
-
b, *_, device = *x.shape,
|
161 |
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
162 |
noise = noise_like(x.shape, device, repeat_noise)
|
163 |
# no noise when t == 0
|
@@ -172,10 +172,7 @@ class GaussianDiffusion(nn.Module):
|
|
172 |
|
173 |
def get_x_pred(x, noise_t, t):
|
174 |
a_t = extract(self.alphas_cumprod, t, x.shape)
|
175 |
-
|
176 |
-
a_prev = torch.ones_like(a_t)
|
177 |
-
else:
|
178 |
-
a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
|
179 |
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
180 |
|
181 |
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
@@ -231,7 +228,7 @@ class GaussianDiffusion(nn.Module):
|
|
231 |
|
232 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
233 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
234 |
-
b, *_, device = *txt_tokens.shape,
|
235 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
236 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
237 |
cond = ret['decoder_inp'].transpose(1, 2)
|
@@ -290,7 +287,7 @@ class GaussianDiffusion(nn.Module):
|
|
290 |
class OfflineGaussianDiffusion(GaussianDiffusion):
|
291 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
292 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
293 |
-
b, *_, device = *txt_tokens.shape,
|
294 |
|
295 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
296 |
skip_decoder=True, infer=True, **kwargs)
|
@@ -319,4 +316,4 @@ class OfflineGaussianDiffusion(GaussianDiffusion):
|
|
319 |
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
320 |
x = x[:, 0].transpose(1, 2)
|
321 |
ret['mel_out'] = self.denorm_spec(x)
|
322 |
-
return ret
|
|
|
157 |
|
158 |
@torch.no_grad()
|
159 |
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
160 |
+
b, *_, device = *x.shape, "cuda"
|
161 |
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
162 |
noise = noise_like(x.shape, device, repeat_noise)
|
163 |
# no noise when t == 0
|
|
|
172 |
|
173 |
def get_x_pred(x, noise_t, t):
|
174 |
a_t = extract(self.alphas_cumprod, t, x.shape)
|
175 |
+
a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
|
|
|
|
|
|
|
176 |
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
177 |
|
178 |
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
|
|
228 |
|
229 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
230 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
231 |
+
b, *_, device = *txt_tokens.shape, "cuda"
|
232 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
233 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
234 |
cond = ret['decoder_inp'].transpose(1, 2)
|
|
|
287 |
class OfflineGaussianDiffusion(GaussianDiffusion):
|
288 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
289 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
290 |
+
b, *_, device = *txt_tokens.shape, "cuda"
|
291 |
|
292 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
293 |
skip_decoder=True, infer=True, **kwargs)
|
|
|
316 |
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
317 |
x = x[:, 0].transpose(1, 2)
|
318 |
ret['mel_out'] = self.denorm_spec(x)
|
319 |
+
return ret
|