HemanM commited on
Commit
f718bd4
·
verified ·
1 Parent(s): 64f4483

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +6 -5
generate.py CHANGED
@@ -2,10 +2,10 @@ import torch
2
  import torch.nn.functional as F
3
  from transformers import GPT2Tokenizer
4
  from evo_decoder import EvoDecoder
5
- from search_utils import web_search # Optional RAG fallback
6
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
  tokenizer.pad_token = tokenizer.eos_token
11
 
@@ -21,19 +21,20 @@ model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
21
  model.eval()
22
 
23
  @torch.no_grad()
24
- def generate_response(question, context="", use_rag=False, temperature=1.0, max_length=100):
25
- if use_rag and not context:
26
  context = web_search(question)
27
 
28
  prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
29
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
30
 
31
- for _ in range(max_length):
32
  logits = model(input_ids)
33
  logits = logits[:, -1, :] / temperature
34
  probs = F.softmax(logits, dim=-1)
35
  next_token = torch.multinomial(probs, num_samples=1)
36
  input_ids = torch.cat((input_ids, next_token), dim=1)
 
37
  if next_token.item() == tokenizer.eos_token_id:
38
  break
39
 
 
2
  import torch.nn.functional as F
3
  from transformers import GPT2Tokenizer
4
  from evo_decoder import EvoDecoder
5
+ from search_utils import web_search
6
 
7
+ # 🔧 Load model and tokenizer
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
9
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
  tokenizer.pad_token = tokenizer.eos_token
11
 
 
21
  model.eval()
22
 
23
  @torch.no_grad()
24
+ def generate_response(question, context="", use_rag=False, temperature=1.0):
25
+ if not context and use_rag:
26
  context = web_search(question)
27
 
28
  prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
29
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
30
 
31
+ for _ in range(128):
32
  logits = model(input_ids)
33
  logits = logits[:, -1, :] / temperature
34
  probs = F.softmax(logits, dim=-1)
35
  next_token = torch.multinomial(probs, num_samples=1)
36
  input_ids = torch.cat((input_ids, next_token), dim=1)
37
+
38
  if next_token.item() == tokenizer.eos_token_id:
39
  break
40