update dit checkpoint_activations and fix#399 #400
Browse files
src/f5_tts/configs/F5TTS_Base_train.yaml
CHANGED
@@ -28,6 +28,7 @@ model:
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
|
|
31 |
mel_spec:
|
32 |
target_sample_rate: 24000
|
33 |
n_mel_channels: 100
|
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
31 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
src/f5_tts/configs/F5TTS_Small_train.yaml
CHANGED
@@ -28,6 +28,7 @@ model:
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
|
|
31 |
mel_spec:
|
32 |
target_sample_rate: 24000
|
33 |
n_mel_channels: 100
|
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
31 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
src/f5_tts/model/backbones/dit.py
CHANGED
@@ -105,6 +105,7 @@ class DiT(nn.Module):
|
|
105 |
text_dim=None,
|
106 |
conv_layers=0,
|
107 |
long_skip_connection=False,
|
|
|
108 |
):
|
109 |
super().__init__()
|
110 |
|
@@ -127,6 +128,17 @@ class DiT(nn.Module):
|
|
127 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
128 |
self.proj_out = nn.Linear(dim, mel_dim)
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
def forward(
|
131 |
self,
|
132 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -152,7 +164,10 @@ class DiT(nn.Module):
|
|
152 |
residual = x
|
153 |
|
154 |
for block in self.transformer_blocks:
|
155 |
-
|
|
|
|
|
|
|
156 |
|
157 |
if self.long_skip_connection is not None:
|
158 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
|
|
105 |
text_dim=None,
|
106 |
conv_layers=0,
|
107 |
long_skip_connection=False,
|
108 |
+
checkpoint_activations=False,
|
109 |
):
|
110 |
super().__init__()
|
111 |
|
|
|
128 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
129 |
self.proj_out = nn.Linear(dim, mel_dim)
|
130 |
|
131 |
+
self.checkpoint_activations = checkpoint_activations
|
132 |
+
|
133 |
+
def ckpt_wrapper(self, module):
|
134 |
+
"""Code from https://github.com/chuanyangjin/fast-DiT/blob/1a8ecce58f346f877749f2dc67cdb190d295e4dc/models.py#L233-L237"""
|
135 |
+
|
136 |
+
def ckpt_forward(*inputs):
|
137 |
+
outputs = module(*inputs)
|
138 |
+
return outputs
|
139 |
+
|
140 |
+
return ckpt_forward
|
141 |
+
|
142 |
def forward(
|
143 |
self,
|
144 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
164 |
residual = x
|
165 |
|
166 |
for block in self.transformer_blocks:
|
167 |
+
if self.checkpoint_activations:
|
168 |
+
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
|
169 |
+
else:
|
170 |
+
x = block(x, t, mask=mask, rope=rope)
|
171 |
|
172 |
if self.long_skip_connection is not None:
|
173 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|