HemanM commited on
Commit
74435ef
·
verified ·
1 Parent(s): 2f5aba0

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +4 -3
generate.py CHANGED
@@ -24,13 +24,14 @@ def generate_response(prompt, max_length=100, top_k=40):
24
 
25
  # Top-k sampling
26
  top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
27
- probs = torch.softmax(top_k_probs.squeeze(0), dim=-1) # Flatten
28
  sampled_index = torch.multinomial(probs, 1).item()
29
  next_token = top_k_indices[0, sampled_index]
30
 
31
- input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
 
 
32
 
33
- # Stop if EOS token
34
  if next_token.item() == tokenizer.eos_token_id:
35
  break
36
 
 
24
 
25
  # Top-k sampling
26
  top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
27
+ probs = torch.softmax(top_k_probs.squeeze(0), dim=-1)
28
  sampled_index = torch.multinomial(probs, 1).item()
29
  next_token = top_k_indices[0, sampled_index]
30
 
31
+ # Reshape next_token to match input_ids shape
32
+ next_token = next_token.unsqueeze(0).unsqueeze(0) # Shape: (1, 1)
33
+ input_ids = torch.cat([input_ids, next_token], dim=1)
34
 
 
35
  if next_token.item() == tokenizer.eos_token_id:
36
  break
37