HemanM commited on
Commit
f7ab7a4
·
verified ·
1 Parent(s): 1070f67

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +16 -29
generate.py CHANGED
@@ -1,35 +1,22 @@
1
- # generate.py
2
  import torch
3
- from transformers import AutoTokenizer
4
- from evo_model import EvoDecoderModel
5
 
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
7
 
8
- # Load tokenizer and model
9
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
- vocab_size = tokenizer.vocab_size
11
 
12
- model = EvoDecoderModel(vocab_size=vocab_size).to(device)
13
- model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device))
14
- model.eval()
 
 
 
 
15
 
16
- def generate_response(prompt, max_new_tokens=50, use_web=False):
17
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
18
- input_ids = inputs["input_ids"].to(device)
19
 
20
- for _ in range(max_new_tokens):
21
- with torch.no_grad():
22
- logits = model(input_ids)
23
-
24
- next_token_logits = logits[:, -1, :] # shape (B, vocab_size)
25
- next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) # shape (1, 1)
26
-
27
- # Append to input
28
- input_ids = torch.cat([input_ids, next_token_id], dim=1)
29
-
30
- # Stop if EOS token
31
- if next_token_id.item() in tokenizer.all_special_ids:
32
- break
33
-
34
- output_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
35
- return output_text[len(prompt):].strip()
 
 
1
  import torch
2
+ from torch.nn import functional as F
 
3
 
4
+ def generate_text(model, tokenizer, prompt, max_length=100, temperature=1.0):
5
+ model.eval()
6
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
7
 
8
+ generated = input_ids
9
+ memory = torch.zeros((1, input_ids.size(1), model.config.d_model)).to(model.device)
 
10
 
11
+ with torch.no_grad():
12
+ for _ in range(max_length):
13
+ outputs = model(generated, memory)
14
+ next_token_logits = outputs[:, -1, :] / temperature
15
+ probs = F.softmax(next_token_logits, dim=-1)
16
+ next_token = torch.multinomial(probs, num_samples=1)
17
+ generated = torch.cat((generated, next_token), dim=1)
18
 
19
+ if next_token.item() == tokenizer.eos_token_id:
20
+ break
 
21
 
22
+ return tokenizer.decode(generated[0], skip_special_tokens=True)