xlr8 commited on
Commit
5a36a74
·
1 Parent(s): 0be2076
Files changed (1) hide show
  1. models.py +32 -29
models.py CHANGED
@@ -71,29 +71,36 @@ def sample_topk_topp(
71
  Apply top-k, then nucleus (top-p), then sample.
72
  Returns a tensor of shape (batch_size, 1).
73
  """
 
74
  scaled = logits / temperature
75
  probs = F.softmax(scaled, dim=-1)
76
 
77
- # Top-k
78
  if topk < probs.size(-1):
79
  topk_vals, topk_idx = torch.topk(probs, topk, dim=-1)
80
- mask = torch.zeros_like(probs)
81
- mask.scatter_(-1, topk_idx, topk_vals)
82
- probs = mask
83
 
84
- # Nucleus (top-p)
85
  sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
86
  cumulative = torch.cumsum(sorted_probs, dim=-1)
87
  keep = cumulative <= top_p
88
- keep[..., 0] = True # always keep top token
89
 
 
 
 
 
90
  probs_final = torch.zeros_like(probs)
91
- probs_final.scatter_(-1, sorted_idx, sorted_probs * keep.float())
 
92
 
 
93
  probs_final = probs_final / probs_final.sum(dim=-1, keepdim=True)
94
 
95
  # sample once per batch, keep that extra dim
96
- return torch.multinomial(probs_final, num_samples=1) # (batch, 1)
97
 
98
 
99
  @dataclass
@@ -116,7 +123,7 @@ class Model(
116
  super().__init__()
117
  self.config = config
118
 
119
- # Text + audio backbone
120
  self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
121
  # Audio decoder
122
  self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
@@ -162,32 +169,36 @@ class Model(
162
  ) -> torch.Tensor:
163
  dtype = next(self.parameters()).dtype
164
 
165
- # Backbone forward
166
- mask_bb = _index_causal_mask(self.backbone_causal_mask, input_pos)
167
  embeds = self._embed_tokens(tokens)
168
- h = self.backbone((embeds * tokens_mask.unsqueeze(-1)).sum(dim=2), input_pos=input_pos, mask=mask_bb).to(dtype=dtype)
 
 
 
 
169
 
170
- # Last hidden
171
  last_h = h[:, -1, :] # (batch, hidden)
172
  last_h_unsq = last_h.unsqueeze(1) # (batch,1,hidden)
173
 
174
- # Codebook 0
175
- c0_logits = self.codebook0_head(last_h) # (batch, vocab)
176
  c0_sample = sample_topk_topp(c0_logits, topk, top_p, temperature) # (batch,1)
177
  c0_embed = self._embed_audio(0, c0_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
178
 
179
- # Prepare for decoder
180
  curr_h = torch.cat([last_h_unsq, c0_embed], dim=1) # (batch,2,hidden)
181
  curr_sample = c0_sample.clone() # (batch,1)
182
- curr_pos = torch.arange(0, curr_h.size(1)).unsqueeze(0).to(tokens.device).long() # (1,2)
183
 
184
- # Remaining codebooks
185
  self.decoder.reset_caches()
186
  for i in range(1, self.config.audio_num_codebooks):
187
- mask_dec = _index_causal_mask(self.decoder_causal_mask, curr_pos)
188
- dec_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=mask_dec).to(dtype=dtype)
189
 
190
- ci_logits = torch.mm(dec_h[:, -1, :], self.audio_head[i - 1]) # (batch, vocab)
191
  ci_sample = sample_topk_topp(ci_logits, topk, top_p, temperature) # (batch,1)
192
  ci_embed = self._embed_audio(i, ci_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
193
 
@@ -202,18 +213,10 @@ class Model(
202
  self.decoder.reset_caches()
203
 
204
  def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
205
- """
206
- tokens: (batch,) token IDs for this codebook
207
- returns: (batch, hidden)
208
- """
209
  ids = tokens + codebook * self.config.audio_vocab_size
210
  return self.audio_embeddings(ids)
211
 
212
  def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
213
- """
214
- tokens: (batch, seq, codebooks+1)
215
- returns: (batch, seq, codebooks+1, hidden)
216
- """
217
  text_ids = tokens[:, :, -1]
218
  text_emb = self.text_embeddings(text_ids).unsqueeze(-2)
219
  audio_ids = tokens[:, :, :-1] + (
 
71
  Apply top-k, then nucleus (top-p), then sample.
72
  Returns a tensor of shape (batch_size, 1).
73
  """
74
+ # scale + softmax
75
  scaled = logits / temperature
76
  probs = F.softmax(scaled, dim=-1)
77
 
78
+ # --- top-k ---
79
  if topk < probs.size(-1):
80
  topk_vals, topk_idx = torch.topk(probs, topk, dim=-1)
81
+ mask_k = torch.zeros_like(probs)
82
+ mask_k.scatter_(-1, topk_idx, topk_vals)
83
+ probs = mask_k
84
 
85
+ # --- top-p (nucleus) ---
86
  sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
87
  cumulative = torch.cumsum(sorted_probs, dim=-1)
88
  keep = cumulative <= top_p
89
+ keep[..., 0] = True # always keep highest-prob
90
 
91
+ # cast mask to same dtype as sorted_probs
92
+ keep = keep.to(sorted_probs.dtype)
93
+
94
+ # build final probabilities in correct dtype
95
  probs_final = torch.zeros_like(probs)
96
+ src = sorted_probs * keep # same dtype
97
+ probs_final.scatter_(-1, sorted_idx, src)
98
 
99
+ # renormalize
100
  probs_final = probs_final / probs_final.sum(dim=-1, keepdim=True)
101
 
102
  # sample once per batch, keep that extra dim
103
+ return torch.multinomial(probs_final, num_samples=1) # shape (batch,1)
104
 
105
 
106
  @dataclass
 
123
  super().__init__()
124
  self.config = config
125
 
126
+ # Text+audio backbone
127
  self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
128
  # Audio decoder
129
  self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
 
169
  ) -> torch.Tensor:
