xlr8 commited on
Commit
be19290
·
1 Parent(s): da443d2

bugfix in top_p sampling when top_p tokens < top_k

Browse files
Files changed (1) hide show
  1. models.py +25 -14
models.py CHANGED
@@ -63,28 +63,39 @@ def _multinomial_sample_one_no_sync(probs):
63
  q = torch.empty_like(probs).exponential_(1)
64
  return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
65
 
 
 
 
66
 
67
- def sample_topk_topp(logits: torch.Tensor, topk: int, top_p: float, temperature: float):
68
  logits = logits / temperature
69
- logits = torch.nn.functional.log_softmax(logits, dim=-1)
70
- probs = torch.nn.functional.softmax(logits, dim=-1)
71
 
72
- sorted_probs, sorted_indices = torch.sort(probs, descending=True)
 
 
 
 
 
 
 
 
73
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
74
 
75
- if top_p < 1.0:
76
- sorted_mask = cumulative_probs > top_p
77
- sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
78
- sorted_mask[..., 0] = 0
79
- probs[sorted_indices[sorted_mask]] = 0.0
80
 
81
- if topk < probs.shape[-1]:
82
- topk_thresh = torch.topk(probs, topk)[0][..., -1, None]
83
- probs = torch.where(probs < topk_thresh, 0.0, probs)
84
 
85
- probs = probs / probs.sum(dim=-1, keepdim=True)
86
- return _multinomial_sample_one_no_sync(probs)
87
 
 
 
 
88
 
89
  @dataclass
90
  class ModelArgs:
 
63
  q = torch.empty_like(probs).exponential_(1)
64
  return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
65
 
66
+ def sample_topk_topp(logits, topk=50, topp=0.9, temperature=1.0):
67
+ if temperature <= 0:
68
+ raise ValueError("Temperature must be > 0")
69
 
 
70
  logits = logits / temperature
71
+ probs = torch.softmax(logits, dim=-1)
 
72
 
73
+ # Clamp topk to not exceed the vocab size
74
+ vocab_size = probs.shape[-1]
75
+ topk = min(topk, vocab_size)
76
+
77
+ # Get topk indices and probabilities
78
+ topk_probs, topk_indices = torch.topk(probs, topk, dim=-1)
79
+
80
+ # Compute cumulative probabilities for nucleus sampling
81
+ sorted_probs, sorted_indices = torch.sort(topk_probs, descending=True, dim=-1)
82
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
83
 
84
+ # Mask out tokens beyond topp threshold
85
+ topp_mask = cumulative_probs <= topp
86
+ # Always keep at least one token
87
+ topp_mask[..., 0] = True
 
88
 
89
+ # Apply mask and renormalize
90
+ masked_probs = sorted_probs * topp_mask
91
+ masked_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
92
 
93
+ # Sample from masked distribution
94
+ sample_idx = torch.multinomial(masked_probs, num_samples=1)
95
 
96
+ # Map back to the original vocab indices
97
+ chosen_index = sorted_indices.gather(-1, sample_idx).squeeze(-1)
98
+ return chosen_index
99
 
100
  @dataclass
101
  class ModelArgs: