Fix sample timesteps
Browse files- src/f5_tts/model/cfm.py +1 -1
src/f5_tts/model/cfm.py
CHANGED
@@ -193,7 +193,7 @@ class CFM(nn.Module):
|
|
193 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
194 |
steps = int(steps * (1 - t_start))
|
195 |
|
196 |
-
t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
|
197 |
if sway_sampling_coef is not None:
|
198 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
199 |
|
|
|
193 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
194 |
steps = int(steps * (1 - t_start))
|
195 |
|
196 |
+
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
197 |
if sway_sampling_coef is not None:
|
198 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
199 |
|