170
  dtype = next(self.parameters()).dtype
171
 
172
+ # Backbone pass
173
+ bb_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
174
  embeds = self._embed_tokens(tokens)
175
+ h = self.backbone(
176
+ (embeds * tokens_mask.unsqueeze(-1)).sum(dim=2),
177
+ input_pos=input_pos,
178
+ mask=bb_mask,
179
+ ).to(dtype=dtype)
180
 
181
+ # Last hidden state
182
  last_h = h[:, -1, :] # (batch, hidden)
183
  last_h_unsq = last_h.unsqueeze(1) # (batch,1,hidden)
184
 
185
+ # --- codebook 0 ---
186
+ c0_logits = self.codebook0_head(last_h) # (batch, vocab)
187
  c0_sample = sample_topk_topp(c0_logits, topk, top_p, temperature) # (batch,1)
188
  c0_embed = self._embed_audio(0, c0_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
189
 
190
+ # Prepare decoder input
191
  curr_h = torch.cat([last_h_unsq, c0_embed], dim=1) # (batch,2,hidden)
192
  curr_sample = c0_sample.clone() # (batch,1)
193
+ curr_pos = torch.arange(0, curr_h.size(1)).unsqueeze(0).to(tokens.device).long()
194
 
195
+ # --- remaining codebooks ---
196
  self.decoder.reset_caches()
197
  for i in range(1, self.config.audio_num_codebooks):
198
+ dec_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
199
+ dec_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=dec_mask).to(dtype=dtype)
200
 
201
+ ci_logits = torch.mm(dec_h[:, -1, :], self.audio_head[i - 1])
202
  ci_sample = sample_topk_topp(ci_logits, topk, top_p, temperature) # (batch,1)
203
  ci_embed = self._embed_audio(i, ci_sample.squeeze(-1)).unsqueeze(1) # (batch,1,hidden)
204
 
 
213
  self.decoder.reset_caches()
214
 
215
  def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
 
 
 
 
216
  ids = tokens + codebook * self.config.audio_vocab_size
217
  return self.audio_embeddings(ids)
218
 
219
  def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
 
 
 
 
220
  text_ids = tokens[:, :, -1]
221
  text_emb = self.text_embeddings(text_ids).unsqueeze(-2)
222
  audio_ids = tokens[:, :, :-1] + (