HemanM commited on
Commit
ccae0a9
·
verified ·
1 Parent(s): e7984f7

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +21 -20
generate.py CHANGED
@@ -1,28 +1,29 @@
1
  import torch
2
- import torch.nn.functional as F
3
  from evo_model import EvoDecoderModel
4
 
5
- def load_model(vocab_size, model_path="evo_decoder.pt", device="cpu"):
6
- model = EvoDecoderModel(vocab_size)
7
- model.load_state_dict(torch.load(model_path, map_location=device))
8
- model.to(device)
9
- model.eval()
10
- return model
11
 
12
- def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0, device="cpu"):
13
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
 
14
 
15
- for _ in range(max_new_tokens):
16
- with torch.no_grad():
17
- logits = model(input_ids)
18
- next_token_logits = logits[:, -1, :] / temperature
19
- probs = F.softmax(next_token_logits, dim=-1)
20
- next_token_id = torch.multinomial(probs, num_samples=1)
21
 
22
- input_ids = torch.cat([input_ids, next_token_id], dim=-1)
 
 
 
 
 
23
 
24
- if next_token_id.item() == tokenizer.eos_token_id:
25
- break
26
 
27
- output_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
28
- return output_text
 
1
  import torch
2
+ from transformers import GPT2Tokenizer
3
  from evo_model import EvoDecoderModel
4
 
5
+ # Load tokenizer
6
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
8
 
9
+ # Model configuration (must match evo_model.py and trained weights)
10
+ vocab_size = tokenizer.vocab_size
11
+ model = EvoDecoderModel(vocab_size=vocab_size, d_model=256, nhead=4, num_layers=3)
12
+ model.load_state_dict(torch.load("evo_decoder.pt", map_location=torch.device("cpu")))
13
+ model.eval()
14
 
15
+ def generate_response(prompt, max_length=100):
16
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
17
+ generated = input_ids.clone()
 
 
 
18
 
19
+ with torch.no_grad():
20
+ for _ in range(max_length):
21
+ output = model(generated, memory=None)
22
+ next_token_logits = output[:, -1, :]
23
+ next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
24
+ generated = torch.cat((generated, next_token), dim=1)
25
 
26
+ if next_token.item() == tokenizer.eos_token_id:
27
+ break
28
 
29
+ return tokenizer.decode(generated[0], skip_special_tokens=True)