HemanM commited on
Commit
7dc4300
·
verified ·
1 Parent(s): 0bd71c9

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +21 -15
generate.py CHANGED
@@ -1,22 +1,28 @@
1
  import torch
2
- from torch.nn import functional as F
 
3
 
4
- def generate_text(model, tokenizer, prompt, max_length=100, temperature=1.0):
 
 
 
5
  model.eval()
6
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
7
 
8
- generated = input_ids
9
- memory = torch.zeros((1, input_ids.size(1), model.config.d_model)).to(model.device)
10
 
11
- with torch.no_grad():
12
- for _ in range(max_length):
13
- outputs = model(generated, memory)
14
- next_token_logits = outputs[:, -1, :] / temperature
15
- probs = F.softmax(next_token_logits, dim=-1)
16
- next_token = torch.multinomial(probs, num_samples=1)
17
- generated = torch.cat((generated, next_token), dim=1)
18
 
19
- if next_token.item() == tokenizer.eos_token_id:
20
- break
21
 
22
- return tokenizer.decode(generated[0], skip_special_tokens=True)
 
 
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ from evo_model import EvoDecoderModel
4
 
5
+ def load_model(vocab_size, model_path="evo_decoder.pt", device="cpu"):
6
+ model = EvoDecoderModel(vocab_size)
7
+ model.load_state_dict(torch.load(model_path, map_location=device))
8
+ model.to(device)
9
  model.eval()
10
+ return model
11
 
12
+ def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0, device="cpu"):
13
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
14
 
15
+ for _ in range(max_new_tokens):
16
+ with torch.no_grad():
17
+ logits = model(input_ids)
18
+ next_token_logits = logits[:, -1, :] / temperature
19
+ probs = F.softmax(next_token_logits, dim=-1)
20
+ next_token_id = torch.multinomial(probs, num_samples=1)
 
21
 
22
+ input_ids = torch.cat([input_ids, next_token_id], dim=-1)
 
23
 
24
+ if next_token_id.item() == tokenizer.eos_token_id:
25
+ break
26
+
27
+ output_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
28
+ return output_text