0.4.5 fix extremely short case that lengths of text_seq > audio_seq, causing wrong cond_mask
Browse files- pyproject.toml +1 -1
- src/f5_tts/model/cfm.py +3 -5
pyproject.toml
CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
-
version = "0.4.
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
+
version = "0.4.5"
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
src/f5_tts/model/cfm.py
CHANGED
@@ -120,10 +120,6 @@ class CFM(nn.Module):
|
|
120 |
text = list_str_to_tensor(text).to(device)
|
121 |
assert text.shape[0] == batch
|
122 |
|
123 |
-
if exists(text):
|
124 |
-
text_lens = (text != -1).sum(dim=-1)
|
125 |
-
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
126 |
-
|
127 |
# duration
|
128 |
|
129 |
cond_mask = lens_to_mask(lens)
|
@@ -133,7 +129,9 @@ class CFM(nn.Module):
|
|
133 |
if isinstance(duration, int):
|
134 |
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
135 |
|
136 |
-
duration = torch.maximum(
|
|
|
|
|
137 |
duration = duration.clamp(max=max_duration)
|
138 |
max_duration = duration.amax()
|
139 |
|
|
|
120 |
text = list_str_to_tensor(text).to(device)
|
121 |
assert text.shape[0] == batch
|
122 |
|
|
|
|
|
|
|
|
|
123 |
# duration
|
124 |
|
125 |
cond_mask = lens_to_mask(lens)
|
|
|
129 |
if isinstance(duration, int):
|
130 |
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
131 |
|
132 |
+
duration = torch.maximum(
|
133 |
+
torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
|
134 |
+
) # duration at least text/audio prompt length plus one token, so something is generated
|
135 |
duration = duration.clamp(max=max_duration)
|
136 |
max_duration = duration.amax()
|
137 |
|