File size: 1,522 Bytes
1c98e83
b28676e
b37c655
e80297e
b37c655
e80297e
b37c655
e80297e
 
ccae0a9
100e3bb
3d17dd0
 
 
 
 
ef57a4b
3d17dd0
 
e80297e
b28676e
ccae0a9
5259900
e80297e
b28676e
e80297e
 
 
 
 
 
 
 
b37c655
e80297e
1c98e83
e80297e
b28676e
 
 
 
 
f718bd4
b28676e
 
7dc4300
e80297e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn.functional as F
from evo_decoder import EvoDecoder
from transformers import GPT2Tokenizer

# ✅ Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = EvoDecoder(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    nhead=4,
    num_layers=3,
    dim_feedforward=512
).to(device)

# ✅ Load trained weights
model.load_state_dict(torch.load("evo_decoder.pt", map_location=device))
model.eval()

# ✅ Response Generator
@torch.no_grad()
def generate_response(prompt, max_length=128, temperature=1.0, external_context=""):
    model.eval()

    # ✅ Force prompt into SQuAD-style format Evo was trained on
    if external_context:
        full_prompt = f"Context: {external_context}\nQuestion: {prompt}\nAnswer:"
    else:
        full_prompt = f"Question: {prompt}\nAnswer:"

    input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)

    for _ in range(max_length):
        logits = model(input_ids)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat((input_ids, next_token), dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    output = tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)
    return output[len(full_prompt):].strip()