zkniu commited on
Commit
0e60b0d
·
1 Parent(s): d00a89e

update dit checkpoint_activations and fix#399 #400

Browse files
src/f5_tts/configs/F5TTS_Base_train.yaml CHANGED
@@ -28,6 +28,7 @@ model:
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
 
31
  mel_spec:
32
  target_sample_rate: 24000
33
  n_mel_channels: 100
 
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
src/f5_tts/configs/F5TTS_Small_train.yaml CHANGED
@@ -28,6 +28,7 @@ model:
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
 
31
  mel_spec:
32
  target_sample_rate: 24000
33
  n_mel_channels: 100
 
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
src/f5_tts/model/backbones/dit.py CHANGED
@@ -105,6 +105,7 @@ class DiT(nn.Module):
105
  text_dim=None,
106
  conv_layers=0,
107
  long_skip_connection=False,
 
108
  ):
109
  super().__init__()
110
 
@@ -127,6 +128,17 @@ class DiT(nn.Module):
127
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
  self.proj_out = nn.Linear(dim, mel_dim)
129
 
 
 
 
 
 
 
 
 
 
 
 
130
  def forward(
131
  self,
132
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -152,7 +164,10 @@ class DiT(nn.Module):
152
  residual = x
153
 
154
  for block in self.transformer_blocks:
155
- x = block(x, t, mask=mask, rope=rope)
 
 
 
156
 
157
  if self.long_skip_connection is not None:
158
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
 
105
  text_dim=None,
106
  conv_layers=0,
107
  long_skip_connection=False,
108
+ checkpoint_activations=False,
109
  ):
110
  super().__init__()
111
 
 
128
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
129
  self.proj_out = nn.Linear(dim, mel_dim)
130
 
131
+ self.checkpoint_activations = checkpoint_activations
132
+
133
+ def ckpt_wrapper(self, module):
134
+ """Code from https://github.com/chuanyangjin/fast-DiT/blob/1a8ecce58f346f877749f2dc67cdb190d295e4dc/models.py#L233-L237"""
135
+
136
+ def ckpt_forward(*inputs):
137
+ outputs = module(*inputs)
138
+ return outputs
139
+
140
+ return ckpt_forward
141
+
142
  def forward(
143
  self,
144
  x: float["b n d"], # nosied input audio # noqa: F722
 
164
  residual = x
165
 
166
  for block in self.transformer_blocks:
167
+ if self.checkpoint_activations:
168
+ x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
169
+ else:
170
+ x = block(x, t, mask=mask, rope=rope)
171
 
172
  if self.long_skip_connection is not None:
173
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))