starkwj7 commited on
Commit
c47c011
·
1 Parent(s): 1d923b1

Batch cfg DiT forward

Browse files
src/f5_tts/model/backbones/dit.py CHANGED
@@ -185,6 +185,20 @@ class DiT(nn.Module):
185
  def clear_cache(self):
186
  self.text_cond, self.text_uncond = None, None
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def forward(
189
  self,
190
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -193,6 +207,7 @@ class DiT(nn.Module):
193
  time: float["b"] | float[""], # time step # noqa: F821 F722
194
  drop_audio_cond, # cfg for cond audio
195
  drop_text, # cfg for text
 
196
  mask: bool["b n"] | None = None, # noqa: F722
197
  cache=False,
198
  ):
@@ -202,18 +217,21 @@ class DiT(nn.Module):
202
 
203
  # t: conditioning time, text: text, x: noised audio + cond audio + text
204
  t = self.time_embed(time)
205
- if cache:
206
- if drop_text:
207
- if self.text_uncond is None:
208
- self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
209
- text_embed = self.text_uncond
210
- else:
211
- if self.text_cond is None:
212
- self.text_cond = self.text_embed(text, seq_len, drop_text=False)
213
- text_embed = self.text_cond
 
 
 
214
  else:
215
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
216
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
217
 
218
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
219
 
 
185
  def clear_cache(self):
186
  self.text_cond, self.text_uncond = None, None
187
 
188
+ def get_text_embed(self, text, seq_len, drop_text, cache):
189
+ if cache:
190
+ if drop_text:
191
+ if self.text_uncond is None:
192
+ self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
193
+ text_embed = self.text_uncond
194
+ else:
195
+ if self.text_cond is None:
196
+ self.text_cond = self.text_embed(text, seq_len, drop_text=False)
197
+ text_embed = self.text_cond
198
+ else:
199
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
200
+ return text_embed
201
+
202
  def forward(
203
  self,
204
  x: float["b n d"], # nosied input audio # noqa: F722
 
207
  time: float["b"] | float[""], # time step # noqa: F821 F722
208
  drop_audio_cond, # cfg for cond audio
209
  drop_text, # cfg for text
210
+ batch_cfg=False, # batch cfg compute
211
  mask: bool["b n"] | None = None, # noqa: F722
212
  cache=False,
213
  ):
 
217
 
218
  # t: conditioning time, text: text, x: noised audio + cond audio + text
219
  t = self.time_embed(time)
220
+ if batch_cfg:
221
+ text_embed_cond = self.get_text_embed(
222
+ text, seq_len, drop_text=False, cache=cache
223
+ )
224
+ text_embed_uncond = self.get_text_embed(
225
+ text, seq_len, drop_text=True, cache=cache
226
+ )
227
+ x_cond = self.input_embed(x, cond, text_embed_cond, drop_audio_cond=False)
228
+ x_uncond = self.input_embed(
229
+ x, cond, text_embed_uncond, drop_audio_cond=True
230
+ )
231
+ x = torch.cat((x_cond, x_uncond), dim=0)
232
  else:
233
+ text_embed = self.get_text_embed(text, seq_len, drop_text, cache)
234
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
235
 
236
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
237
 
src/f5_tts/model/cfm.py CHANGED
@@ -163,15 +163,31 @@ class CFM(nn.Module):
163
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
164
 
165
  # predict flow
166
- pred = self.transformer(
167
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
168
- )
169
  if cfg_strength < 1e-5:
 
 
 
 
 
 
 
 
 
 
170
  return pred
171
 
172
- null_pred = self.transformer(
173
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
 
 
 
 
 
 
 
 
174
  )
 
175
  return pred + (pred - null_pred) * cfg_strength
176
 
177
  # noise input
 
163
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
164
 
165
  # predict flow
 
 
 
166
  if cfg_strength < 1e-5:
167
+ pred = self.transformer(
168
+ x=x,
169
+ cond=step_cond,
170
+ text=text,
171
+ time=t,
172
+ mask=mask,
173
+ drop_audio_cond=False,
174
+ drop_text=False,
175
+ cache=True,
176
+ )
177
  return pred
178
 
179
+ pred_and_null = self.transformer(
180
+ x=x,
181
+ cond=step_cond,
182
+ text=text,
183
+ time=t,
184
+ mask=mask,
185
+ drop_audio_cond=False,
186
+ drop_text=False,
187
+ batch_cfg=True,
188
+ cache=True,
189
  )
190
+ pred, null_pred = torch.chunk(pred_and_null, 2, dim=0)
191
  return pred + (pred - null_pred) * cfg_strength
192
 
193
  # noise input