HemanM commited on
Commit
68934cc
·
verified ·
1 Parent(s): 5f10669

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -54
app.py CHANGED
@@ -1,79 +1,114 @@
1
  import gradio as gr
2
- from evo_transformer import EvoTransformer
3
- from plots import plot_radar_chart
4
- from diagrams import get_transformer_diagram
5
  import pandas as pd
6
  import json
7
- import tempfile
 
 
 
8
 
9
- et = EvoTransformer()
 
10
 
11
- def run_evolution(generations):
12
- et.reset()
13
- et.evolve(generations)
 
 
 
 
14
 
15
- final_eval = et.evaluate()
16
- csv_path = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
17
- json_path = tempfile.NamedTemporaryFile(delete=False, suffix=".json").name
 
 
 
 
18
 
19
- df = pd.DataFrame(et.get_history())
20
- df.to_csv(csv_path, index=False)
21
- with open(json_path, "w") as f:
22
- json.dump(et.get_history(), f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- radar_plot = plot_radar_chart(et.config)
25
- diagram_path = get_transformer_diagram(et.config)
 
 
26
 
27
- history_outputs = [gr.Textbox(label=f"Gen {i+1} Config", value=json.dumps(cfg, indent=2), lines=4) for i, cfg in enumerate(et.get_history())]
 
 
28
 
29
  return (
30
- f"{final_eval['accuracy']*100:.2f}%",
31
- f"{final_eval['params']:.2f}M params",
32
- json.dumps(et.config, indent=2),
33
- radar_plot,
34
- diagram_path,
35
- history_outputs,
36
- csv_path,
37
  json_path
38
  )
39
 
40
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
41
- gr.Markdown("## 🧬 EvoTransformer Evolving Transformer Architectures")
42
- gr.Markdown("Simulate trait mutation and adaptive architecture generation.")
43
 
44
- with gr.Row():
45
- generations_slider = gr.Slider(1, 10, value=3, label="Number of Generations", step=1)
46
- evolve_btn = gr.Button("🧬 Evolve Architecture", variant="primary")
 
 
47
 
48
  with gr.Row():
49
- accuracy_output = gr.Textbox(label="Simulated Accuracy")
50
- param_output = gr.Textbox(label="Estimated Parameters")
51
- current_config = gr.Textbox(label="Current Config Summary", lines=5)
52
 
53
- with gr.Column():
54
- gr.Markdown("## 🧬 Evolution History")
55
- radar_output = gr.Image(label="Final Generation Trait Radar", height=400)
56
- diagram_output = gr.Image(label="Illustrative Transformer Structure", height=300)
57
- history_group = gr.Group()
58
 
59
- with gr.Row():
60
- csv_download = gr.File(label="Download CSV History")
61
- json_download = gr.File(label="Download JSON History")
 
 
 
 
 
 
 
 
62
 
63
  evolve_btn.click(
64
- fn=run_evolution,
65
- inputs=[generations_slider],
66
  outputs=[
67
  accuracy_output,
68
- param_output,
69
- current_config,
70
- radar_output,
71
- diagram_output,
72
- history_group,
73
- csv_download,
74
- json_download,
75
- ],
76
  )
77
 
78
- if __name__ == "__main__":
79
- demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import pandas as pd
3
  import json
4
+ import os
5
+ from evo_transformer import EvoTransformer
6
+ from plot import plot_radar_chart
7
+ from diagram import plot_architecture_diagram
8
 
9
+ # === Initialize ===
10
+ evo = EvoTransformer()
11
 
12
+ # === Gradio Output Placeholders ===
13
+ accuracy_output = gr.Text(label="Simulated Accuracy")
14
+ params_output = gr.Text(label="Estimated Parameters")
15
+ current_config_output = gr.Text(label="Current Config Summary")
16
+ radar_image = gr.Image(label="Final Generation Trait Radar")
17
+ diagram_image = gr.Image(label="Illustrative Transformer Structure")
18
+ history_outputs = []
19
 
20
+ # === Core Evolution Function ===
21
+ def evolve_model(generations):
22
+ evo.__init__() # Reset EvoTransformer
23
+ evo.evolve(generations)
24
+ score = evo.evaluate()
25
+ radar_path = "radar_chart.png"
26
+ diagram_path = "architecture_diagram.png"
27
 
28
+ # Plot radar chart
29
+ try:
30
+ plot_radar_chart(evo.config, radar_path)
31
+ except Exception as e:
32
+ print("Radar chart error:", e)
33
+ radar_path = None
34
+
35
+ # Plot architecture diagram
36
+ diagram_bytes = None
37
+ try:
38
+ plot_architecture_diagram(evo.config)
39
+ if os.path.exists(diagram_path):
40
+ with open(diagram_path, "rb") as f:
41
+ diagram_bytes = f.read()
42
+ except Exception as e:
43
+ print("Diagram plot error:", e)
44
+
45
+ # History cards
46
+ history = evo.get_history()
47
+ history_cards = [f"Gen {i+1} Config: {h}" for i, h in enumerate(history)]
48
 
49
+ # History downloadables
50
+ df = pd.DataFrame(history)
51
+ df_csv_path = "evo_history.csv"
52
+ df.to_csv(df_csv_path, index=False)
53
 
54
+ json_path = "evo_history.json"
55
+ with open(json_path, "w") as f:
56
+ json.dump(history, f, indent=2)
57
 
58
  return (
59
+ f"{score['accuracy']*100:.2f}%",
60
+ f"{score['params']:.2f}M params",
61
+ str(evo.config),
62
+ radar_path,
63
+ diagram_bytes,
64
+ *history_cards,
65
+ df_csv_path,
66
  json_path
67
  )
68
 
69
+ # === Gradio Interface ===
70
+ generations_input = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
71
+ evolve_btn = gr.Button("\U0001F9EC Evolve Architecture")
72
 
73
+ with gr.Blocks(title="EvoTransformer") as demo:
74
+ gr.Markdown("""
75
+ # 🧬 EvoTransformer Evolving Transformer Architectures
76
+ Simulate trait mutation and adaptive architecture generation.
77
+ """)
78
 
79
  with gr.Row():
80
+ generations_input.render()
81
+ evolve_btn.render()
 
82
 
83
+ accuracy_output.render()
84
+ params_output.render()
85
+ current_config_output.render()
 
 
86
 
87
+ gr.Markdown("## 🧬 Evolution History")
88
+ radar_image.render()
89
+ diagram_image.render()
90
+
91
+ with gr.Accordion("Downloadable Files", open=True):
92
+ csv_file = gr.File(label="Download CSV History")
93
+ json_file = gr.File(label="Download JSON History")
94
+
95
+ for _ in range(10):
96
+ card = gr.Textbox(label="", interactive=False)
97
+ history_outputs.append(card)
98
 
99
  evolve_btn.click(
100
+ fn=evolve_model,
101
+ inputs=[generations_input],
102
  outputs=[
103
  accuracy_output,
104
+ params_output,
105
+ current_config_output,
106
+ radar_image,
107
+ diagram_image,
108
+ *history_outputs,
109
+ csv_file,
110
+ json_file
111
+ ]
112
  )
113
 
114
+ demo.launch()