HemanM commited on
Commit
10054c3
·
verified ·
1 Parent(s): eeda69b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -66
app.py CHANGED
@@ -1,94 +1,125 @@
 
 
1
  import gradio as gr
2
- import matplotlib.pyplot as plt
3
- import pandas as pd
4
- import numpy as np
5
- import io
6
  from evo_transformer import EvoTransformer
 
 
 
 
 
7
 
8
- # Global instance (resettable)
9
- evo = EvoTransformer()
10
-
11
- # === Visualization Functions ===
12
- def plot_radar(config):
13
- labels = ["Layers", "Attention Heads", "FFN Dim", "Dropout", "Memory"]
14
- values = [
15
- config["layers"],
16
- config["attention_heads"],
17
- config["ffn_dim"],
18
- int(config["dropout"] * 100),
19
- int(config["memory"])
20
- ]
21
- angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
22
- values += values[:1]
23
- angles += angles[:1]
24
 
25
- fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
26
- ax.plot(angles, values, "o-", linewidth=2)
27
- ax.fill(angles, values, alpha=0.25)
28
- ax.set_thetagrids(np.degrees(angles[:-1]), labels)
29
- ax.set_title("Final Architecture (Radar Chart)")
30
- return fig
31
 
32
- def evolve_and_display(generations):
33
- global evo
34
- evo = EvoTransformer() # Reset model
35
- evo.evolve(generations)
36
 
37
- df = evo.get_history_df()
38
- final_config = evo.get_final_config()
39
- accuracy, params = evo.evaluate()
40
 
41
- fig = plot_radar(final_config)
 
42
 
43
- json_file = io.BytesIO()
44
- json_file.write(evo.get_history_json().encode("utf-8"))
45
- json_file.seek(0)
46
 
47
- csv_file = io.BytesIO()
48
- df.to_csv(csv_file, index=False)
49
- csv_file.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  return (
52
- accuracy,
53
- params,
54
- gr.Tabs.update(visible=True),
55
- [gr.Textbox.update(value=str(row)) for _, row in df.iterrows()],
56
- fig,
57
- (csv_file, "evo_history.csv"),
58
- (json_file, "evo_history.json"),
59
  )
60
 
61
- # === Gradio UI ===
62
- with gr.Blocks(title="EvoTransformer Live Demo") as demo:
63
- gr.Markdown(
64
- "🚀 **EvoTransformer Live Demo**\n\n"
65
- "This demo evolves a Transformer architecture and displays how traits change over generations."
66
- )
 
67
 
68
  with gr.Row():
69
- generations = gr.Slider(1, 10, value=5, label="Generations")
70
- evolve_btn = gr.Button("Evolve Now 🚀")
71
 
72
  with gr.Row():
73
- accuracy_out = gr.Number(label="Estimated Accuracy", value=0)
74
- params_out = gr.Number(label="Estimated Params (M)", value=0)
75
 
76
- tabbox = gr.Tabs(visible=False)
77
- with tabbox:
78
- with gr.Tab(label="Evolution History"):
79
- history_display = [gr.Textbox(label=str(i+1), interactive=False) for i in range(10)]
80
- with gr.Tab(label="Radar View"):
81
- radar_plot = gr.Plot()
82
 
83
  with gr.Row():
84
- csv_btn = gr.File(label="Download CSV")
85
- json_btn = gr.File(label="Download JSON")
86
 
87
  evolve_btn.click(
88
  evolve_and_display,
89
  inputs=[generations],
90
- outputs=[accuracy_out, params_out, tabbox, history_display, radar_plot, csv_btn, json_btn]
91
  )
92
 
 
93
  if __name__ == "__main__":
94
  demo.launch()
 
1
+ # app.py
2
+
3
  import gradio as gr
 
 
 
 
4
  from evo_transformer import EvoTransformer
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import tempfile
8
+ import os
9
+ import json
10
 
11
+ # === Initialize Model ===
12
+ model = EvoTransformer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # === Create Interface Components ===
15
+ generations = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
16
+
17
+ evolve_btn = gr.Button("🧬 Evolve Architecture")
 
 
18
 
