HemanM commited on
Commit
8469bea
·
verified ·
1 Parent(s): 4ffcc65

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +20 -23
generate.py CHANGED
@@ -1,43 +1,40 @@
1
- # generate.py — Generates response using EvoDecoderModel with GPT2 tokenizer and top-k/p sampling
2
  import torch
3
- from transformers import GPT2Tokenizer
4
  from evo_model import EvoDecoderModel
 
5
 
6
- # Set device
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
- # Load GPT2 tokenizer (better for decoding tasks)
10
- tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
11
- tokenizer.pad_token = tokenizer.eos_token # Safe fallback
12
  vocab_size = tokenizer.vocab_size
13
 
14
- # Load trained EvoDecoder model
15
  model = EvoDecoderModel(vocab_size=vocab_size).to(device)
16
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
17
  model.eval()
18
 
19
- def generate_response(prompt, max_length=100, top_k=40):
20
- input_text = f"User: {prompt}\nAssistant:"
21
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
 
 
 
 
 
22
 
23
  for _ in range(max_length):
24
  with torch.no_grad():
25
  logits = model(input_ids)
26
- next_token_logits = logits[:, -1, :].squeeze(0)
27
-
28
- # Apply repetition penalty
29
- for token_id in set(input_ids.view(-1).tolist()):
30
- next_token_logits[token_id] *= 0.8
31
-
32
- # Top-k sampling
33
- top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k)
34
- probs = torch.softmax(top_k_logits, dim=-1)
35
- next_token = top_k_indices[torch.multinomial(probs, num_samples=1)].unsqueeze(0)
36
 
37
- input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
 
38
 
39
- # Stop on EOS
40
- if tokenizer.eos_token_id and next_token.item() == tokenizer.eos_token_id:
41
  break
42
 
43
  output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
1
+ # generate.py — Generates EvoDecoder responses with optional live web context
2
  import torch
3
+ from transformers import AutoTokenizer
4
  from evo_model import EvoDecoderModel
5
+ from search_utils import web_search
6
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ # Load tokenizer and model
10
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
11
  vocab_size = tokenizer.vocab_size
12
 
 
13
  model = EvoDecoderModel(vocab_size=vocab_size).to(device)
14
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
15
  model.eval()
16
 
17
+ def generate_response(prompt, use_web=False, max_length=100, top_k=40):
18
+ if use_web:
19
+ context = web_search(prompt)
20
+ prompt = f"Relevant Info: {context}\nUser: {prompt}\nAssistant:"
21
+ else:
22
+ prompt = f"User: {prompt}\nAssistant:"
23
+
24
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
25
 
26
  for _ in range(max_length):
27
  with torch.no_grad():
28
  logits = model(input_ids)
29
+ next_token_logits = logits[:, -1, :]
30
+ top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
31
+ probs = torch.softmax(top_k_probs, dim=-1)
32
+ next_token = top_k_indices[0, torch.multinomial(probs, 1)]
 
 
 
 
 
 
33
 
34
+ next_token = next_token.unsqueeze(0).unsqueeze(0) # (1,1)
35
+ input_ids = torch.cat([input_ids, next_token], dim=1)
36
 
37
+ if next_token.item() == tokenizer.eos_token_id:
 
38
  break
39
 
40
  output = tokenizer.decode(input_ids[0], skip_special_tokens=True)