HemanM commited on
Commit
eeda69b
·
verified ·
1 Parent(s): 2b38854

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -45
app.py CHANGED
@@ -1,69 +1,94 @@
1
- # app.py
2
  import gradio as gr
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
 
5
  import io
6
  from evo_transformer import EvoTransformer
7
 
8
- def run_evolution(generations):
9
- evo = EvoTransformer()
10
- evo.evolve(generations)
11
 
12
- history = evo.get_history()
13
- final_eval = evo.evaluate()
14
-
15
- df = pd.DataFrame(history)
16
- df_display = df.copy()
17
- df_display["memory"] = df_display["memory"].apply(lambda x: "Enabled" if x else "Disabled")
18
-
19
- # Radar chart
20
- last_config = history[-1]
21
- traits = ["layers", "attention_heads", "ffn_dim", "dropout", "memory"]
22
- values = [last_config["layers"], last_config["attention_heads"],
23
- last_config["ffn_dim"]/100, last_config["dropout"]*10, int(last_config["memory"])*10]
24
-
25
- fig, ax = plt.subplots(figsize=(6,6), subplot_kw=dict(polar=True))
26
- angles = [n / float(len(traits)) * 2 * 3.14159 for n in range(len(traits))]
27
  values += values[:1]
28
  angles += angles[:1]
29
- ax.plot(angles, values, linewidth=2)
30
- ax.fill(angles, values, alpha=0.3)
31
- ax.set_xticks(angles[:-1])
32
- ax.set_xticklabels(traits)
33
- ax.set_title("Final Architecture Traits", size=15)
34
 
35
- # File downloads
36
- csv_file = io.StringIO(evo.export_csv())
37
- json_file = io.StringIO(evo.export_json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  return (
40
- df_display,
41
- final_eval["accuracy"],
42
- final_eval["params"],
 
43
  fig,
44
- ("evo_history.csv", csv_file.getvalue()),
45
- ("evo_history.json", json_file.getvalue())
46
  )
47
 
48
- with gr.Blocks(title="EvoTransformer Demo") as demo:
49
- gr.Markdown("# 🧬 EvoTransformer Live Demo")
50
- gr.Markdown("This demo evolves a Transformer architecture and displays how traits change over generations.")
 
 
 
51
 
52
  with gr.Row():
53
- generations = gr.Slider(1, 10, value=5, step=1, label="Generations")
54
- run_btn = gr.Button("Evolve Now 🚀")
55
 
56
  with gr.Row():
57
- acc = gr.Number(label="Estimated Accuracy")
58
- params = gr.Number(label="Estimated Params (M)")
59
 
60
- table = gr.Dataframe(label="Evolution History", wrap=True)
61
- plot = gr.Plot(label="Final Architecture (Radar Chart)")
 
 
 
 
62
 
63
  with gr.Row():
64
- csv_dl = gr.File(label="Download CSV", file_types=[".csv"], interactive=True)
65
- json_dl = gr.File(label="Download JSON", file_types=[".json"], interactive=True)
66
 
67
- run_btn.click(fn=run_evolution, inputs=generations, outputs=[table, acc, params, plot, csv_dl, json_dl])
 
 
 
 
68
 
69
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
4
+ import numpy as np
5
  import io
6
  from evo_transformer import EvoTransformer
7
 
8
+ # Global instance (resettable)
9
+ evo = EvoTransformer()
 
10
 
11
+ # === Visualization Functions ===
12
+ def plot_radar(config):
13
+ labels = ["Layers", "Attention Heads", "FFN Dim", "Dropout", "Memory"]
14
+ values = [
15
+ config["layers"],
16
+ config["attention_heads"],
17
+ config["ffn_dim"],
18
+ int(config["dropout"] * 100),
19
+ int(config["memory"])
20
+ ]
21
+ angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
 
 
 
 
22
  values += values[:1]
23
  angles += angles[:1]
 
 
 
 
 
24
 
25
+ fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
26
+ ax.plot(angles, values, "o-", linewidth=2)
27
+ ax.fill(angles, values, alpha=0.25)
28
+ ax.set_thetagrids(np.degrees(angles[:-1]), labels)
29
+ ax.set_title("Final Architecture (Radar Chart)")
30
+ return fig
31
+
32
+ def evolve_and_display(generations):
33
+ global evo
34
+ evo = EvoTransformer() # Reset model
35
+ evo.evolve(generations)
36
+
37
+ df = evo.get_history_df()
38
+ final_config = evo.get_final_config()
39
+ accuracy, params = evo.evaluate()
40
+
41
+ fig = plot_radar(final_config)
42
+
43
+ json_file = io.BytesIO()
44
+ json_file.write(evo.get_history_json().encode("utf-8"))
45
+ json_file.seek(0)
46
+
47
+ csv_file = io.BytesIO()
48
+ df.to_csv(csv_file, index=False)
49
+ csv_file.seek(0)
50
 
51
  return (
52
+ accuracy,
53
+ params,
54
+ gr.Tabs.update(visible=True),
55
+ [gr.Textbox.update(value=str(row)) for _, row in df.iterrows()],
56
  fig,
57
+ (csv_file, "evo_history.csv"),
58
+ (json_file, "evo_history.json"),
59
  )
60
 
61
+ # === Gradio UI ===
62
+ with gr.Blocks(title="EvoTransformer Live Demo") as demo:
63
+ gr.Markdown(
64
+ "🚀 **EvoTransformer Live Demo**\n\n"
65
+ "This demo evolves a Transformer architecture and displays how traits change over generations."
66
+ )
67
 
68
  with gr.Row():
69
+ generations = gr.Slider(1, 10, value=5, label="Generations")
70
+ evolve_btn = gr.Button("Evolve Now 🚀")
71
 
72
  with gr.Row():
73
+ accuracy_out = gr.Number(label="Estimated Accuracy", value=0)
74
+ params_out = gr.Number(label="Estimated Params (M)", value=0)
75
 
76
+ tabbox = gr.Tabs(visible=False)
77
+ with tabbox:
78
+ with gr.Tab(label="Evolution History"):
79
+ history_display = [gr.Textbox(label=str(i+1), interactive=False) for i in range(10)]
80
+ with gr.Tab(label="Radar View"):
81
+ radar_plot = gr.Plot()
82
 
83
  with gr.Row():
84
+ csv_btn = gr.File(label="Download CSV")
85
+ json_btn = gr.File(label="Download JSON")
86
 
87
+ evolve_btn.click(
88
+ evolve_and_display,
89
+ inputs=[generations],
90
+ outputs=[accuracy_out, params_out, tabbox, history_display, radar_plot, csv_btn, json_btn]
91
+ )
92
 
93
+ if __name__ == "__main__":
94
+ demo.launch()