19
+ accuracy_out = gr.Textbox(label="Simulated Accuracy")
20
+ params_out = gr.Textbox(label="Estimated Parameters")
 
 
21
 
22
+ tabbox = gr.Textbox(label="Current Config Summary")
 
 
23
 
24
+ # Dynamic display of evolution history
25
+ history_display = [gr.Textbox(label=f"Gen {i+1} Config") for i in range(10)]
26
 
27
+ # Download buttons
28
+ csv_btn = gr.File(label="Download CSV History")
29
+ json_btn = gr.File(label="Download JSON History")
30
 
31
+ # === Helper: Create Evolution Radar Plot ===
32
+ def plot_radar(history):
33
+ import numpy as np
34
+
35
+ traits = ["layers", "attention_heads", "ffn_dim", "dropout", "memory"]
36
+ N = len(traits)
37
+ values = [history[-1].get(t, 0) for t in traits]
38
+ values = [int(v) if isinstance(v, bool) else v for v in values]
39
+
40
+ angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
41
+ values += values[:1]
42
+ angles += angles[:1]
43
+
44
+ fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
45
+ ax.plot(angles, values, "o-", linewidth=2)
46
+ ax.fill(angles, values, alpha=0.25)
47
+ ax.set_thetagrids([a * 180 / np.pi for a in angles[:-1]], traits)
48
+ ax.set_title("Last Gen Trait Radar", fontsize=14)
49
+ plt.tight_layout()
50
+
51
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
52
+ plt.savefig(tmp.name)
53
+ plt.close()
54
+ return tmp.name
55
+
56
+ radar_plot = gr.Image(label="Final Generation Trait Radar")
57
+
58
+ # === Main Evolution Logic ===
59
+ def evolve_and_display(gens):
60
+ model.evolve(generations=gens)
61
+ history = model.get_history()
62
+ eval_result = model.evaluate()
63
+
64
+ # Format summary
65
+ summary = "\n".join([f"{k}: {v}" for k, v in history[-1].items()])
66
+
67
+ # Fill up to 10 generations of history for display
68
+ history_txt = [json.dumps(h, indent=2) for h in history]
69
+ while len(history_txt) < 10:
70
+ history_txt.append("")
71
+
72
+ # Generate CSV + JSON
73
+ df = pd.DataFrame(history)
74
+ csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
75
+ json_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
76
+ df.to_csv(csv_path, index=False)
77
+ with open(json_path, "w") as jf:
78
+ json.dump(history, jf, indent=2)
79
+
80
+ # Plot radar
81
+ radar_img = plot_radar(history)
82
 
83
  return (
84
+ f"{eval_result['accuracy']*100:.2f}%",
85
+ f"{eval_result['params']:.2f}M params",
86
+ summary,
87
+ *history_txt,
88
+ radar_img,
89
+ csv_path,
90
+ json_path,
91
  )
92
 
93
+ # === Interface Layout ===
94
+ with gr.Blocks(title="EvoTransformer Demo") as demo:
95
+ gr.Markdown("# 🧬 EvoTransformer – Evolving Transformer Architectures")
96
+ gr.Markdown("Simulate trait mutation and adaptive architecture generation.")
97
+ with gr.Row():
98
+ generations.render()
99
+ evolve_btn.render()
100
 
101
  with gr.Row():
102
+ accuracy_out.render()
103
+ params_out.render()
104
 
105
  with gr.Row():
106
+ radar_plot.render()
107
+ tabbox.render()
108
 
109
+ with gr.Row():
110
+ for box in history_display:
111
+ box.render()
 
 
 
112
 
113
  with gr.Row():
114
+ csv_btn.render()
115
+ json_btn.render()
116
 
117
  evolve_btn.click(
118
  evolve_and_display,
119
  inputs=[generations],
120
+ outputs=[accuracy_out, params_out, tabbox, *history_display, radar_plot, csv_btn, json_btn],
121
  )
122
 
123
+ # === Launch App ===
124
  if __name__ == "__main__":
125
  demo.launch()