HemanM commited on
Commit
d590322
·
verified ·
1 Parent(s): 81ef8c2

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +14 -12
inference.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
- from transformers import AutoTokenizer, OpenAIGPTLMHeadModel
3
  from evo_model import EvoTransformerV22
 
 
4
 
5
  # Load Evo model
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -12,7 +14,7 @@ evo_model.eval()
12
  # Load tokenizer
13
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
14
 
15
- # 🧠 Evo logic
16
  def get_evo_response(query, context):
17
  combined = query + " " + context
18
  inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
@@ -20,24 +22,24 @@ def get_evo_response(query, context):
20
 
21
  with torch.no_grad():
22
  logits = evo_model(input_ids)
23
- pred = torch.argmax(logits, dim=1).item()
24
 
25
- return f"Evo suggests: Option {pred + 1}" # Assumes binary classification (0 or 1)
26
 
27
- # 🤖 GPT-3.5 comparison (optional)
28
- import openai
29
- openai.api_key = "sk-..." # Replace with your OpenAI API key
30
 
31
  def get_gpt_response(query, context):
32
  try:
33
  prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
34
- response = openai.ChatCompletion.create(
35
  model="gpt-3.5-turbo",
36
- messages=[{"role": "user", "content": prompt}],
 
 
37
  temperature=0.3
38
  )
39
- return response['choices'][0]['message']['content'].strip()
40
  except Exception as e:
41
  return f"Error from GPT: {e}"
42
-
43
- #
 
1
  import torch
2
+ from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerV22
4
+ from openai import OpenAI
5
+ import os
6
 
7
  # Load Evo model
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
  # Load tokenizer
15
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
16
 
17
+ # 🧠 Evo logic (binary classification with sigmoid)
18
  def get_evo_response(query, context):
19
  combined = query + " " + context
20
  inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
 
22
 
23
  with torch.no_grad():
24
  logits = evo_model(input_ids)
25
+ pred = int(torch.sigmoid(logits).item() > 0.5)
26
 
27
+ return f"Evo suggests: Option {pred + 1}"
28
 
29
+ # 🤖 GPT-3.5 comparison using openai>=1.0.0
30
+ openai_api_key = os.environ.get("OPENAI_API_KEY", "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") # Replace with real key or set via HF secrets
31
+ client = OpenAI(api_key=openai_api_key)
32
 
33
  def get_gpt_response(query, context):
34
  try:
35
  prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
36
+ response = client.chat.completions.create(
37
  model="gpt-3.5-turbo",
38
+ messages=[
39
+ {"role": "user", "content": prompt}
40
+ ],
41
  temperature=0.3
42
  )
43
+ return response.choices[0].message.content.strip()
44
  except Exception as e:
45
  return f"Error from GPT: {e}"