HemanM commited on
Commit
0be77ef
·
verified ·
1 Parent(s): 69a8269

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -82
app.py CHANGED
@@ -1,113 +1,86 @@
 
1
  import gradio as gr
2
- import pandas as pd
3
  import json
4
- import os
5
- from plot import plot_radar_chart
6
- from diagram import show_transformer_diagram
7
  from evo_transformer import EvoTransformer
 
8
 
9
- # === Initialize ===
10
  evo = EvoTransformer()
11
 
12
- # === Gradio Output Placeholders ===
13
- accuracy_output = gr.Text(label="Simulated Accuracy")
14
- params_output = gr.Text(label="Estimated Parameters")
15
- current_config_output = gr.Text(label="Current Config Summary")
16
- radar_image = gr.Image(label="Final Generation Trait Radar")
17
- diagram_image = gr.Image(label="Illustrative Transformer Structure")
18
- history_outputs = []
19
-
20
- # === Core Evolution Function ===
21
- def evolve_model(generations):
22
- evo.__init__() # Reset EvoTransformer
23
- evo.evolve(generations)
24
- score = evo.evaluate()
25
- radar_path = "radar_chart.png"
26
- diagram_path = "architecture_diagram.png"
27
-
28
- # Plot radar chart
29
  try:
30
- plot_radar_chart(evo.config, radar_path)
31
  except Exception as e:
32
  print("Radar chart error:", e)
33
- radar_path = None
34
 
35
- # Plot architecture diagram
36
- diagram_bytes = None
37
- try:
38
- plot_architecture_diagram(evo.config)
39
- if os.path.exists(diagram_path):
40
- with open(diagram_path, "rb") as f:
41
- diagram_bytes = f.read()
42
- except Exception as e:
43
- print("Diagram plot error:", e)
44
 
45
- # History cards
46
- history = evo.get_history()
47
- history_cards = [f"Gen {i+1} Config: {h}" for i, h in enumerate(history)]
 
 
 
48
 
49
- # History downloadables
50
- df = pd.DataFrame(history)
51
- df_csv_path = "evo_history.csv"
52
- df.to_csv(df_csv_path, index=False)
53
-
54
- json_path = "evo_history.json"
55
- with open(json_path, "w") as f:
56
- json.dump(history, f, indent=2)
57
 
58
  return (
59
- f"{score['accuracy']*100:.2f}%",
60
- f"{score['params']:.2f}M params",
61
- str(evo.config),
62
- radar_path,
63
- diagram_bytes,
64
  *history_cards,
65
- df_csv_path,
66
- json_path
67
  )
68
 
69
- # === Gradio Interface ===
70
- generations_input = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
71
- evolve_btn = gr.Button("\U0001F9EC Evolve Architecture")
72
-
73
- with gr.Blocks(title="EvoTransformer") as demo:
74
- gr.Markdown("""
75
- # 🧬 EvoTransformer – Evolving Transformer Architectures
76
- Simulate trait mutation and adaptive architecture generation.
77
- """)
78
 
79
  with gr.Row():
80
- generations_input.render()
81
- evolve_btn.render()
82
 
83
- accuracy_output.render()
84
- params_output.render()
85
- current_config_output.render()
 
86
 
87
  gr.Markdown("## 🧬 Evolution History")
88
- radar_image.render()
89
- diagram_image.render()
90
 
91
- with gr.Accordion("Downloadable Files", open=True):
92
- csv_file = gr.File(label="Download CSV History")
93
- json_file = gr.File(label="Download JSON History")
 
94
 
95
- for _ in range(10):
96
- card = gr.Textbox(label="", interactive=False)
97
- history_outputs.append(card)
98
 
99
  evolve_btn.click(
100
  fn=evolve_model,
101
- inputs=[generations_input],
102
  outputs=[
103
- accuracy_output,
104
- params_output,
105
- current_config_output,
106
- radar_image,
107
- diagram_image,
108
- *history_outputs,
109
- csv_file,
110
- json_file
111
  ]
112
  )
113
 
 
1
+ # app.py
2
  import gradio as gr
 
3
  import json
4
+ import csv
5
+ import io
 
6
  from evo_transformer import EvoTransformer
7
+ from plot import plot_radar_chart
8
 
9
+ # Global instance
10
  evo = EvoTransformer()
11
 
12
+ def evolve_model(num_generations):
13
+ evo.reset() # reset before evolution
14
+ evo.evolve(num_generations)
15
+ final_config = evo.config
16
+ history = evo.get_history()
17
+ evaluation = evo.evaluate()
18
+
19
+ # Radar Chart
 
 
 
 
 
 
 
 
 
20
  try:
21
+ radar_img = plot_radar_chart(final_config)
22
  except Exception as e:
23
  print("Radar chart error:", e)
24
+ radar_img = None
25
 
26
+ # Prepare history display
27
+ history_cards = []
28
+ for i, h in enumerate(history):
29
+ text = f"Gen {i + 1} Config: {h}"
30
+ history_cards.append(text)
 
 
 
 
31
 
32
+ # Prepare CSV and JSON
33
+ csv_file = io.StringIO()
34
+ csv_writer = csv.DictWriter(csv_file, fieldnames=history[0].keys())
35
+ csv_writer.writeheader()
36
+ csv_writer.writerows(history)
37
+ csv_bytes = io.BytesIO(csv_file.getvalue().encode())
38
 
39
+ json_file = io.StringIO(json.dumps(history, indent=2))
40
+ json_bytes = io.BytesIO(json_file.getvalue().encode())
 
 
 
 
 
 
41
 
42
  return (
43
+ f"{evaluation['accuracy']:.2f}%",
44
+ f"{evaluation['params']:.2f}M params",
45
+ str(final_config),
46
+ radar_img,
47
+ None, # placeholder for future diagram
48
  *history_cards,
49
+ ("evo_history.csv", csv_bytes),
50
+ ("evo_history.json", json_bytes)
51
  )
52
 
53
+ with gr.Blocks(title="🧬 EvoTransformer Evolving Transformer Architectures") as demo:
54
+ gr.Markdown("## 🧬 EvoTransformer Evolving Transformer Architectures\nSimulate trait mutation and adaptive architecture generation.")
 
 
 
 
 
 
 
55
 
56
  with gr.Row():
57
+ num_generations = gr.Slider(minimum=1, maximum=10, value=3, label="Number of Generations", step=1)
58
+ evolve_btn = gr.Button("🧬 Evolve Architecture")
59
 
60
+ with gr.Row():
61
+ accuracy_text = gr.Textbox(label="Simulated Accuracy")
62
+ param_text = gr.Textbox(label="Estimated Parameters")
63
+ config_text = gr.Textbox(label="Current Config Summary")
64
 
65
  gr.Markdown("## 🧬 Evolution History")
 
 
66
 
67
+ radar_output = gr.Image(label="Final Generation Trait Radar", type="pil", interactive=False)
68
+ diagram_output = gr.Image(label="Illustrative Transformer Structure", visible=False)
69
+
70
+ history_boxes = [gr.Textbox(label=f"Gen {i+1} Config") for i in range(10)]
71
 
72
+ with gr.Accordion("Downloadable Files", open=True):
73
+ csv_out = gr.File(label="Download CSV History")
74
+ json_out = gr.File(label="Download JSON History")
75
 
76
  evolve_btn.click(
77
  fn=evolve_model,
78
+ inputs=[num_generations],
79
  outputs=[
80
+ accuracy_text, param_text, config_text,
81
+ radar_output, diagram_output,
82
+ *history_boxes,
83
+ csv_out, json_out
 
 
 
 
84
  ]
85
  )
86