Spaces:
Paused
Paused
Update usr/diff/shallow_diffusion_tts.py
Browse files
usr/diff/shallow_diffusion_tts.py
CHANGED
@@ -16,6 +16,7 @@ from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
|
16 |
from utils.hparams import hparams
|
17 |
|
18 |
|
|
|
19 |
def exists(x):
|
20 |
return x is not None
|
21 |
|
@@ -157,7 +158,7 @@ class GaussianDiffusion(nn.Module):
|
|
157 |
|
158 |
@torch.no_grad()
|
159 |
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
160 |
-
b, *_, device = *x.shape,
|
161 |
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
162 |
noise = noise_like(x.shape, device, repeat_noise)
|
163 |
# no noise when t == 0
|
@@ -172,7 +173,10 @@ class GaussianDiffusion(nn.Module):
|
|
172 |
|
173 |
def get_x_pred(x, noise_t, t):
|
174 |
a_t = extract(self.alphas_cumprod, t, x.shape)
|
175 |
-
|
|
|
|
|
|
|
176 |
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
177 |
|
178 |
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
@@ -228,7 +232,7 @@ class GaussianDiffusion(nn.Module):
|
|
228 |
|
229 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
230 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
231 |
-
b, *_, device = *txt_tokens.shape,
|
232 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
233 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
234 |
cond = ret['decoder_inp'].transpose(1, 2)
|
@@ -287,7 +291,7 @@ class GaussianDiffusion(nn.Module):
|
|
287 |
class OfflineGaussianDiffusion(GaussianDiffusion):
|
288 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
289 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
290 |
-
b, *_, device = *txt_tokens.shape,
|
291 |
|
292 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
293 |
skip_decoder=True, infer=True, **kwargs)
|
@@ -316,4 +320,4 @@ class OfflineGaussianDiffusion(GaussianDiffusion):
|
|
316 |
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
317 |
x = x[:, 0].transpose(1, 2)
|
318 |
ret['mel_out'] = self.denorm_spec(x)
|
319 |
-
return ret
|
|
|
16 |
from utils.hparams import hparams
|
17 |
|
18 |
|
19 |
+
|
20 |
def exists(x):
|
21 |
return x is not None
|
22 |
|
|
|
158 |
|
159 |
@torch.no_grad()
|
160 |
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
161 |
+
b, *_, device = *x.shape, x.device
|
162 |
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
163 |
noise = noise_like(x.shape, device, repeat_noise)
|
164 |
# no noise when t == 0
|
|
|
173 |
|
174 |
def get_x_pred(x, noise_t, t):
|
175 |
a_t = extract(self.alphas_cumprod, t, x.shape)
|
176 |
+
if t[0] < interval:
|
177 |
+
a_prev = torch.ones_like(a_t)
|
178 |
+
else:
|
179 |
+
a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
|
180 |
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
181 |
|
182 |
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
|
|
232 |
|
233 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
234 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
235 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
236 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
237 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
238 |
cond = ret['decoder_inp'].transpose(1, 2)
|
|
|
291 |
class OfflineGaussianDiffusion(GaussianDiffusion):
|
292 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
293 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
294 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
295 |
|
296 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
297 |
skip_decoder=True, infer=True, **kwargs)
|
|
|
320 |
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
321 |
x = x[:, 0].transpose(1, 2)
|
322 |
ret['mel_out'] = self.denorm_spec(x)
|
323 |
+
return ret
|