Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -13,14 +13,17 @@ from PIL import Image
|
|
13 |
import openai
|
14 |
import time
|
15 |
|
16 |
-
# β
Secure API key
|
17 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
18 |
|
19 |
-
# β
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
-
# β
Load PIQA
|
23 |
-
dataset = load_dataset("
|
|
|
|
|
|
|
24 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
25 |
|
26 |
def tokenize_choices(example):
|
@@ -35,7 +38,7 @@ def tokenize_choices(example):
|
|
35 |
dataset = dataset.map(tokenize_choices)
|
36 |
val_dataset = dataset["validation"].select(range(200)).with_format("torch")
|
37 |
|
38 |
-
# β
EvoTransformer
|
39 |
class EvoTransformer(nn.Module):
|
40 |
def __init__(self):
|
41 |
super().__init__()
|
@@ -53,7 +56,7 @@ class EvoTransformer(nn.Module):
|
|
53 |
x = self.encoder(x)
|
54 |
return self.classifier(x[:, 0, :]).squeeze(-1)
|
55 |
|
56 |
-
# β
GPT-3.5
|
57 |
def gpt35_answer(prompt):
|
58 |
try:
|
59 |
response = openai.ChatCompletion.create(
|
@@ -66,7 +69,7 @@ def gpt35_answer(prompt):
|
|
66 |
except Exception as e:
|
67 |
return f"[Error: {e}]"
|
68 |
|
69 |
-
# β
|
70 |
def train_and_demo(few_shot_size):
|
71 |
start_time = time.time()
|
72 |
model = EvoTransformer().to(device)
|
@@ -133,7 +136,7 @@ def train_and_demo(few_shot_size):
|
|
133 |
buf.seek(0)
|
134 |
img = Image.open(buf)
|
135 |
|
136 |
-
# β
|
137 |
output = ""
|
138 |
for i in range(2):
|
139 |
ex = dataset["validation"][i]
|
@@ -167,15 +170,15 @@ EvoTransformer v2.1 Configuration:
|
|
167 |
|
168 |
return img, f"Best Accuracy: {best_val:.4f}", output.strip() + "\n\n" + architecture_info.strip()
|
169 |
|
170 |
-
# β
Gradio
|
171 |
gr.Interface(
|
172 |
fn=train_and_demo,
|
173 |
-
inputs=gr.Slider(10, 500, step=10, value=50, label="Training
|
174 |
outputs=[
|
175 |
-
gr.Image(label="
|
176 |
gr.Textbox(label="Best Accuracy"),
|
177 |
-
gr.Textbox(label="Evo vs GPT-3.5
|
178 |
],
|
179 |
-
title="𧬠EvoTransformer v2.1
|
180 |
-
description="Train EvoTransformer on PIQA and
|
181 |
).launch()
|
|
|
13 |
import openai
|
14 |
import time
|
15 |
|
16 |
+
# β
Secure OpenAI API key
|
17 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
18 |
|
19 |
+
# β
Use CPU or GPU
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
+
# β
Load PIQA from Hugging Face JSON (safe for Spaces)
|
23 |
+
dataset = load_dataset("json", data_files={
|
24 |
+
"train": "https://huggingface.co/datasets/AI-Sweden/piqa-downsampled/resolve/main/train.json",
|
25 |
+
"validation": "https://huggingface.co/datasets/AI-Sweden/piqa-downsampled/resolve/main/validation.json"
|
26 |
+
})
|
27 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
28 |
|
29 |
def tokenize_choices(example):
|
|
|
38 |
dataset = dataset.map(tokenize_choices)
|
39 |
val_dataset = dataset["validation"].select(range(200)).with_format("torch")
|
40 |
|
41 |
+
# β
EvoTransformer definition
|
42 |
class EvoTransformer(nn.Module):
|
43 |
def __init__(self):
|
44 |
super().__init__()
|
|
|
56 |
x = self.encoder(x)
|
57 |
return self.classifier(x[:, 0, :]).squeeze(-1)
|
58 |
|
59 |
+
# β
GPT-3.5 response
|
60 |
def gpt35_answer(prompt):
|
61 |
try:
|
62 |
response = openai.ChatCompletion.create(
|
|
|
69 |
except Exception as e:
|
70 |
return f"[Error: {e}]"
|
71 |
|
72 |
+
# β
Training + Evaluation function
|
73 |
def train_and_demo(few_shot_size):
|
74 |
start_time = time.time()
|
75 |
model = EvoTransformer().to(device)
|
|
|
136 |
buf.seek(0)
|
137 |
img = Image.open(buf)
|
138 |
|
139 |
+
# β
Show comparison examples
|
140 |
output = ""
|
141 |
for i in range(2):
|
142 |
ex = dataset["validation"][i]
|
|
|
170 |
|
171 |
return img, f"Best Accuracy: {best_val:.4f}", output.strip() + "\n\n" + architecture_info.strip()
|
172 |
|
173 |
+
# β
Gradio Interface
|
174 |
gr.Interface(
|
175 |
fn=train_and_demo,
|
176 |
+
inputs=gr.Slider(10, 500, step=10, value=50, label="Number of Training Examples"),
|
177 |
outputs=[
|
178 |
+
gr.Image(label="Accuracy Plot"),
|
179 |
gr.Textbox(label="Best Accuracy"),
|
180 |
+
gr.Textbox(label="Evo vs GPT-3.5 Output")
|
181 |
],
|
182 |
+
title="𧬠EvoTransformer v2.1 Benchmark",
|
183 |
+
description="Train EvoTransformer on PIQA and compare its predictions to GPT-3.5."
|
184 |
).launch()
|