HemanM commited on
Commit
00ea8bb
·
verified ·
1 Parent(s): 558d45c

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +18 -35
generate.py CHANGED
@@ -1,47 +1,30 @@
1
- # generate.py — EvoDecoder response generation with optional DuckDuckGo RAG
2
  import torch
3
- from transformers import AutoTokenizer
4
  from evo_model import EvoDecoderModel
5
- from search_utils import web_search
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
- # Load tokenizer and model
10
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
- vocab_size = tokenizer.vocab_size
12
 
13
- model = EvoDecoderModel(vocab_size=vocab_size).to(device)
 
 
14
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
 
15
  model.eval()
16
 
17
- def generate_response(prompt, use_web=False, max_length=100, top_k=40):
18
- # Augment with web context if enabled
19
- context = ""
20
- if use_web:
21
- web_context = web_search(prompt)
22
- context += f"Relevant Info: {web_context}\n"
23
-
24
- input_text = context + f"User: {prompt}\nAssistant:"
25
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
26
-
27
- # Truncate to avoid exceeding model's positional encoding limit
28
- if input_ids.size(1) > 512:
29
- input_ids = input_ids[:, -512:]
30
-
31
- for _ in range(max_length):
32
- with torch.no_grad():
33
- logits = model(input_ids)
34
- next_token_logits = logits[:, -1, :]
35
-
36
- # Top-k sampling
37
- top_k_probs, top_k_indices = torch.topk(next_token_logits, top_k)
38
- probs = torch.softmax(top_k_probs, dim=-1)
39
- next_token = top_k_indices[0, torch.multinomial(probs, 1).item()].unsqueeze(0).unsqueeze(0)
40
 
41
- input_ids = torch.cat([input_ids, next_token], dim=1)
 
 
42
 
43
- if next_token.item() == tokenizer.eos_token_id:
44
- break
45
 
46
- output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
47
- return output.split("Assistant:")[-1].strip()
 
1
+ # generate.py
2
  import torch
3
+ from transformers import BertTokenizer
4
  from evo_model import EvoDecoderModel
 
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
+ # Load tokenizer
9
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 
10
 
11
+ # Initialize model architecture
12
+ vocab_size = tokenizer.vocab_size
13
+ model = EvoDecoderModel(vocab_size=vocab_size)
14
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
15
+ model.to(device)
16
  model.eval()
17
 
18
+ def generate_response(prompt, max_length=128, use_web=False):
19
+ with torch.no_grad():
20
+ input_ids = tokenizer(prompt, return_tensors="pt", padding=False, truncation=True).input_ids.to(device)
21
+ input_ids = input_ids[:, :128] # ✅ clip to trained length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ logits = model(input_ids)
24
+ next_token_logits = logits[:, -1, :] # take last token's logits
25
+ predicted_id = torch.argmax(next_token_logits, dim=-1)
26
 
27
+ output_ids = torch.cat([input_ids, predicted_id.unsqueeze(0)], dim=1)
28
+ decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
29
 
30
+ return decoded