SWivid commited on
Commit
8c7215c
·
1 Parent(s): c47c011

correct imple., minor fixes

Browse files
src/f5_tts/model/backbones/dit.py CHANGED
@@ -182,10 +182,16 @@ class DiT(nn.Module):
182
 
183
  return ckpt_forward
184
 
185
- def clear_cache(self):
186
- self.text_cond, self.text_uncond = None, None
187
-
188
- def get_text_embed(self, text, seq_len, drop_text, cache):
 
 
 
 
 
 
189
  if cache:
190
  if drop_text:
191
  if self.text_uncond is None:
@@ -197,7 +203,13 @@ class DiT(nn.Module):
197
  text_embed = self.text_cond
198
  else:
199
  text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
200
- return text_embed
 
 
 
 
 
 
201
 
202
  def forward(
203
  self,
@@ -205,11 +217,11 @@ class DiT(nn.Module):
205
  cond: float["b n d"], # masked cond audio # noqa: F722
206
  text: int["b nt"], # text # noqa: F722
207
  time: float["b"] | float[""], # time step # noqa: F821 F722
208
- drop_audio_cond, # cfg for cond audio
209
- drop_text, # cfg for text
210
- batch_cfg=False, # batch cfg compute
211
  mask: bool["b n"] | None = None, # noqa: F722
212
- cache=False,
 
 
 
213
  ):
214
  batch, seq_len = x.shape[0], x.shape[1]
215
  if time.ndim == 0:
@@ -217,21 +229,14 @@ class DiT(nn.Module):
217
 
218
  # t: conditioning time, text: text, x: noised audio + cond audio + text
219
  t = self.time_embed(time)
220
- if batch_cfg:
221
- text_embed_cond = self.get_text_embed(
222
- text, seq_len, drop_text=False, cache=cache
223
- )
224
- text_embed_uncond = self.get_text_embed(
225
- text, seq_len, drop_text=True, cache=cache
226
- )
227
- x_cond = self.input_embed(x, cond, text_embed_cond, drop_audio_cond=False)
228
- x_uncond = self.input_embed(
229
- x, cond, text_embed_uncond, drop_audio_cond=True
230
- )
231
  x = torch.cat((x_cond, x_uncond), dim=0)
 
 
232
  else:
233
- text_embed = self.get_text_embed(text, seq_len, drop_text, cache)
234
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
235
 
236
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
237
 
 
182
 
183
  return ckpt_forward
184
 
185
+ def get_input_embed(
186
+ self,
187
+ x, # b n d
188
+ cond, # b n d
189
+ text, # b nt
190
+ drop_audio_cond: bool = False,
191
+ drop_text: bool = False,
192
+ cache: bool = True,
193
+ ):
194
+ seq_len = x.shape[1]
195
  if cache:
196
  if drop_text:
197
  if self.text_uncond is None:
 
203
  text_embed = self.text_cond
204
  else:
205
  text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
206
+
207
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
208
+
209
+ return x
210
+
211
+ def clear_cache(self):
212
+ self.text_cond, self.text_uncond = None, None
213
 
214
  def forward(
215
  self,
 
217
  cond: float["b n d"], # masked cond audio # noqa: F722
218
  text: int["b nt"], # text # noqa: F722
219
  time: float["b"] | float[""], # time step # noqa: F821 F722
 
 
 
220
  mask: bool["b n"] | None = None, # noqa: F722
221
+ drop_audio_cond: bool = False, # cfg for cond audio
222
+ drop_text: bool = False, # cfg for text
223
+ cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
224
+ cache: bool = False,
225
  ):
226
  batch, seq_len = x.shape[0], x.shape[1]
227
  if time.ndim == 0:
 
229
 
230
  # t: conditioning time, text: text, x: noised audio + cond audio + text
231
  t = self.time_embed(time)
232
+ if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
233
+ x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
234
+ x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
 
 
 
 
 
 
 
 
235
  x = torch.cat((x_cond, x_uncond), dim=0)
236
+ t = torch.cat((t, t), dim=0)
237
+ mask = torch.cat((mask, mask), dim=0) if mask is not None else None
238
  else:
239
+ x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
 
240
 
241
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
242
 
src/f5_tts/model/backbones/mmdit.py CHANGED
@@ -141,6 +141,30 @@ class MMDiT(nn.Module):
141
  nn.init.constant_(self.proj_out.weight, 0)
