File size: 1,522 Bytes
1c98e83 b28676e b37c655 e80297e b37c655 e80297e b37c655 e80297e ccae0a9 100e3bb 3d17dd0 ef57a4b 3d17dd0 e80297e b28676e ccae0a9 5259900 e80297e b28676e e80297e b37c655 e80297e 1c98e83 e80297e b28676e f718bd4 b28676e 7dc4300 e80297e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
import torch.nn.functional as F
from evo_decoder import EvoDecoder
from transformers import GPT2Tokenizer
# ✅ Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ✅ Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = EvoDecoder(
vocab_size=tokenizer.vocab_size,
d_model=256,
nhead=4,
num_layers=3,
dim_feedforward=512
).to(device)
# ✅ Load trained weights
model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
model.eval()
# ✅ Response Generator
@torch.no_grad()
def generate_response(prompt, max_length=128, temperature=1.0, external_context=""):
model.eval()
# ✅ Force prompt into SQuAD-style format Evo was trained on
if external_context:
full_prompt = f"Context: {external_context}\nQuestion: {prompt}\nAnswer:"
else:
full_prompt = f"Question: {prompt}\nAnswer:"
input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
for _ in range(max_length):
logits = model(input_ids)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat((input_ids, next_token), dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
output = tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)
return output[len(full_prompt):].strip()
|