xlr8 commited on
Commit
0be2076
·
1 Parent(s): b50cb0b
Files changed (1) hide show
  1. models.py +36 -62
models.py CHANGED
@@ -2,6 +2,7 @@ from dataclasses import dataclass
2
 
3
  import torch
4
  import torch.nn as nn
 
5
  import torchtune
6
  from huggingface_hub import PyTorchModelHubMixin
7
  from torchtune.models import llama3_2
@@ -67,35 +68,32 @@ def sample_topk_topp(
67
  temperature: float,
68
  ) -> torch.Tensor:
69
  """
70
- Returns a tensor of shape (batch_size, 1) of sampled token indices,
71
- applying first top-k, then nucleus (top-p), then multinomial sampling.
72
  """
73
- # scale and softmax
74
  scaled = logits / temperature
75
- probs = torch.softmax(scaled, dim=-1)
76
 
77
- # apply top-k mask
78
  if topk < probs.size(-1):
79
  topk_vals, topk_idx = torch.topk(probs, topk, dim=-1)
80
  mask = torch.zeros_like(probs)
81
  mask.scatter_(-1, topk_idx, topk_vals)
82
  probs = mask
83
 
84
- # apply top-p (nucleus)
85
  sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
86
  cumulative = torch.cumsum(sorted_probs, dim=-1)
87
- keep_mask = cumulative <= top_p
88
- keep_mask[..., 0] = True # always keep the top token
89
 
90
  probs_final = torch.zeros_like(probs)
91
- probs_final.scatter_(-1, sorted_idx, sorted_probs * keep_mask.float())
92
 
93
- # renormalize
94
  probs_final = probs_final / probs_final.sum(dim=-1, keepdim=True)
95
 
96
- # sample once per batch, keep that extra dim!
97
- sample = torch.multinomial(probs_final, num_samples=1) # (batch_size, 1)
98
- return sample
99
 
100
 
101
  @dataclass
