SWivid commited on
Commit
488d746
·
1 Parent(s): 572d786

0.4.5 fix extremely short case that lengths of text_seq > audio_seq, causing wrong cond_mask

Browse files
Files changed (2) hide show
  1. pyproject.toml +1 -1
  2. src/f5_tts/model/cfm.py +3 -5
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.4.4"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "0.4.5"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
src/f5_tts/model/cfm.py CHANGED
@@ -120,10 +120,6 @@ class CFM(nn.Module):
120
  text = list_str_to_tensor(text).to(device)
121
  assert text.shape[0] == batch
122
 
123
- if exists(text):
124
- text_lens = (text != -1).sum(dim=-1)
125
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
126
-
127
  # duration
128
 
129
  cond_mask = lens_to_mask(lens)
@@ -133,7 +129,9 @@ class CFM(nn.Module):
133
  if isinstance(duration, int):
134
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
135
 
136
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
 
 
137
  duration = duration.clamp(max=max_duration)
138
  max_duration = duration.amax()
139
 
 
120
  text = list_str_to_tensor(text).to(device)
121
  assert text.shape[0] == batch
122
 
 
 
 
 
123
  # duration
124
 
125
  cond_mask = lens_to_mask(lens)
 
129
  if isinstance(duration, int):
130
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
131
 
132
+ duration = torch.maximum(
133
+ torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
134
+ ) # duration at least text/audio prompt length plus one token, so something is generated
135
  duration = duration.clamp(max=max_duration)
136
  max_duration = duration.amax()
137