Yushen CHEN commited on
Commit
2e44356
·
2 Parent(s): ecd723b dceb380

Merge pull request #859 from ZhikangNiu/main

Browse files

fix #858 and pass use_reentrant explicitly in checkpoint_activation mode

Files changed (1) hide show
  1. src/f5_tts/model/backbones/dit.py +3 -1
src/f5_tts/model/backbones/dit.py CHANGED
@@ -219,7 +219,9 @@ class DiT(nn.Module):
219
 
220
  for block in self.transformer_blocks:
221
  if self.checkpoint_activations:
222
- x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
 
 
223
  else:
224
  x = block(x, t, mask=mask, rope=rope)
225
 
 
219
 
220
  for block in self.transformer_blocks:
221
  if self.checkpoint_activations:
222
+ # if you have question, please check: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
223
+ # After PyTorch 2.4, we must pass the use_reentrant explicitly
224
+ x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
225
  else:
226
  x = block(x, t, mask=mask, rope=rope)
227