File size: 2,825 Bytes
2ad7d0e
76d9193
 
eeda69b
ce7f881
 
2ad7d0e
eeda69b
 
ce7f881
eeda69b
 
 
 
 
 
 
 
 
 
 
ce7f881
 
 
eeda69b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce7f881
118680e
eeda69b
 
 
 
ce7f881
eeda69b
 
118680e
2ad7d0e
eeda69b
 
 
 
 
 
ce7f881
 
eeda69b
 
ce7f881
 
eeda69b
 
2ad7d0e
eeda69b
 
 
 
 
 
76d9193
ce7f881
eeda69b
 
76d9193
eeda69b
 
 
 
 
2ad7d0e
eeda69b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import io
from evo_transformer import EvoTransformer

# Global instance (resettable)
evo = EvoTransformer()

# === Visualization Functions ===
def plot_radar(config):
    labels = ["Layers", "Attention Heads", "FFN Dim", "Dropout", "Memory"]
    values = [
        config["layers"],
        config["attention_heads"],
        config["ffn_dim"],
        int(config["dropout"] * 100),
        int(config["memory"])
    ]
    angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
    values += values[:1]
    angles += angles[:1]

    fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
    ax.plot(angles, values, "o-", linewidth=2)
    ax.fill(angles, values, alpha=0.25)
    ax.set_thetagrids(np.degrees(angles[:-1]), labels)
    ax.set_title("Final Architecture (Radar Chart)")
    return fig

def evolve_and_display(generations):
    global evo
    evo = EvoTransformer()  # Reset model
    evo.evolve(generations)

    df = evo.get_history_df()
    final_config = evo.get_final_config()
    accuracy, params = evo.evaluate()

    fig = plot_radar(final_config)

    json_file = io.BytesIO()
    json_file.write(evo.get_history_json().encode("utf-8"))
    json_file.seek(0)

    csv_file = io.BytesIO()
    df.to_csv(csv_file, index=False)
    csv_file.seek(0)

    return (
        accuracy,
        params,
        gr.Tabs.update(visible=True),
        [gr.Textbox.update(value=str(row)) for _, row in df.iterrows()],
        fig,
        (csv_file, "evo_history.csv"),
        (json_file, "evo_history.json"),
    )

# === Gradio UI ===
with gr.Blocks(title="EvoTransformer Live Demo") as demo:
    gr.Markdown(
        "🚀 **EvoTransformer Live Demo**\n\n"
        "This demo evolves a Transformer architecture and displays how traits change over generations."
    )

    with gr.Row():
        generations = gr.Slider(1, 10, value=5, label="Generations")
        evolve_btn = gr.Button("Evolve Now 🚀")

    with gr.Row():
        accuracy_out = gr.Number(label="Estimated Accuracy", value=0)
        params_out = gr.Number(label="Estimated Params (M)", value=0)

    tabbox = gr.Tabs(visible=False)
    with tabbox:
        with gr.Tab(label="Evolution History"):
            history_display = [gr.Textbox(label=str(i+1), interactive=False) for i in range(10)]
        with gr.Tab(label="Radar View"):
            radar_plot = gr.Plot()

    with gr.Row():
        csv_btn = gr.File(label="Download CSV")
        json_btn = gr.File(label="Download JSON")

    evolve_btn.click(
        evolve_and_display,
        inputs=[generations],
        outputs=[accuracy_out, params_out, tabbox, history_display, radar_plot, csv_btn, json_btn]
    )

if __name__ == "__main__":
    demo.launch()