HemanM commited on
Commit
6648cc8
·
verified ·
1 Parent(s): 5829fd2

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +6 -10
generate.py CHANGED
@@ -1,8 +1,7 @@
1
- # generate.py — Evo + RAG (Web Search Integration)
2
  import torch
3
  from transformers import AutoTokenizer
4
  from evo_model import EvoDecoderModel
5
- from web_search import web_search # 🔍 Import RAG utility
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
@@ -15,11 +14,7 @@ model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
15
  model.eval()
16
 
17
  def generate_response(prompt, max_length=100, top_k=40):
18
- # 🔍 Step 1: Get web search context
19
- context = web_search(prompt)
20
-
21
- # Step 2: Prepend context to the prompt
22
- input_text = f"[CONTEXT]\n{context}\n\nUser: {prompt}\nAssistant:"
23
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
24
 
25
  for _ in range(max_length):
@@ -29,12 +24,13 @@ def generate_response(prompt, max_length=100, top_k=40):
29
 
30
  # Top-k sampling
31
  top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
32
- probs = torch.softmax(top_k_probs, dim=-1)
33
- next_token = top_k_indices[torch.multinomial(probs, 1)]
 
34
 
35
  input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
36
 
37
- # Stop if EOS token is predicted
38
  if next_token.item() == tokenizer.eos_token_id:
39
  break
40
 
 
1
+ # generate.py — Generates responses from EvoDecoderModel with Top-k sampling
2
  import torch
3
  from transformers import AutoTokenizer
4
  from evo_model import EvoDecoderModel
 
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
 
14
  model.eval()
15
 
16
  def generate_response(prompt, max_length=100, top_k=40):
17
+ input_text = f"User: {prompt}\nAssistant:"
 
 
 
 
18
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
19
 
20
  for _ in range(max_length):
 
24
 
25
  # Top-k sampling
26
  top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
27
+ probs = torch.softmax(top_k_probs.squeeze(0), dim=-1) # Flatten
28
+ sampled_index = torch.multinomial(probs, 1).item()
29
+ next_token = top_k_indices[0, sampled_index]
30
 
31
  input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
32
 
33
+ # Stop if EOS token
34
  if next_token.item() == tokenizer.eos_token_id:
35
  break
36