HemanM commited on
Commit
2b45a2a
·
verified ·
1 Parent(s): 2474a23

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +18 -13
generate.py CHANGED
@@ -1,14 +1,17 @@
1
- # generate.py — Generates responses from EvoDecoderModel with Top-k sampling
2
  import torch
3
- from transformers import AutoTokenizer
4
  from evo_model import EvoDecoderModel
5
 
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- # Load tokenizer and model
9
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
10
  vocab_size = tokenizer.vocab_size
11
 
 
12
  model = EvoDecoderModel(vocab_size=vocab_size).to(device)
13
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
14
  model.eval()
@@ -20,19 +23,21 @@ def generate_response(prompt, max_length=100, top_k=40):
20
  for _ in range(max_length):
21
  with torch.no_grad():
22
  logits = model(input_ids)
23
- next_token_logits = logits[:, -1, :]
 
 
 
 
24
 
25
  # Top-k sampling
26
- top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
27
- probs = torch.softmax(top_k_probs.squeeze(0), dim=-1)
28
- sampled_index = torch.multinomial(probs, 1).item()
29
- next_token = top_k_indices[0, sampled_index]
30
 
31
- # Reshape next_token to match input_ids shape
32
- next_token = next_token.unsqueeze(0).unsqueeze(0) # Shape: (1, 1)
33
- input_ids = torch.cat([input_ids, next_token], dim=1)
34
 
35
- if next_token.item() == tokenizer.eos_token_id:
 
36
  break
37
 
38
  output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
1
+ # generate.py — Generates response using EvoDecoderModel with GPT2 tokenizer and top-k/p sampling
2
  import torch
3
+ from transformers import GPT2Tokenizer
4
  from evo_model import EvoDecoderModel
5
 
6
+ # Set device
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ # Load GPT2 tokenizer (better for decoding tasks)
10
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
11
+ tokenizer.pad_token = tokenizer.eos_token # Safe fallback
12
  vocab_size = tokenizer.vocab_size
13
 
14
+ # Load trained EvoDecoder model
15
  model = EvoDecoderModel(vocab_size=vocab_size).to(device)
16
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
17
  model.eval()
 
23
  for _ in range(max_length):
24
  with torch.no_grad():
25
  logits = model(input_ids)
26
+ next_token_logits = logits[:, -1, :].squeeze(0)
27
+
28
+ # Apply repetition penalty
29
+ for token_id in set(input_ids.view(-1).tolist()):
30
+ next_token_logits[token_id] *= 0.8
31
 
32
  # Top-k sampling
33
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k)
34
+ probs = torch.softmax(top_k_logits, dim=-1)
35
+ next_token = top_k_indices[torch.multinomial(probs, num_samples=1)].unsqueeze(0)
 
36
 
37
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
 
 
38
 
39
+ # Stop on EOS
40
+ if tokenizer.eos_token_id and next_token.item() == tokenizer.eos_token_id:
41
  break
42
 
43
  output = tokenizer.decode(input_ids[0], skip_special_tokens=True)