Update generate.py
Browse files- generate.py +6 -10
generate.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
-
# generate.py —
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer
|
4 |
from evo_model import EvoDecoderModel
|
5 |
-
from web_search import web_search # 🔍 Import RAG utility
|
6 |
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
|
@@ -15,11 +14,7 @@ model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
|
|
15 |
model.eval()
|
16 |
|
17 |
def generate_response(prompt, max_length=100, top_k=40):
|
18 |
-
|
19 |
-
context = web_search(prompt)
|
20 |
-
|
21 |
-
# Step 2: Prepend context to the prompt
|
22 |
-
input_text = f"[CONTEXT]\n{context}\n\nUser: {prompt}\nAssistant:"
|
23 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
24 |
|
25 |
for _ in range(max_length):
|
@@ -29,12 +24,13 @@ def generate_response(prompt, max_length=100, top_k=40):
|
|
29 |
|
30 |
# Top-k sampling
|
31 |
top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
|
32 |
-
probs = torch.softmax(top_k_probs, dim=-1)
|
33 |
-
|
|
|
34 |
|
35 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
36 |
|
37 |
-
# Stop if EOS token
|
38 |
if next_token.item() == tokenizer.eos_token_id:
|
39 |
break
|
40 |
|
|
|
1 |
+
# generate.py — Generates responses from EvoDecoderModel with Top-k sampling
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer
|
4 |
from evo_model import EvoDecoderModel
|
|
|
5 |
|
6 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
|
|
|
14 |
model.eval()
|
15 |
|
16 |
def generate_response(prompt, max_length=100, top_k=40):
|
17 |
+
input_text = f"User: {prompt}\nAssistant:"
|
|
|
|
|
|
|
|
|
18 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
19 |
|
20 |
for _ in range(max_length):
|
|
|
24 |
|
25 |
# Top-k sampling
|
26 |
top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
|
27 |
+
probs = torch.softmax(top_k_probs.squeeze(0), dim=-1) # Flatten
|
28 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
29 |
+
next_token = top_k_indices[0, sampled_index]
|
30 |
|
31 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
32 |
|
33 |
+
# Stop if EOS token
|
34 |
if next_token.item() == tokenizer.eos_token_id:
|
35 |
break
|
36 |
|