HemanM commited on
Commit
b28676e
·
verified ·
1 Parent(s): 32221da

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +20 -18
generate.py CHANGED
@@ -1,29 +1,31 @@
1
  import torch
 
 
2
  from transformers import GPT2Tokenizer
3
- from evo_model import EvoDecoderModel
4
 
5
  # Load tokenizer
6
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
- tokenizer.pad_token = tokenizer.eos_token
8
 
9
- # Model configuration (must match evo_model.py and trained weights)
10
- vocab_size = tokenizer.vocab_size
11
- model = EvoDecoderModel(vocab_size=vocab_size, d_model=256, nhead=4, num_layers=3)
12
- model.load_state_dict(torch.load("evo_decoder.pt", map_location=torch.device("cpu")))
13
  model.eval()
14
 
15
- def generate_response(prompt, max_length=100):
16
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
17
- generated = input_ids.clone()
 
18
 
19
- with torch.no_grad():
20
- for _ in range(max_length):
21
- output = model(generated, memory=None)
22
- next_token_logits = output[:, -1, :]
23
- next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
24
- generated = torch.cat((generated, next_token), dim=1)
25
 
26
- if next_token.item() == tokenizer.eos_token_id:
27
- break
28
 
29
- return tokenizer.decode(generated[0], skip_special_tokens=True)
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ from evo_model import EvoDecoder
4
  from transformers import GPT2Tokenizer
 
5
 
6
  # Load tokenizer
7
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 
8
 
9
+ # Load trained model
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model = EvoDecoder(vocab_size=tokenizer.vocab_size, d_model=512, nhead=8, num_layers=6).to(device)
12
+ model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
13
  model.eval()
14
 
15
+ @torch.no_grad()
16
+ def generate_response(prompt, max_length=50, temperature=1.0):
17
+ model.eval()
18
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
19
 
20
+ for _ in range(max_length):
21
+ logits = model(input_ids)
22
+ logits = logits[:, -1, :] / temperature
23
+ probs = F.softmax(logits, dim=-1)
24
+ next_token = torch.multinomial(probs, num_samples=1)
25
+ input_ids = torch.cat((input_ids, next_token), dim=1)
26
 
27
+ if next_token.item() == tokenizer.eos_token_id:
28
+ break
29
 
30
+ output = tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)
31
+ return output[len(prompt):].strip()