Update generate.py
Browse files- generate.py +4 -3
generate.py
CHANGED
@@ -24,13 +24,14 @@ def generate_response(prompt, max_length=100, top_k=40):
|
|
24 |
|
25 |
# Top-k sampling
|
26 |
top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
|
27 |
-
probs = torch.softmax(top_k_probs.squeeze(0), dim=-1)
|
28 |
sampled_index = torch.multinomial(probs, 1).item()
|
29 |
next_token = top_k_indices[0, sampled_index]
|
30 |
|
31 |
-
|
|
|
|
|
32 |
|
33 |
-
# Stop if EOS token
|
34 |
if next_token.item() == tokenizer.eos_token_id:
|
35 |
break
|
36 |
|
|
|
24 |
|
25 |
# Top-k sampling
|
26 |
top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
|
27 |
+
probs = torch.softmax(top_k_probs.squeeze(0), dim=-1)
|
28 |
sampled_index = torch.multinomial(probs, 1).item()
|
29 |
next_token = top_k_indices[0, sampled_index]
|
30 |
|
31 |
+
# Reshape next_token to match input_ids shape
|
32 |
+
next_token = next_token.unsqueeze(0).unsqueeze(0) # Shape: (1, 1)
|
33 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
34 |
|
|
|
35 |
if next_token.item() == tokenizer.eos_token_id:
|
36 |
break
|
37 |
|