142
  nn.init.constant_(self.proj_out.bias, 0)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def clear_cache(self):
145
  self.text_cond, self.text_uncond = None, None
146
 
@@ -150,10 +174,11 @@ class MMDiT(nn.Module):
150
  cond: float["b n d"], # masked cond audio # noqa: F722
151
  text: int["b nt"], # text # noqa: F722
152
  time: float["b"] | float[""], # time step # noqa: F821 F722
153
- drop_audio_cond, # cfg for cond audio
154
- drop_text, # cfg for text
155
  mask: bool["b n"] | None = None, # noqa: F722
156
- cache=False,
 
 
 
157
  ):
158
  batch = x.shape[0]
159
  if time.ndim == 0:
@@ -161,18 +186,17 @@ class MMDiT(nn.Module):
161
 
162
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
163
  t = self.time_embed(time)
164
- if cache:
165
- if drop_text:
166
- if self.text_uncond is None:
167
- self.text_uncond = self.text_embed(text, drop_text=True)
168
- c = self.text_uncond
169
- else:
170
- if self.text_cond is None:
171
- self.text_cond = self.text_embed(text, drop_text=False)
172
- c = self.text_cond
173
  else:
174
- c = self.text_embed(text, drop_text=drop_text)
175
- x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
 
176
 
177
  seq_len = x.shape[1]
178
  text_len = text.shape[1]
 
141
  nn.init.constant_(self.proj_out.weight, 0)
142
  nn.init.constant_(self.proj_out.bias, 0)
143
 
144
+ def get_input_embed(
145
+ self,
146
+ x, # b n d
147
+ cond, # b n d
148
+ text, # b nt
149
+ drop_audio_cond: bool = False,
150
+ drop_text: bool = False,
151
+ cache: bool = True,
152
+ ):
153
+ if cache:
154
+ if drop_text:
155
+ if self.text_uncond is None:
156
+ self.text_uncond = self.text_embed(text, drop_text=True)
157
+ c = self.text_uncond
158
+ else:
159
+ if self.text_cond is None:
160
+ self.text_cond = self.text_embed(text, drop_text=False)
161
+ c = self.text_cond
162
+ else:
163
+ c = self.text_embed(text, drop_text=drop_text)
164
+ x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
165
+
166
+ return x, c
167
+
168
  def clear_cache(self):
169
  self.text_cond, self.text_uncond = None, None
170
 
 
174
  cond: float["b n d"], # masked cond audio # noqa: F722
175
  text: int["b nt"], # text # noqa: F722
176
  time: float["b"] | float[""], # time step # noqa: F821 F722
 
 
177
  mask: bool["b n"] | None = None, # noqa: F722
178
+ drop_audio_cond: bool = False, # cfg for cond audio
179
+ drop_text: bool = False, # cfg for text
180
+ cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
181
+ cache: bool = False,
182
  ):
183
  batch = x.shape[0]
184
  if time.ndim == 0:
 
186
 
187
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
188
  t = self.time_embed(time)
189
+ if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
190
+ x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
191
+ x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
192
+ x = torch.cat((x_cond, x_uncond), dim=0)
193
+ c = torch.cat((c_cond, c_uncond), dim=0)
194
+ t = torch.cat((t, t), dim=0)
195
+ mask = torch.cat((mask, mask), dim=0) if mask is not None else None
 
 
196
  else:
197
+ x, c = self.get_input_embed(
198
+ x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
199
+ )
200
 
201
  seq_len = x.shape[1]
202
  text_len = text.shape[1]
src/f5_tts/model/backbones/unett.py CHANGED
@@ -178,6 +178,32 @@ class UNetT(nn.Module):
178
  self.norm_out = RMSNorm(dim)
179
  self.proj_out = nn.Linear(dim, mel_dim)
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def clear_cache(self):
182
  self.text_cond, self.text_uncond = None, None
183
 
@@ -187,10 +213,11 @@ class UNetT(nn.Module):
187
  cond: float["b n d"], # masked cond audio # noqa: F722
188
  text: int["b nt"], # text # noqa: F722
189
  time: float["b"] | float[""], # time step # noqa: F821 F722
190
- drop_audio_cond, # cfg for cond audio
191
- drop_text, # cfg for text
192
  mask: bool["b n"] | None = None, # noqa: F722