@@ -118,9 +116,9 @@ class Model(
118
  super().__init__()
119
  self.config = config
120
 
121
- # backbone (text+audio embedding)
122
  self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
123
- # decoder (only audio codebooks)
124
  self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
125
 
126
  self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
@@ -155,59 +153,46 @@ class Model(
155
  @torch.inference_mode()
156
  def generate_frame(
157
  self,
158
- tokens: torch.Tensor,
159
- tokens_mask: torch.Tensor,
160
- input_pos: torch.Tensor,
161
  temperature: float,
162
  topk: int,
163
  top_p: float,
164
  ) -> torch.Tensor:
165
- """
166
- tokens: (batch, seq, codebooks+1)
167
- tokens_mask: (batch, seq, codebooks+1)
168
- input_pos: (batch, seq)
169
- Returns:
170
- Tensor of shape (batch, codebooks) containing one new token per codebook.
171
- """
172
  dtype = next(self.parameters()).dtype
173
 
174
- assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
175
- # build backbone mask from causal mask + positions
176
- bb_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
177
-
178
- # embed and encode
179
  embeds = self._embed_tokens(tokens)
180
- h = self.backbone(
181
- (embeds * tokens_mask.unsqueeze(-1)).sum(dim=2),
182
- input_pos=input_pos,
183
- mask=bb_mask,
184
- ).to(dtype=dtype)
185
 
186
- # Take last hidden state
187
  last_h = h[:, -1, :] # (batch, hidden)
188
- last_h_unsq = last_h.unsqueeze(1) # (batch, 1, hidden)
189
 
190
- # ==== CODEBOOK 0 ====
191
- c0_logits = self.codebook0_head(last_h) # (batch, vocab)
192
  c0_sample = sample_topk_topp(c0_logits, topk, top_p, temperature) # (batch,1)
193
- c0_embed = self._embed_audio(0, c0_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
194
 
195
  # Prepare for decoder
196
  curr_h = torch.cat([last_h_unsq, c0_embed], dim=1) # (batch,2,hidden)
197
  curr_sample = c0_sample.clone() # (batch,1)
198
  curr_pos = torch.arange(0, curr_h.size(1)).unsqueeze(0).to(tokens.device).long() # (1,2)
199
 
200
- # ==== Remaining codebooks ====
201
  self.decoder.reset_caches()
202
  for i in range(1, self.config.audio_num_codebooks):
203
- dec_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
204
- dec_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=dec_mask).to(dtype=dtype)
 
205
  ci_logits = torch.mm(dec_h[:, -1, :], self.audio_head[i - 1]) # (batch, vocab)
206
  ci_sample = sample_topk_topp(ci_logits, topk, top_p, temperature) # (batch,1)
207
  ci_embed = self._embed_audio(i, ci_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
208
 
209
  curr_h = ci_embed
210
- curr_sample = torch.cat([curr_sample, ci_sample], dim=1) # (batch, i+1)
211
  curr_pos = curr_pos[:, -1:] + 1
212
 
213
  return curr_sample # (batch, audio_num_codebooks)
@@ -218,7 +203,7 @@ class Model(
218
 
219
  def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
220
  """
221
- tokens: (batch,) of token IDs for this codebook
222
  returns: (batch, hidden)
223
  """
224
  ids = tokens + codebook * self.config.audio_vocab_size
@@ -229,26 +214,15 @@ class Model(
229
  tokens: (batch, seq, codebooks+1)
230
  returns: (batch, seq, codebooks+1, hidden)
231
  """
232
- # text part (last index of 33)
233
  text_ids = tokens[:, :, -1]
234
- text_emb = self.text_embeddings(text_ids).unsqueeze(-2) # (batch, seq, 1, hidden)
235
-
236
- # audio codebooks
237
  audio_ids = tokens[:, :, :-1] + (
238
  self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
239
- ) # (batch, seq, codebooks)
240
  audio_emb = (
241
- self.audio_embeddings(audio_ids.reshape(-1)).reshape(
242
- tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
243
- )
244
- ) # (batch, seq, codebooks, hidden)
245
 
246
- return torch.cat([audio_emb, text_emb], dim=2) # (batch, seq, codebooks+1, hidden)
247
 
248
- @classmethod
249
- def from_pretrained(cls, repo_id: str):
250
- # load args & state from HF repo, e.g. sesame/csm-1b or your fine-tuned xlr8harder model
251
- config = cls._load_config(repo_id) # uses PyTorchModelHubMixin behind the scenes
252
- model = cls(config)
253
- model.load_state_dict(model._load_state_dict(repo_id), strict=False)
254
- return model
 
2
 
3
  import torch
4
  import torch.nn as nn
5
+ import torch.nn.functional as F
6
  import torchtune
7
  from huggingface_hub import PyTorchModelHubMixin
8
  from torchtune.models import llama3_2
 
68
  temperature: float,
69
  ) -> torch.Tensor:
70
  """
71
+ Apply top-k, then nucleus (top-p), then sample.
72
+ Returns a tensor of shape (batch_size, 1).
73
  """
 
74
  scaled = logits / temperature
75
+ probs = F.softmax(scaled, dim=-1)
76
 
77
+ # Top-k
78
  if topk < probs.size(-1):
79
  topk_vals, topk_idx = torch.topk(probs, topk, dim=-1)
80
  mask = torch.zeros_like(probs)
81
  mask.scatter_(-1, topk_idx, topk_vals)
82
  probs = mask
83
 
84
+ # Nucleus (top-p)
85
  sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
86
  cumulative = torch.cumsum(sorted_probs, dim=-1)
87
+ keep = cumulative <= top_p
88
+ keep[..., 0] = True # always keep top token
89
 
90
  probs_final = torch.zeros_like(probs)
91
+ probs_final.scatter_(-1, sorted_idx, sorted_probs * keep.float())
92
 
 
93
  probs_final = probs_final / probs_final.sum(dim=-1, keepdim=True)
94
 
95
+ # sample once per batch, keep that extra dim
96
+ return torch.multinomial(probs_final, num_samples=1) # (batch, 1)
 
97
 
98
 
99
  @dataclass
 
116
  super().__init__()
117
  self.config = config
118
 
119
+ # Text + audio backbone
120
  self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
121
+ # Audio decoder
122
  self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
123
 
124
  self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
 
153
  @torch.inference_mode()
154
  def generate_frame(
155
  self,
156
+ tokens: torch.Tensor, # (batch, seq, codebooks+1)
157
+ tokens_mask: torch.Tensor, # (batch, seq, codebooks+1)
158
+ input_pos: torch.Tensor, # (batch, seq)
159
  temperature: float,
160
  topk: int,
161
  top_p: float,
162
  ) -> torch.Tensor:
 
 
 
 
 
 
 
163
  dtype = next(self.parameters()).dtype
164
 
165
+ # Backbone forward
166
+ mask_bb = _index_causal_mask(self.backbone_causal_mask, input_pos)
 
 
 
167
  embeds = self._embed_tokens(tokens)
168
+ h = self.backbone((embeds * tokens_mask.unsqueeze(-1)).sum(dim=2), input_pos=input_pos, mask=mask_bb).to(dtype=dtype)
 
 
 
 
169
 
170
+ # Last hidden
171
  last_h = h[:, -1, :] # (batch, hidden)
172
+ last_h_unsq = last_h.unsqueeze(1) # (batch,1,hidden)
173
 
174
+ # Codebook 0
175
+ c0_logits = self.codebook0_head(last_h) # (batch, vocab)
176
  c0_sample = sample_topk_topp(c0_logits, topk, top_p, temperature) # (batch,1)
177
+ c0_embed = self._embed_audio(0, c0_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
178
 
179
  # Prepare for decoder
180
  curr_h = torch.cat([last_h_unsq, c0_embed], dim=1) # (batch,2,hidden)
181
  curr_sample = c0_sample.clone() # (batch,1)
182
  curr_pos = torch.arange(0, curr_h.size(1)).unsqueeze(0).to(tokens.device).long() # (1,2)
183
 
184
+ # Remaining codebooks
185
  self.decoder.reset_caches()
186
  for i in range(1, self.config.audio_num_codebooks):
187
+ mask_dec = _index_causal_mask(self.decoder_causal_mask, curr_pos)
188
+ dec_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=mask_dec).to(dtype=dtype)
189
+
190
  ci_logits = torch.mm(dec_h[:, -1, :], self.audio_head[i - 1]) # (batch, vocab)
191
  ci_sample = sample_topk_topp(ci_logits, topk, top_p, temperature) # (batch,1)
192
  ci_embed = self._embed_audio(i, ci_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
193
 
194
  curr_h = ci_embed
195
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1) # (batch,i+1)
196
  curr_pos = curr_pos[:, -1:] + 1
197
 
198
  return curr_sample # (batch, audio_num_codebooks)
 
203
 
204
  def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
205
  """
206
+ tokens: (batch,) token IDs for this codebook
207
  returns: (batch, hidden)
208
  """
209
  ids = tokens + codebook * self.config.audio_vocab_size
 
214
  tokens: (batch, seq, codebooks+1)
215
  returns: (batch, seq, codebooks+1, hidden)
216
  """
 
217
  text_ids = tokens[:, :, -1]
218
+ text_emb = self.text_embeddings(text_ids).unsqueeze(-2)
 
 
219
  audio_ids = tokens[:, :, :-1] + (
220
  self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
221
+ )
222
  audio_emb = (
223
+ self.audio_embeddings(audio_ids.reshape(-1))
224
+ .reshape(tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1)
225
+ )
226
+ return torch.cat([audio_emb, text_emb], dim=2)
227
 
 
228