HemanM's picture
Update app.py
10054c3 verified
raw
history blame
3.62 kB
# app.py
import gradio as gr
from evo_transformer import EvoTransformer
import pandas as pd
import matplotlib.pyplot as plt
import tempfile
import os
import json
# === Initialize Model ===
model = EvoTransformer()
# === Create Interface Components ===
generations = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
evolve_btn = gr.Button("🧬 Evolve Architecture")
accuracy_out = gr.Textbox(label="Simulated Accuracy")
params_out = gr.Textbox(label="Estimated Parameters")
tabbox = gr.Textbox(label="Current Config Summary")
# Dynamic display of evolution history
history_display = [gr.Textbox(label=f"Gen {i+1} Config") for i in range(10)]
# Download buttons
csv_btn = gr.File(label="Download CSV History")
json_btn = gr.File(label="Download JSON History")
# === Helper: Create Evolution Radar Plot ===
def plot_radar(history):
import numpy as np
traits = ["layers", "attention_heads", "ffn_dim", "dropout", "memory"]
N = len(traits)
values = [history[-1].get(t, 0) for t in traits]
values = [int(v) if isinstance(v, bool) else v for v in values]
angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
values += values[:1]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
ax.plot(angles, values, "o-", linewidth=2)
ax.fill(angles, values, alpha=0.25)
ax.set_thetagrids([a * 180 / np.pi for a in angles[:-1]], traits)
ax.set_title("Last Gen Trait Radar", fontsize=14)
plt.tight_layout()
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
plt.savefig(tmp.name)
plt.close()
return tmp.name
radar_plot = gr.Image(label="Final Generation Trait Radar")
# === Main Evolution Logic ===
def evolve_and_display(gens):
model.evolve(generations=gens)
history = model.get_history()
eval_result = model.evaluate()
# Format summary
summary = "\n".join([f"{k}: {v}" for k, v in history[-1].items()])
# Fill up to 10 generations of history for display
history_txt = [json.dumps(h, indent=2) for h in history]
while len(history_txt) < 10:
history_txt.append("")
# Generate CSV + JSON
df = pd.DataFrame(history)
csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
json_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
df.to_csv(csv_path, index=False)
with open(json_path, "w") as jf:
json.dump(history, jf, indent=2)
# Plot radar
radar_img = plot_radar(history)
return (
f"{eval_result['accuracy']*100:.2f}%",
f"{eval_result['params']:.2f}M params",
summary,
*history_txt,
radar_img,
csv_path,
json_path,
)
# === Interface Layout ===
with gr.Blocks(title="EvoTransformer Demo") as demo:
gr.Markdown("# 🧬 EvoTransformer – Evolving Transformer Architectures")
gr.Markdown("Simulate trait mutation and adaptive architecture generation.")
with gr.Row():
generations.render()
evolve_btn.render()
with gr.Row():
accuracy_out.render()
params_out.render()
with gr.Row():
radar_plot.render()
tabbox.render()
with gr.Row():
for box in history_display:
box.render()
with gr.Row():
csv_btn.render()
json_btn.render()
evolve_btn.click(
evolve_and_display,
inputs=[generations],
outputs=[accuracy_out, params_out, tabbox, *history_display, radar_plot, csv_btn, json_btn],
)
# === Launch App ===
if __name__ == "__main__":
demo.launch()