HemanM commited on
Commit
ca6258b
·
verified ·
1 Parent(s): 5a8590e

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +12 -31
generate.py CHANGED
@@ -1,14 +1,12 @@
1
- # generate.py — Generates responses from EvoDecoder with optional web-based RAG
2
-
3
  import torch
4
  from transformers import AutoTokenizer
5
  from evo_model import EvoDecoderModel
6
- from search_utils import web_search # Make sure this file exists with a working `web_search()` function
7
 
8
- # Set device
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
- # Load tokenizer and EvoDecoder model
12
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
13
  vocab_size = tokenizer.vocab_size
14
 
@@ -16,47 +14,30 @@ model = EvoDecoderModel(vocab_size=vocab_size).to(device)
16
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
17
  model.eval()
18
 
19
- def generate_response(prompt, max_length=100, top_k=40, use_web=False):
20
- """
21
- Generates a response using EvoDecoder with optional web-enhanced context (RAG).
22
-
23
- Args:
24
- prompt (str): User input prompt.
25
- max_length (int): Maximum number of tokens to generate.
26
- top_k (int): Top-k sampling for diversity.
27
- use_web (bool): Whether to augment prompt using live search.
28
-
29
- Returns:
30
- str: The generated assistant response.
31
- """
32
-
33
- # Add RAG-based context if enabled
34
  if use_web:
35
  web_context = web_search(prompt)
36
- input_text = f"User: {prompt}\n\nContext: {web_context}\n\nAssistant:"
37
- else:
38
- input_text = f"User: {prompt}\nAssistant:"
39
 
40
- # Tokenize input prompt
41
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
42
 
43
- # Generate tokens autoregressively
44
  for _ in range(max_length):
45
  with torch.no_grad():
46
  logits = model(input_ids)
47
  next_token_logits = logits[:, -1, :]
 
 
48
  top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
49
  probs = torch.softmax(top_k_probs, dim=-1)
50
- next_token = top_k_indices[0, torch.multinomial(probs, 1)]
51
 
52
- next_token = next_token.unsqueeze(0).unsqueeze(1) # (1, 1)
53
  input_ids = torch.cat([input_ids, next_token], dim=1)
54
 
55
  if next_token.item() == tokenizer.eos_token_id:
56
  break
57
 
58
- # Decode and return assistant's response only
59
  output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
60
- if "Assistant:" in output:
61
- return output.split("Assistant:")[-1].strip()
62
- return output.strip()
 
1
+ # generate.py — EvoDecoder response generation with optional DuckDuckGo RAG
 
2
  import torch
3
  from transformers import AutoTokenizer
4
  from evo_model import EvoDecoderModel
5
+ from search_utils import web_search
6
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ # Load tokenizer and model
10
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
  vocab_size = tokenizer.vocab_size
12
 
 
14
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
15
  model.eval()
16
 
17
+ def generate_response(prompt, use_web=False, max_length=100, top_k=40):
18
+ # Augment with web context if enabled
19
+ context = ""
 
 
 
 
 
 
 
 
 
 
 
 
20
  if use_web:
21
  web_context = web_search(prompt)
22
+ context += f"Relevant Info: {web_context}\n"
 
 
23
 
24
+ input_text = context + f"User: {prompt}\nAssistant:"
25
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
26
 
 
27
  for _ in range(max_length):
28
  with torch.no_grad():
29
  logits = model(input_ids)
30
  next_token_logits = logits[:, -1, :]
31
+
32
+ # Top-k sampling
33
  top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
34
  probs = torch.softmax(top_k_probs, dim=-1)
35
+ next_token = top_k_indices[0, torch.multinomial(probs, 1).item()].unsqueeze(0).unsqueeze(0)
36
 
 
37
  input_ids = torch.cat([input_ids, next_token], dim=1)
38
 
39
  if next_token.item() == tokenizer.eos_token_id:
40
  break
41
 
 
42
  output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
43
+ return output.split("Assistant:")[-1].strip()