Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
from evo_transformer import EvoTransformer | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import tempfile | |
import os | |
import json | |
# === Initialize Model === | |
model = EvoTransformer() | |
# === Create Interface Components === | |
generations = gr.Slider(1, 10, value=3, step=1, label="Number of Generations") | |
evolve_btn = gr.Button("🧬 Evolve Architecture") | |
accuracy_out = gr.Textbox(label="Simulated Accuracy") | |
params_out = gr.Textbox(label="Estimated Parameters") | |
tabbox = gr.Textbox(label="Current Config Summary") | |
# Dynamic display of evolution history | |
history_display = [gr.Textbox(label=f"Gen {i+1} Config") for i in range(10)] | |
# Download buttons | |
csv_btn = gr.File(label="Download CSV History") | |
json_btn = gr.File(label="Download JSON History") | |
# === Helper: Create Evolution Radar Plot === | |
def plot_radar(history): | |
import numpy as np | |
traits = ["layers", "attention_heads", "ffn_dim", "dropout", "memory"] | |
N = len(traits) | |
values = [history[-1].get(t, 0) for t in traits] | |
values = [int(v) if isinstance(v, bool) else v for v in values] | |
angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() | |
values += values[:1] | |
angles += angles[:1] | |
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) | |
ax.plot(angles, values, "o-", linewidth=2) | |
ax.fill(angles, values, alpha=0.25) | |
ax.set_thetagrids([a * 180 / np.pi for a in angles[:-1]], traits) | |
ax.set_title("Last Gen Trait Radar", fontsize=14) | |
plt.tight_layout() | |
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
plt.savefig(tmp.name) | |
plt.close() | |
return tmp.name | |
radar_plot = gr.Image(label="Final Generation Trait Radar") | |
# === Main Evolution Logic === | |
def evolve_and_display(gens): | |
model.evolve(generations=gens) | |
history = model.get_history() | |
eval_result = model.evaluate() | |
# Format summary | |
summary = "\n".join([f"{k}: {v}" for k, v in history[-1].items()]) | |
# Fill up to 10 generations of history for display | |
history_txt = [json.dumps(h, indent=2) for h in history] | |
while len(history_txt) < 10: | |
history_txt.append("") | |
# Generate CSV + JSON | |
df = pd.DataFrame(history) | |
csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name | |
json_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name | |
df.to_csv(csv_path, index=False) | |
with open(json_path, "w") as jf: | |
json.dump(history, jf, indent=2) | |
# Plot radar | |
radar_img = plot_radar(history) | |
return ( | |
f"{eval_result['accuracy']*100:.2f}%", | |
f"{eval_result['params']:.2f}M params", | |
summary, | |
*history_txt, | |
radar_img, | |
csv_path, | |
json_path, | |
) | |
# === Interface Layout === | |
with gr.Blocks(title="EvoTransformer Demo") as demo: | |
gr.Markdown("# 🧬 EvoTransformer – Evolving Transformer Architectures") | |
gr.Markdown("Simulate trait mutation and adaptive architecture generation.") | |
with gr.Row(): | |
generations.render() | |
evolve_btn.render() | |
with gr.Row(): | |
accuracy_out.render() | |
params_out.render() | |
with gr.Row(): | |
radar_plot.render() | |
tabbox.render() | |
with gr.Row(): | |
for box in history_display: | |
box.render() | |
with gr.Row(): | |
csv_btn.render() | |
json_btn.render() | |
evolve_btn.click( | |
evolve_and_display, | |
inputs=[generations], | |
outputs=[accuracy_out, params_out, tabbox, *history_display, radar_plot, csv_btn, json_btn], | |
) | |
# === Launch App === | |
if __name__ == "__main__": | |
demo.launch() | |