HemanM commited on
Commit
ce7f881
·
verified ·
1 Parent(s): 5490a10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -30
app.py CHANGED
@@ -1,45 +1,69 @@
 
1
  import gradio as gr
2
- from evo_transformer import EvoTransformer
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
- import seaborn as sns
 
6
 
7
- def simulate_evolution(generations):
8
- model = EvoTransformer()
9
- model.evolve(generations)
10
- history = model.get_history()
11
-
12
- # Evaluation of final model
13
- final_result = model.evaluate()
14
 
15
- # Visualization
16
- fig, ax = plt.subplots(figsize=(10, 5))
17
  df = pd.DataFrame(history)
18
- sns.lineplot(data=df.drop("memory", axis=1), markers=True, ax=ax)
19
- ax.set_title("EvoTransformer Trait Evolution")
20
- ax.set_ylabel("Value")
21
- ax.set_xlabel("Generation")
22
- plt.xticks(range(len(df)))
23
- plt.tight_layout()
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  return (
26
- final_result["accuracy"],
27
- round(final_result["params"], 2),
28
- fig
 
 
 
29
  )
30
 
31
- # UI
32
- with gr.Blocks(title="EvoTransformer: Adaptive Architecture Evolution") as demo:
33
- gr.Markdown("# 🧬 EvoTransformer Demo")
34
- gr.Markdown("Evolving architecture live — inspired by nature. Tune the number of generations below and visualize how traits change during evolution.")
 
 
 
 
 
 
 
35
 
36
- generations = gr.Slider(1, 20, value=5, step=1, label="Number of Generations")
37
- btn = gr.Button("Evolve EvoTransformer")
38
 
39
- accuracy = gr.Textbox(label="Estimated Accuracy", interactive=False)
40
- params = gr.Textbox(label="Estimated Parameter Count (M)", interactive=False)
41
- plot = gr.Plot(label="Architecture Trait Evolution")
42
 
43
- btn.click(fn=simulate_evolution, inputs=generations, outputs=[accuracy, params, plot])
44
 
45
  demo.launch()
 
1
+ # app.py
2
  import gradio as gr
 
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
+ import io
6
+ from evo_transformer import EvoTransformer
7
 
8
+ def run_evolution(generations):
9
+ evo = EvoTransformer()
10
+ evo.evolve(generations)
11
+
12
+ history = evo.get_history()
13
+ final_eval = evo.evaluate()
 
14
 
 
 
15
  df = pd.DataFrame(history)
16
+ df_display = df.copy()
17
+ df_display["memory"] = df_display["memory"].apply(lambda x: "Enabled" if x else "Disabled")
18
+
19
+ # Radar chart
20
+ last_config = history[-1]
21
+ traits = ["layers", "attention_heads", "ffn_dim", "dropout", "memory"]
22
+ values = [last_config["layers"], last_config["attention_heads"],
23
+ last_config["ffn_dim"]/100, last_config["dropout"]*10, int(last_config["memory"])*10]
24
 
25
+ fig, ax = plt.subplots(figsize=(6,6), subplot_kw=dict(polar=True))
26
+ angles = [n / float(len(traits)) * 2 * 3.14159 for n in range(len(traits))]
27
+ values += values[:1]
28
+ angles += angles[:1]
29
+ ax.plot(angles, values, linewidth=2)
30
+ ax.fill(angles, values, alpha=0.3)
31
+ ax.set_xticks(angles[:-1])
32
+ ax.set_xticklabels(traits)
33
+ ax.set_title("Final Architecture Traits", size=15)
34
+
35
+ # File downloads
36
+ csv_file = io.StringIO(evo.export_csv())
37
+ json_file = io.StringIO(evo.export_json())
38
+
39
  return (
40
+ df_display,
41
+ final_eval["accuracy"],
42
+ final_eval["params"],
43
+ fig,
44
+ ("evo_history.csv", csv_file.getvalue()),
45
+ ("evo_history.json", json_file.getvalue())
46
  )
47
 
48
+ with gr.Blocks(title="EvoTransformer Demo") as demo:
49
+ gr.Markdown("# 🧬 EvoTransformer Live Demo")
50
+ gr.Markdown("This demo evolves a Transformer architecture and displays how traits change over generations.")
51
+
52
+ with gr.Row():
53
+ generations = gr.Slider(1, 10, value=5, step=1, label="Generations")
54
+ run_btn = gr.Button("Evolve Now 🚀")
55
+
56
+ with gr.Row():
57
+ acc = gr.Number(label="Estimated Accuracy")
58
+ params = gr.Number(label="Estimated Params (M)")
59
 
60
+ table = gr.Dataframe(label="Evolution History", wrap=True)
61
+ plot = gr.Plot(label="Final Architecture (Radar Chart)")
62
 
63
+ with gr.Row():
64
+ csv_dl = gr.File(label="Download CSV", file_types=[".csv"], interactive=True)
65
+ json_dl = gr.File(label="Download JSON", file_types=[".json"], interactive=True)
66
 
67
+ run_btn.click(fn=run_evolution, inputs=generations, outputs=[table, acc, params, plot, csv_dl, json_dl])
68
 
69
  demo.launch()