HemanM commited on
Commit
e80297e
·
verified ·
1 Parent(s): a8e16cb

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +18 -12
generate.py CHANGED
@@ -1,14 +1,14 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from transformers import GPT2Tokenizer
4
  from evo_decoder import EvoDecoder
5
- from search_utils import web_search
6
 
7
- # 🔧 Load model and tokenizer
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
9
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
  tokenizer.pad_token = tokenizer.eos_token
11
-
12
  model = EvoDecoder(
13
  vocab_size=tokenizer.vocab_size,
14
  d_model=256,
@@ -17,18 +17,24 @@ model = EvoDecoder(
17
  dim_feedforward=512
18
  ).to(device)
19
 
 
20
  model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
21
  model.eval()
22
 
 
23
  @torch.no_grad()
24
- def generate_response(question, context="", use_rag=False, temperature=1.0):
25
- if not context and use_rag:
26
- context = web_search(question)
 
 
 
 
 
27
 
28
- prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
29
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
30
 
31
- for _ in range(128):
32
  logits = model(input_ids)
33
  logits = logits[:, -1, :] / temperature
34
  probs = F.softmax(logits, dim=-1)
@@ -38,5 +44,5 @@ def generate_response(question, context="", use_rag=False, temperature=1.0):
38
  if next_token.item() == tokenizer.eos_token_id:
39
  break
40
 
41
- output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
42
- return output[len(prompt):].strip()
 
1
  import torch
2
  import torch.nn.functional as F
 
3
  from evo_decoder import EvoDecoder
4
+ from transformers import GPT2Tokenizer
5
 
6
+ # Device
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # ✅ Load tokenizer and model
10
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
11
  tokenizer.pad_token = tokenizer.eos_token
 
12
  model = EvoDecoder(
13
  vocab_size=tokenizer.vocab_size,
14
  d_model=256,
 
17
  dim_feedforward=512
18
  ).to(device)
19
 
20
+ # ✅ Load trained weights
21
  model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
22
  model.eval()
23
 
24
+ # ✅ Response Generator
25
  @torch.no_grad()
26
+ def generate_response(prompt, max_length=128, temperature=1.0, external_context=""):
27
+ model.eval()
28
+
29
+ # ✅ Force prompt into SQuAD-style format Evo was trained on
30
+ if external_context:
31
+ full_prompt = f"Context: {external_context}\nQuestion: {prompt}\nAnswer:"
32
+ else:
33
+ full_prompt = f"Question: {prompt}\nAnswer:"
34
 
35
+ input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
 
36
 
37
+ for _ in range(max_length):
38
  logits = model(input_ids)
39
  logits = logits[:, -1, :] / temperature
40
  probs = F.softmax(logits, dim=-1)
 
44
  if next_token.item() == tokenizer.eos_token_id:
45
  break
46
 
47
+ output = tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)
48
+ return output[len(full_prompt):].strip()