193
- cache=False,
 
 
 
194
  ):
195
  batch, seq_len = x.shape[0], x.shape[1]
196
  if time.ndim == 0:
@@ -198,18 +225,14 @@ class UNetT(nn.Module):
198
 
199
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
200
  t = self.time_embed(time)
201
- if cache:
202
- if drop_text:
203
- if self.text_uncond is None:
204
- self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
205
- text_embed = self.text_uncond
206
- else:
207
- if self.text_cond is None:
208
- self.text_cond = self.text_embed(text, seq_len, drop_text=False)
209
- text_embed = self.text_cond
210
  else:
211
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
212
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
213
 
214
  # postfix time t to input x, [b n d] -> [b n+1 d]
215
  x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
 
178
  self.norm_out = RMSNorm(dim)
179
  self.proj_out = nn.Linear(dim, mel_dim)
180
 
181
+ def get_input_embed(
182
+ self,
183
+ x, # b n d
184
+ cond, # b n d
185
+ text, # b nt
186
+ drop_audio_cond: bool = False,
187
+ drop_text: bool = False,
188
+ cache: bool = True,
189
+ ):
190
+ seq_len = x.shape[1]
191
+ if cache:
192
+ if drop_text:
193
+ if self.text_uncond is None:
194
+ self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
195
+ text_embed = self.text_uncond
196
+ else:
197
+ if self.text_cond is None:
198
+ self.text_cond = self.text_embed(text, seq_len, drop_text=False)
199
+ text_embed = self.text_cond
200
+ else:
201
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
202
+
203
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
204
+
205
+ return x
206
+
207
  def clear_cache(self):
208
  self.text_cond, self.text_uncond = None, None
209
 
 
213
  cond: float["b n d"], # masked cond audio # noqa: F722
214
  text: int["b nt"], # text # noqa: F722
215
  time: float["b"] | float[""], # time step # noqa: F821 F722
 
 
216
  mask: bool["b n"] | None = None, # noqa: F722
217
+ drop_audio_cond: bool = False, # cfg for cond audio
218
+ drop_text: bool = False, # cfg for text
219
+ cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
220
+ cache: bool = False,
221
  ):
222
  batch, seq_len = x.shape[0], x.shape[1]
223
  if time.ndim == 0:
 
225
 
226
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
227
  t = self.time_embed(time)
228
+ if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
229
+ x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
230
+ x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
231
+ x = torch.cat((x_cond, x_uncond), dim=0)
232
+ t = torch.cat((t, t), dim=0)
233
+ mask = torch.cat((mask, mask), dim=0) if mask is not None else None
 
 
 
234
  else:
235
+ x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
 
236
 
237
  # postfix time t to input x, [b n d] -> [b n+1 d]
238
  x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
src/f5_tts/model/cfm.py CHANGED
@@ -162,7 +162,7 @@ class CFM(nn.Module):
162
  # at each step, conditioning is fixed
163
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
164
 
165
- # predict flow
166
  if cfg_strength < 1e-5:
167
  pred = self.transformer(
168
  x=x,
@@ -176,18 +176,17 @@ class CFM(nn.Module):
176
  )
177
  return pred
178
 
179
- pred_and_null = self.transformer(
 
180
  x=x,
181
  cond=step_cond,
182
  text=text,
183
  time=t,
184
  mask=mask,
185
- drop_audio_cond=False,
186
- drop_text=False,
187
- batch_cfg=True,
188
  cache=True,
189
  )
190
- pred, null_pred = torch.chunk(pred_and_null, 2, dim=0)
191
  return pred + (pred - null_pred) * cfg_strength
192
 
193
  # noise input
 
162
  # at each step, conditioning is fixed
163
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
164
 
165
+ # predict flow (cond)
166
  if cfg_strength < 1e-5:
167
  pred = self.transformer(
168
  x=x,
 
176
  )
177
  return pred
178
 
179
+ # predict flow (cond and uncond), for classifier-free guidance
180
+ pred_cfg = self.transformer(
181
  x=x,
182
  cond=step_cond,
183
  text=text,
184
  time=t,
185
  mask=mask,
186
+ cfg_infer=True,
 
 
187
  cache=True,
188
  )
189
+ pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
190
  return pred + (pred - null_pred) * cfg_strength
191
 
192
  # noise input