HemanM commited on
Commit
63d9bd3
Β·
verified Β·
1 Parent(s): 312dbba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -13,14 +13,17 @@ from PIL import Image
13
  import openai
14
  import time
15
 
16
- # βœ… Secure API key
17
  openai.api_key = os.getenv("OPENAI_API_KEY")
18
 
19
- # βœ… Device config
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- # βœ… Load PIQA dataset
23
- dataset = load_dataset("piqa")
 
 
 
24
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
25
 
26
  def tokenize_choices(example):
@@ -35,7 +38,7 @@ def tokenize_choices(example):
35
  dataset = dataset.map(tokenize_choices)
36
  val_dataset = dataset["validation"].select(range(200)).with_format("torch")
37
 
38
- # βœ… EvoTransformer architecture
39
  class EvoTransformer(nn.Module):
40
  def __init__(self):
41
  super().__init__()
@@ -53,7 +56,7 @@ class EvoTransformer(nn.Module):
53
  x = self.encoder(x)
54
  return self.classifier(x[:, 0, :]).squeeze(-1)
55
 
56
- # βœ… GPT-3.5 call
57
  def gpt35_answer(prompt):
58
  try:
59
  response = openai.ChatCompletion.create(
@@ -66,7 +69,7 @@ def gpt35_answer(prompt):
66
  except Exception as e:
67
  return f"[Error: {e}]"
68
 
69
- # βœ… Evo vs GPT logic
70
  def train_and_demo(few_shot_size):
71
  start_time = time.time()
72
  model = EvoTransformer().to(device)
@@ -133,7 +136,7 @@ def train_and_demo(few_shot_size):
133
  buf.seek(0)
134
  img = Image.open(buf)
135
 
136
- # βœ… Sample output vs GPT-3.5
137
  output = ""
138
  for i in range(2):
139
  ex = dataset["validation"][i]
@@ -167,15 +170,15 @@ EvoTransformer v2.1 Configuration:
167
 
168
  return img, f"Best Accuracy: {best_val:.4f}", output.strip() + "\n\n" + architecture_info.strip()
169
 
170
- # βœ… Gradio app
171
  gr.Interface(
172
  fn=train_and_demo,
173
- inputs=gr.Slider(10, 500, step=10, value=50, label="Training Samples"),
174
  outputs=[
175
- gr.Image(label="Validation Accuracy Plot"),
176
  gr.Textbox(label="Best Accuracy"),
177
- gr.Textbox(label="Evo vs GPT-3.5: Predictions & Architecture")
178
  ],
179
- title="🧬 EvoTransformer v2.1 vs GPT-3.5",
180
- description="Train EvoTransformer on PIQA and benchmark it live against GPT-3.5."
181
  ).launch()
 
13
  import openai
14
  import time
15
 
16
+ # βœ… Secure OpenAI API key
17
  openai.api_key = os.getenv("OPENAI_API_KEY")
18
 
19
+ # βœ… Use CPU or GPU
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # βœ… Load PIQA from Hugging Face JSON (safe for Spaces)
23
+ dataset = load_dataset("json", data_files={
24
+ "train": "https://huggingface.co/datasets/AI-Sweden/piqa-downsampled/resolve/main/train.json",
25
+ "validation": "https://huggingface.co/datasets/AI-Sweden/piqa-downsampled/resolve/main/validation.json"
26
+ })
27
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
28
 
29
  def tokenize_choices(example):
 
38
  dataset = dataset.map(tokenize_choices)
39
  val_dataset = dataset["validation"].select(range(200)).with_format("torch")
40
 
41
+ # βœ… EvoTransformer definition
42
  class EvoTransformer(nn.Module):
43
  def __init__(self):
44
  super().__init__()
 
56
  x = self.encoder(x)
57
  return self.classifier(x[:, 0, :]).squeeze(-1)
58
 
59
+ # βœ… GPT-3.5 response
60
  def gpt35_answer(prompt):
61
  try:
62
  response = openai.ChatCompletion.create(
 
69
  except Exception as e:
70
  return f"[Error: {e}]"
71
 
72
+ # βœ… Training + Evaluation function
73
  def train_and_demo(few_shot_size):
74
  start_time = time.time()
75
  model = EvoTransformer().to(device)
 
136
  buf.seek(0)
137
  img = Image.open(buf)
138
 
139
+ # βœ… Show comparison examples
140
  output = ""
141
  for i in range(2):
142
  ex = dataset["validation"][i]
 
170
 
171
  return img, f"Best Accuracy: {best_val:.4f}", output.strip() + "\n\n" + architecture_info.strip()
172
 
173
+ # βœ… Gradio Interface
174
  gr.Interface(
175
  fn=train_and_demo,
176
+ inputs=gr.Slider(10, 500, step=10, value=50, label="Number of Training Examples"),
177
  outputs=[
178
+ gr.Image(label="Accuracy Plot"),
179
  gr.Textbox(label="Best Accuracy"),
180
+ gr.Textbox(label="Evo vs GPT-3.5 Output")
181
  ],
182
+ title="🧬 EvoTransformer v2.1 Benchmark",
183
+ description="Train EvoTransformer on PIQA and compare its predictions to GPT-3.5."
184
  ).launch()