File size: 3,619 Bytes
10054c3
 
2ad7d0e
ce7f881
10054c3
 
 
 
 
2ad7d0e
10054c3
 
ce7f881
10054c3
 
 
 
eeda69b
10054c3
 
eeda69b
10054c3
eeda69b
10054c3
 
eeda69b
10054c3
 
 
eeda69b
10054c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce7f881
118680e
10054c3
 
 
 
 
 
 
118680e
2ad7d0e
10054c3
 
 
 
 
 
 
ce7f881
 
10054c3
 
ce7f881
 
10054c3
 
2ad7d0e
10054c3
 
 
76d9193
ce7f881
10054c3
 
76d9193
eeda69b
 
 
10054c3
eeda69b
2ad7d0e
10054c3
eeda69b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# app.py

import gradio as gr
from evo_transformer import EvoTransformer
import pandas as pd
import matplotlib.pyplot as plt
import tempfile
import os
import json

# === Initialize Model ===
model = EvoTransformer()

# === Create Interface Components ===
generations = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")

evolve_btn = gr.Button("🧬 Evolve Architecture")

accuracy_out = gr.Textbox(label="Simulated Accuracy")
params_out = gr.Textbox(label="Estimated Parameters")

tabbox = gr.Textbox(label="Current Config Summary")

# Dynamic display of evolution history
history_display = [gr.Textbox(label=f"Gen {i+1} Config") for i in range(10)]

# Download buttons
csv_btn = gr.File(label="Download CSV History")
json_btn = gr.File(label="Download JSON History")

# === Helper: Create Evolution Radar Plot ===
def plot_radar(history):
    import numpy as np

    traits = ["layers", "attention_heads", "ffn_dim", "dropout", "memory"]
    N = len(traits)
    values = [history[-1].get(t, 0) for t in traits]
    values = [int(v) if isinstance(v, bool) else v for v in values]

    angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
    values += values[:1]
    angles += angles[:1]

    fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
    ax.plot(angles, values, "o-", linewidth=2)
    ax.fill(angles, values, alpha=0.25)
    ax.set_thetagrids([a * 180 / np.pi for a in angles[:-1]], traits)
    ax.set_title("Last Gen Trait Radar", fontsize=14)
    plt.tight_layout()

    tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    plt.savefig(tmp.name)
    plt.close()
    return tmp.name

radar_plot = gr.Image(label="Final Generation Trait Radar")

# === Main Evolution Logic ===
def evolve_and_display(gens):
    model.evolve(generations=gens)
    history = model.get_history()
    eval_result = model.evaluate()

    # Format summary
    summary = "\n".join([f"{k}: {v}" for k, v in history[-1].items()])

    # Fill up to 10 generations of history for display
    history_txt = [json.dumps(h, indent=2) for h in history]
    while len(history_txt) < 10:
        history_txt.append("")

    # Generate CSV + JSON
    df = pd.DataFrame(history)
    csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
    json_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
    df.to_csv(csv_path, index=False)
    with open(json_path, "w") as jf:
        json.dump(history, jf, indent=2)

    # Plot radar
    radar_img = plot_radar(history)

    return (
        f"{eval_result['accuracy']*100:.2f}%",
        f"{eval_result['params']:.2f}M params",
        summary,
        *history_txt,
        radar_img,
        csv_path,
        json_path,
    )

# === Interface Layout ===
with gr.Blocks(title="EvoTransformer Demo") as demo:
    gr.Markdown("# 🧬 EvoTransformer – Evolving Transformer Architectures")
    gr.Markdown("Simulate trait mutation and adaptive architecture generation.")
    with gr.Row():
        generations.render()
        evolve_btn.render()

    with gr.Row():
        accuracy_out.render()
        params_out.render()

    with gr.Row():
        radar_plot.render()
        tabbox.render()

    with gr.Row():
        for box in history_display:
            box.render()

    with gr.Row():
        csv_btn.render()
        json_btn.render()

    evolve_btn.click(
        evolve_and_display,
        inputs=[generations],
        outputs=[accuracy_out, params_out, tabbox, *history_display, radar_plot, csv_btn, json_btn],
    )

# === Launch App ===
if __name__ == "__main__":
    demo.launch()