File size: 912 Bytes
2e57b59
 
fd77f55
2e57b59
67209e5
2e57b59
 
67209e5
 
fd77f55
67209e5
2e57b59
 
 
 
 
 
 
fd77f55
2e57b59
67209e5
2e57b59
67209e5
2e57b59
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from init_model import get_tokenizer, get_base_model
from model import EvoTransformerArabic

# Load tokenizer and base encoder (Arabic BERT)
tokenizer = get_tokenizer()
bert = get_base_model()

# Load Evo model and weights
model = EvoTransformerArabic()
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
model.eval()

def evo_suggest(question, option1, option2):
    inputs = [question + " " + option1, question + " " + option2]
    scores = []

    for text in inputs:
        encoded = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
        with torch.no_grad():
            outputs = bert(**encoded).last_hidden_state[:, 0, :]  # Get [CLS] token
            logits = model(outputs)
            scores.append(logits[0][1].item())  # Confidence for class 1

    return option1 if scores[0] > scores[1] else option2