EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
854864a verified
raw
history blame
1.15 kB
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
def generate_response(goal, option1, option2):
prompt1 = f"Goal: {goal}\nOption: {option1}"
prompt2 = f"Goal: {goal}\nOption: {option2}"
inputs1 = tokenizer(prompt1, return_tensors="pt", padding=True, truncation=True)
inputs2 = tokenizer(prompt2, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
output1 = model(**inputs1)
output2 = model(**inputs2)
logits1 = output1["logits"]
logits2 = output2["logits"]
prob1 = torch.softmax(logits1, dim=1)[0][1].item()
prob2 = torch.softmax(logits2, dim=1)[0][1].item()
if prob1 > prob2:
suggestion = "βœ… Option 1 is more aligned with the goal."
elif prob2 > prob1:
suggestion = "βœ… Option 2 is more aligned with the goal."
else:
suggestion = "βš–οΈ Both options are equally likely."
return suggestion