File size: 3,498 Bytes
2ad7d0e
ce7f881
10054c3
 
6842aeb
 
 
2ad7d0e
10054c3
ce7f881
6842aeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeda69b
10054c3
b9e3604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ad7d0e
6842aeb
10054c3
 
ce7f881
 
6842aeb
 
ce7f881
 
6842aeb
 
2ad7d0e
6842aeb
 
 
 
 
 
76d9193
ce7f881
6842aeb
 
76d9193
eeda69b
 
 
6842aeb
 
 
 
 
 
 
 
 
eeda69b
2ad7d0e
6842aeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from evo_transformer import EvoTransformer
import pandas as pd
import json
import tempfile
import matplotlib.pyplot as plt
import numpy as np

model = EvoTransformer()

def plot_radar(history):
    traits = ["layers", "attention_heads", "ffn_dim", "dropout"]
    labels = traits
    last = history[-1]
    values = [
        last["layers"] / 12,
        last["attention_heads"] / 12,
        last["ffn_dim"] / 2048,
        last["dropout"] / 0.5,
    ]
    angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
    values += values[:1]
    angles += angles[:1]

    fig, ax = plt.subplots(figsize=(5,5), subplot_kw=dict(polar=True))
    ax.plot(angles, values, "o-", linewidth=2)
    ax.fill(angles, values, alpha=0.25)
    ax.set_thetagrids(np.degrees(angles[:-1]), labels)
    ax.set_title("Final Generation Trait Radar")
    ax.grid(True)

    img_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
    plt.savefig(img_path)
    plt.close(fig)
    return img_path

def evolve_and_display(gens):
    try:
        model.evolve(generations=gens)
        history = model.get_history()
        eval_result = model.evaluate()

        summary = "\n".join([f"{k}: {v}" for k, v in history[-1].items()])
        history_txt = [json.dumps(h, indent=2) for h in history]
        while len(history_txt) < 10:
            history_txt.append("")

        # Save CSV and JSON
        df = pd.DataFrame(history)
        csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
        json_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
        df.to_csv(csv_path, index=False)
        with open(json_path, "w") as jf:
            json.dump(history, jf, indent=2)

        radar_img = plot_radar(history)

        return (
            f"{eval_result['accuracy']*100:.2f}%",
            f"{eval_result['params']:.2f}M params",
            summary,
            *history_txt,
            radar_img,
            csv_path,
            json_path,
        )
    except Exception as e:
        print("🚨 ERROR during evolution:", str(e))
        import traceback
        traceback.print_exc()

        return (
            "Error", "Error", "Error",
            *["Error"] * 10,
            None, None, None
        )

with gr.Blocks() as demo:
    gr.Markdown("# 🧬 EvoTransformer – Evolving Transformer Architectures")
    gr.Markdown("Simulate trait mutation and adaptive architecture generation.")

    with gr.Row():
        generations = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
        evolve_btn = gr.Button("🧬 Evolve Architecture")

    with gr.Row():
        acc_out = gr.Text(label="Simulated Accuracy")
        param_out = gr.Text(label="Estimated Parameters")

    summary_out = gr.Textbox(label="Current Config Summary", lines=5)

    with gr.Accordion("🧬 Evolution History", open=False):
        hist_outputs = [gr.Textbox(label=f"Gen {i+1} Config", lines=4) for i in range(10)]

    radar_plot = gr.Image(label="Final Generation Trait Radar")

    with gr.Row():
        csv_out = gr.File(label="Download CSV History")
        json_out = gr.File(label="Download JSON History")

    evolve_btn.click(
        evolve_and_display,
        inputs=[generations],
        outputs=[
            acc_out,
            param_out,
            summary_out,
            *hist_outputs,
            radar_plot,
            csv_out,
            json_out
        ],
    )

demo.launch()