liuhuadai commited on
Commit
3594636
·
verified ·
1 Parent(s): ca9745d

Update think_sound/models/mmdit.py

Browse files
Files changed (1) hide show
  1. think_sound/models/mmdit.py +2 -2
think_sound/models/mmdit.py CHANGED
@@ -213,8 +213,8 @@ class MMAudio(nn.Module):
213
  self._clip_seq_len,
214
  device=self.device)
215
 
216
- self.latent_rot = nn.Buffer(latent_rot, persistent=False)
217
- self.clip_rot = nn.Buffer(clip_rot, persistent=False)
218
 
219
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
220
  self._latent_seq_len = latent_seq_len
 
213
  self._clip_seq_len,
214
  device=self.device)
215
 
216
+ self.register_buffer("latent_rot", latent_rot, persistent=False)
217
+ self.register_buffer("clip_rot", clip_rot, persistent=False)
218
 
219
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
220
  self._latent_seq_len = latent_seq_len