HemanM commited on
Commit
d7a4aba
·
verified ·
1 Parent(s): 5259900

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +17 -10
generate.py CHANGED
@@ -1,24 +1,31 @@
1
  # generate.py
2
  import torch
3
- from transformers import BertTokenizer
4
  from evo_model import EvoDecoderModel
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 
9
  vocab_size = tokenizer.vocab_size
10
 
11
- model = EvoDecoderModel(vocab_size=vocab_size)
12
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
13
- model.to(device)
14
  model.eval()
15
 
16
- def generate_response(prompt, max_length=128, use_web=False):
 
 
 
 
 
17
  with torch.no_grad():
18
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids[:, :128].to(device)
19
  logits = model(input_ids)
20
- next_token_logits = logits[:, -1, :]
21
- next_token_id = torch.argmax(next_token_logits, dim=-1)
22
 
23
- full_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
24
- return tokenizer.decode(full_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
1
  # generate.py
2
  import torch
3
+ from transformers import AutoTokenizer
4
  from evo_model import EvoDecoderModel
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
+ # Load tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
  vocab_size = tokenizer.vocab_size
11
 
12
+ model = EvoDecoderModel(vocab_size=vocab_size).to(device)
13
  model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
 
14
  model.eval()
15
 
16
+ def generate_response(prompt, use_web=False):
17
+ # Tokenize
18
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
19
+ input_ids = inputs["input_ids"].to(device)
20
+
21
+ # Predict
22
  with torch.no_grad():
 
23
  logits = model(input_ids)
 
 
24
 
25
+ # Take last token's logits and get predicted token
26
+ next_token_logits = logits[0, -1] # shape: (vocab_size,)
27
+ predicted_token_id = torch.argmax(next_token_logits).item()
28
+
29
+ # Decode to word
30
+ predicted_token = tokenizer.decode([predicted_token_id])
31
+ return predicted_token