HemanM commited on
Commit
b37c655
·
verified ·
1 Parent(s): 5eb2122

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +12 -23
generate.py CHANGED
@@ -1,14 +1,14 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from evo_decoder import EvoDecoder
4
  from transformers import GPT2Tokenizer
 
 
 
 
5
 
6
- # Load tokenizer and model
7
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
  tokenizer.pad_token = tokenizer.eos_token
9
 
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
-
12
  model = EvoDecoder(
13
  vocab_size=tokenizer.vocab_size,
14
  d_model=256,
@@ -21,17 +21,12 @@ model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
21
  model.eval()
22
 
23
  @torch.no_grad()
24
- def generate_response(prompt, max_length=50, temperature=1.0, external_context=""):
25
- """
26
- Generate text using EvoDecoder with optional external RAG context.
27
- """
28
- model.eval()
29
-
30
- # Combine external context with prompt
31
- full_prompt = (external_context.strip() + "\n\n" + prompt.strip()) if external_context else prompt.strip()
32
-
33
- # Truncate if input too long
34
- input_ids = tokenizer.encode(full_prompt, return_tensors="pt", truncation=True, max_length=256).to(device)
35
 
36
  for _ in range(max_length):
37
  logits = model(input_ids)
@@ -39,14 +34,8 @@ def generate_response(prompt, max_length=50, temperature=1.0, external_context="
39
  probs = F.softmax(logits, dim=-1)
40
  next_token = torch.multinomial(probs, num_samples=1)
41
  input_ids = torch.cat((input_ids, next_token), dim=1)
42
-
43
- # Break on EOS
44
  if next_token.item() == tokenizer.eos_token_id:
45
  break
46
 
47
- # Prevent overflow
48
- if input_ids.shape[1] >= 256:
49
- break
50
-
51
- output = tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)
52
- return output[len(full_prompt):].strip()
 
1
  import torch
2
  import torch.nn.functional as F
 
3
  from transformers import GPT2Tokenizer
4
+ from evo_decoder import EvoDecoder
5
+ from search_utils import web_search # Optional RAG fallback
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
 
9
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
  tokenizer.pad_token = tokenizer.eos_token
11
 
 
 
12
  model = EvoDecoder(
13
  vocab_size=tokenizer.vocab_size,
14
  d_model=256,
 
21
  model.eval()
22
 
23
  @torch.no_grad()
24
+ def generate_response(question, context="", use_rag=False, temperature=1.0, max_length=100):
25
+ if use_rag and not context:
26
+ context = web_search(question)
27
+
28
+ prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
29
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
 
 
 
30
 
31
  for _ in range(max_length):
32
  logits = model(input_ids)
 
34
  probs = F.softmax(logits, dim=-1)
35
  next_token = torch.multinomial(probs, num_samples=1)
36
  input_ids = torch.cat((input_ids, next_token), dim=1)
 
 
37
  if next_token.item() == tokenizer.eos_token_id:
38
  break
39
 
40
+ output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
41
+ return output[len(prompt):].strip()