HemanM commited on
Commit
312dbba
Β·
verified Β·
1 Parent(s): 08db431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py CHANGED
@@ -1,3 +1,181 @@
1
  import os
 
 
 
 
 
 
 
 
 
 
 
2
  import openai
 
 
 
3
  openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer, get_scheduler
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import io
12
+ from PIL import Image
13
  import openai
14
+ import time
15
+
16
+ # βœ… Secure API key
17
  openai.api_key = os.getenv("OPENAI_API_KEY")
18
+
19
+ # βœ… Device config
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # βœ… Load PIQA dataset
23
+ dataset = load_dataset("piqa")
24
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
25
+
26
+ def tokenize_choices(example):
27
+ input_0 = tokenizer(example["goal"] + " " + example["sol1"], truncation=True, padding="max_length", max_length=128)
28
+ input_1 = tokenizer(example["goal"] + " " + example["sol2"], truncation=True, padding="max_length", max_length=128)
29
+ return {
30
+ "input_ids_0": input_0["input_ids"],
31
+ "input_ids_1": input_1["input_ids"],
32
+ "label": example["label"]
33
+ }
34
+
35
+ dataset = dataset.map(tokenize_choices)
36
+ val_dataset = dataset["validation"].select(range(200)).with_format("torch")
37
+
38
+ # βœ… EvoTransformer architecture
39
+ class EvoTransformer(nn.Module):
40
+ def __init__(self):
41
+ super().__init__()
42
+ self.embedding = nn.Embedding(30522, 384)
43
+ encoder_layer = nn.TransformerEncoderLayer(d_model=384, nhead=6, dim_feedforward=1024, batch_first=True)
44
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
45
+ self.classifier = nn.Sequential(
46
+ nn.Linear(384, 128),
47
+ nn.ReLU(),
48
+ nn.Linear(128, 1)
49
+ )
50
+
51
+ def forward(self, input_ids):
52
+ x = self.embedding(input_ids)
53
+ x = self.encoder(x)
54
+ return self.classifier(x[:, 0, :]).squeeze(-1)
55
+
56
+ # βœ… GPT-3.5 call
57
+ def gpt35_answer(prompt):
58
+ try:
59
+ response = openai.ChatCompletion.create(
60
+ model="gpt-3.5-turbo",
61
+ messages=[{"role": "user", "content": prompt}],
62
+ max_tokens=20,
63
+ temperature=0
64
+ )
65
+ return response['choices'][0]['message']['content'].strip()
66
+ except Exception as e:
67
+ return f"[Error: {e}]"
68
+
69
+ # βœ… Evo vs GPT logic
70
+ def train_and_demo(few_shot_size):
71
+ start_time = time.time()
72
+ model = EvoTransformer().to(device)
73
+ criterion = nn.CrossEntropyLoss()
74
+ optimizer = optim.AdamW(model.parameters(), lr=5e-5)
75
+
76
+ train_set = dataset["train"].select(range(few_shot_size)).with_format("torch")
77
+ train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
78
+ val_loader = DataLoader(val_dataset, batch_size=32)
79
+
80
+ scheduler = get_scheduler("linear", optimizer=optimizer,
81
+ num_warmup_steps=0, num_training_steps=3 * len(train_loader))
82
+
83
+ best_val = 0
84
+ accs = []
85
+ patience = 2
86
+ early_stop = 0
87
+
88
+ for epoch in range(3):
89
+ model.train()
90
+ for batch in train_loader:
91
+ optimizer.zero_grad()
92
+ x0 = batch["input_ids_0"].to(device)
93
+ x1 = batch["input_ids_1"].to(device)
94
+ labels = batch["label"].to(device)
95
+ l0 = model(x0)
96
+ l1 = model(x1)
97
+ logits = torch.stack([l0, l1], dim=1)
98
+ loss = criterion(logits, labels)
99
+ loss.backward()
100
+ optimizer.step()
101
+ scheduler.step()
102
+
103
+ model.eval()
104
+ correct = 0
105
+ with torch.no_grad():
106
+ for batch in val_loader:
107
+ x0 = batch["input_ids_0"].to(device)
108
+ x1 = batch["input_ids_1"].to(device)
109
+ labels = batch["label"].to(device)
110
+ l0 = model(x0)
111
+ l1 = model(x1)
112
+ logits = torch.stack([l0, l1], dim=1)
113
+ preds = torch.argmax(logits, dim=1)
114
+ correct += (preds == labels).sum().item()
115
+ acc = correct / len(val_dataset)
116
+ accs.append(acc)
117
+ if acc > best_val:
118
+ best_val = acc
119
+ early_stop = 0
120
+ else:
121
+ early_stop += 1
122
+ if early_stop >= patience:
123
+ break
124
+
125
+ # βœ… Accuracy plot
126
+ fig, ax = plt.subplots()
127
+ ax.plot(accs, marker='o')
128
+ ax.set_title(f"Validation Accuracy ({few_shot_size} examples)")
129
+ ax.set_xlabel("Epoch")
130
+ ax.set_ylabel("Accuracy")
131
+ buf = io.BytesIO()
132
+ plt.savefig(buf, format='png')
133
+ buf.seek(0)
134
+ img = Image.open(buf)
135
+
136
+ # βœ… Sample output vs GPT-3.5
137
+ output = ""
138
+ for i in range(2):
139
+ ex = dataset["validation"][i]
140
+ goal = ex["goal"]
141
+ sol1 = ex["sol1"]
142
+ sol2 = ex["sol2"]
143
+
144
+ x0 = torch.tensor([ex["input_ids_0"]]).to(device)
145
+ x1 = torch.tensor([ex["input_ids_1"]]).to(device)
146
+ l0 = model(x0)
147
+ l1 = model(x1)
148
+ pred_evo = 0 if l0 > l1 else 1
149
+ correct_evo = "βœ…" if pred_evo == ex["label"] else "❌"
150
+
151
+ gpt_prompt = f"Q: {goal}\nA) {sol1}\nB) {sol2}\nWhich is more appropriate? Answer with A or B only."
152
+ gpt_out = gpt35_answer(gpt_prompt)
153
+ pred_gpt = gpt_out[0].upper()
154
+ correct_gpt = "βœ…" if (pred_gpt == 'A' and ex["label"] == 0) or (pred_gpt == 'B' and ex["label"] == 1) else "❌"
155
+
156
+ output += f"Q: {goal}\nA) {sol1}\nB) {sol2}\n\nEvoTransformer: {'A' if pred_evo==0 else 'B'} {correct_evo}\nGPT-3.5: {pred_gpt} {correct_gpt}\n\n"
157
+
158
+ architecture_info = f"""
159
+ EvoTransformer v2.1 Configuration:
160
+ - Embedding Dim: 384
161
+ - Transformer Layers: 6
162
+ - Attention Heads: 6
163
+ - Feedforward Size: 1024
164
+ - Parameters: ~13M
165
+ - Training Time: {time.time() - start_time:.2f}s
166
+ """
167
+
168
+ return img, f"Best Accuracy: {best_val:.4f}", output.strip() + "\n\n" + architecture_info.strip()
169
+
170
+ # βœ… Gradio app
171
+ gr.Interface(
172
+ fn=train_and_demo,
173
+ inputs=gr.Slider(10, 500, step=10, value=50, label="Training Samples"),
174
+ outputs=[
175
+ gr.Image(label="Validation Accuracy Plot"),
176
+ gr.Textbox(label="Best Accuracy"),
177
+ gr.Textbox(label="Evo vs GPT-3.5: Predictions & Architecture")
178
+ ],
179
+ title="🧬 EvoTransformer v2.1 vs GPT-3.5",
180
+ description="Train EvoTransformer on PIQA and benchmark it live against GPT-3.5."
181
+ ).launch()