Spaces:
Sleeping
Sleeping
xlr8
commited on
Commit
·
be19290
1
Parent(s):
da443d2
bugfix in top_p sampling when top_p tokens < top_k
Browse files
models.py
CHANGED
|
@@ -63,28 +63,39 @@ def _multinomial_sample_one_no_sync(probs):
|
|
| 63 |
q = torch.empty_like(probs).exponential_(1)
|
| 64 |
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
def sample_topk_topp(logits: torch.Tensor, topk: int, top_p: float, temperature: float):
|
| 68 |
logits = logits / temperature
|
| 69 |
-
|
| 70 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
probs[sorted_indices[sorted_mask]] = 0.0
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
@dataclass
|
| 90 |
class ModelArgs:
|
|
|
|
| 63 |
q = torch.empty_like(probs).exponential_(1)
|
| 64 |
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| 65 |
|
| 66 |
+
def sample_topk_topp(logits, topk=50, topp=0.9, temperature=1.0):
|
| 67 |
+
if temperature <= 0:
|
| 68 |
+
raise ValueError("Temperature must be > 0")
|
| 69 |
|
|
|
|
| 70 |
logits = logits / temperature
|
| 71 |
+
probs = torch.softmax(logits, dim=-1)
|
|
|
|
| 72 |
|
| 73 |
+
# Clamp topk to not exceed the vocab size
|
| 74 |
+
vocab_size = probs.shape[-1]
|
| 75 |
+
topk = min(topk, vocab_size)
|
| 76 |
+
|
| 77 |
+
# Get topk indices and probabilities
|
| 78 |
+
topk_probs, topk_indices = torch.topk(probs, topk, dim=-1)
|
| 79 |
+
|
| 80 |
+
# Compute cumulative probabilities for nucleus sampling
|
| 81 |
+
sorted_probs, sorted_indices = torch.sort(topk_probs, descending=True, dim=-1)
|
| 82 |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 83 |
|
| 84 |
+
# Mask out tokens beyond topp threshold
|
| 85 |
+
topp_mask = cumulative_probs <= topp
|
| 86 |
+
# Always keep at least one token
|
| 87 |
+
topp_mask[..., 0] = True
|
|
|
|
| 88 |
|
| 89 |
+
# Apply mask and renormalize
|
| 90 |
+
masked_probs = sorted_probs * topp_mask
|
| 91 |
+
masked_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
|
| 92 |
|
| 93 |
+
# Sample from masked distribution
|
| 94 |
+
sample_idx = torch.multinomial(masked_probs, num_samples=1)
|
| 95 |
|
| 96 |
+
# Map back to the original vocab indices
|
| 97 |
+
chosen_index = sorted_indices.gather(-1, sample_idx).squeeze(-1)
|
| 98 |
+
return chosen_index
|
| 99 |
|
| 100 |
@dataclass
|
| 101 |
class ModelArgs:
|