HemanM commited on
Commit
3d17dd0
·
verified ·
1 Parent(s): a9b4cfb

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +8 -4
generate.py CHANGED
@@ -3,18 +3,22 @@ import torch.nn.functional as F
3
  from evo_model import EvoDecoder
4
  from transformers import GPT2Tokenizer
5
 
6
- # Load tokenizer
7
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
 
9
- # Load trained model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- model = EvoDecoder(vocab_size=tokenizer.vocab_size, d_model=512, nhead=8, num_layers=6).to(device)
 
 
 
 
 
 
 
12
  model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
13
  model.eval()
14
 
15
  @torch.no_grad()
16
  def generate_response(prompt, max_length=50, temperature=1.0):
17
- model.eval()
18
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
19
 
20
  for _ in range(max_length):
 
3
  from evo_model import EvoDecoder
4
  from transformers import GPT2Tokenizer
5
 
 
6
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
 
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model = EvoDecoder(
10
+ vocab_size=tokenizer.vocab_size,
11
+ d_model=256,
12
+ nhead=4,
13
+ num_layers=3,
14
+ dim_feedforward=1024
15
+ ).to(device)
16
+
17
  model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
18
  model.eval()
19
 
20
  @torch.no_grad()
21
  def generate_response(prompt, max_length=50, temperature=1.0):
 
22
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
23
 
24
  for _ in range(max_length):