fix no_ref_audio in cfm.py
Browse files- src/f5_tts/model/cfm.py +3 -4
src/f5_tts/model/cfm.py
CHANGED
@@ -142,6 +142,9 @@ class CFM(nn.Module):
|
|
142 |
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
143 |
|
144 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
|
|
|
|
|
|
145 |
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
146 |
cond_mask = cond_mask.unsqueeze(-1)
|
147 |
step_cond = torch.where(
|
@@ -153,10 +156,6 @@ class CFM(nn.Module):
|
|
153 |
else: # save memory and speed up, as single inference need no mask currently
|
154 |
mask = None
|
155 |
|
156 |
-
# test for no ref audio
|
157 |
-
if no_ref_audio:
|
158 |
-
cond = torch.zeros_like(cond)
|
159 |
-
|
160 |
# neural ode
|
161 |
|
162 |
def fn(t, x):
|
|
|
142 |
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
143 |
|
144 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
145 |
+
if no_ref_audio:
|
146 |
+
cond = torch.zeros_like(cond)
|
147 |
+
|
148 |
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
149 |
cond_mask = cond_mask.unsqueeze(-1)
|
150 |
step_cond = torch.where(
|
|
|
156 |
else: # save memory and speed up, as single inference need no mask currently
|
157 |
mask = None
|
158 |
|
|
|
|
|
|
|
|
|
159 |
# neural ode
|
160 |
|
161 |
def fn(t, x):
|