Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,69 +1,94 @@
|
|
1 |
-
# app.py
|
2 |
import gradio as gr
|
3 |
import matplotlib.pyplot as plt
|
4 |
import pandas as pd
|
|
|
5 |
import io
|
6 |
from evo_transformer import EvoTransformer
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
evo.evolve(generations)
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
last_config["ffn_dim"]/100, last_config["dropout"]*10, int(last_config["memory"])*10]
|
24 |
-
|
25 |
-
fig, ax = plt.subplots(figsize=(6,6), subplot_kw=dict(polar=True))
|
26 |
-
angles = [n / float(len(traits)) * 2 * 3.14159 for n in range(len(traits))]
|
27 |
values += values[:1]
|
28 |
angles += angles[:1]
|
29 |
-
ax.plot(angles, values, linewidth=2)
|
30 |
-
ax.fill(angles, values, alpha=0.3)
|
31 |
-
ax.set_xticks(angles[:-1])
|
32 |
-
ax.set_xticklabels(traits)
|
33 |
-
ax.set_title("Final Architecture Traits", size=15)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
return (
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
43 |
fig,
|
44 |
-
("evo_history.csv"
|
45 |
-
("evo_history.json",
|
46 |
)
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
gr.Markdown(
|
|
|
|
|
|
|
51 |
|
52 |
with gr.Row():
|
53 |
-
generations = gr.Slider(1, 10, value=5,
|
54 |
-
|
55 |
|
56 |
with gr.Row():
|
57 |
-
|
58 |
-
|
59 |
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
|
63 |
with gr.Row():
|
64 |
-
|
65 |
-
|
66 |
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import matplotlib.pyplot as plt
|
3 |
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
import io
|
6 |
from evo_transformer import EvoTransformer
|
7 |
|
8 |
+
# Global instance (resettable)
|
9 |
+
evo = EvoTransformer()
|
|
|
10 |
|
11 |
+
# === Visualization Functions ===
|
12 |
+
def plot_radar(config):
|
13 |
+
labels = ["Layers", "Attention Heads", "FFN Dim", "Dropout", "Memory"]
|
14 |
+
values = [
|
15 |
+
config["layers"],
|
16 |
+
config["attention_heads"],
|
17 |
+
config["ffn_dim"],
|
18 |
+
int(config["dropout"] * 100),
|
19 |
+
int(config["memory"])
|
20 |
+
]
|
21 |
+
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
|
|
|
|
|
|
|
|
|
22 |
values += values[:1]
|
23 |
angles += angles[:1]
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
|
26 |
+
ax.plot(angles, values, "o-", linewidth=2)
|
27 |
+
ax.fill(angles, values, alpha=0.25)
|
28 |
+
ax.set_thetagrids(np.degrees(angles[:-1]), labels)
|
29 |
+
ax.set_title("Final Architecture (Radar Chart)")
|
30 |
+
return fig
|
31 |
+
|
32 |
+
def evolve_and_display(generations):
|
33 |
+
global evo
|
34 |
+
evo = EvoTransformer() # Reset model
|
35 |
+
evo.evolve(generations)
|
36 |
+
|
37 |
+
df = evo.get_history_df()
|
38 |
+
final_config = evo.get_final_config()
|
39 |
+
accuracy, params = evo.evaluate()
|
40 |
+
|
41 |
+
fig = plot_radar(final_config)
|
42 |
+
|
43 |
+
json_file = io.BytesIO()
|
44 |
+
json_file.write(evo.get_history_json().encode("utf-8"))
|
45 |
+
json_file.seek(0)
|
46 |
+
|
47 |
+
csv_file = io.BytesIO()
|
48 |
+
df.to_csv(csv_file, index=False)
|
49 |
+
csv_file.seek(0)
|
50 |
|
51 |
return (
|
52 |
+
accuracy,
|
53 |
+
params,
|
54 |
+
gr.Tabs.update(visible=True),
|
55 |
+
[gr.Textbox.update(value=str(row)) for _, row in df.iterrows()],
|
56 |
fig,
|
57 |
+
(csv_file, "evo_history.csv"),
|
58 |
+
(json_file, "evo_history.json"),
|
59 |
)
|
60 |
|
61 |
+
# === Gradio UI ===
|
62 |
+
with gr.Blocks(title="EvoTransformer Live Demo") as demo:
|
63 |
+
gr.Markdown(
|
64 |
+
"🚀 **EvoTransformer Live Demo**\n\n"
|
65 |
+
"This demo evolves a Transformer architecture and displays how traits change over generations."
|
66 |
+
)
|
67 |
|
68 |
with gr.Row():
|
69 |
+
generations = gr.Slider(1, 10, value=5, label="Generations")
|
70 |
+
evolve_btn = gr.Button("Evolve Now 🚀")
|
71 |
|
72 |
with gr.Row():
|
73 |
+
accuracy_out = gr.Number(label="Estimated Accuracy", value=0)
|
74 |
+
params_out = gr.Number(label="Estimated Params (M)", value=0)
|
75 |
|
76 |
+
tabbox = gr.Tabs(visible=False)
|
77 |
+
with tabbox:
|
78 |
+
with gr.Tab(label="Evolution History"):
|
79 |
+
history_display = [gr.Textbox(label=str(i+1), interactive=False) for i in range(10)]
|
80 |
+
with gr.Tab(label="Radar View"):
|
81 |
+
radar_plot = gr.Plot()
|
82 |
|
83 |
with gr.Row():
|
84 |
+
csv_btn = gr.File(label="Download CSV")
|
85 |
+
json_btn = gr.File(label="Download JSON")
|
86 |
|
87 |
+
evolve_btn.click(
|
88 |
+
evolve_and_display,
|
89 |
+
inputs=[generations],
|
90 |
+
outputs=[accuracy_out, params_out, tabbox, history_display, radar_plot, csv_btn, json_btn]
|
91 |
+
)
|
92 |
|
93 |
+
if __name__ == "__main__":
|
94 |
+
demo.launch()
|