HemanM's picture
Update app.py
6842aeb verified
raw
history blame
3.5 kB
import gradio as gr
from evo_transformer import EvoTransformer
import pandas as pd
import json
import tempfile
import matplotlib.pyplot as plt
import numpy as np
model = EvoTransformer()
def plot_radar(history):
traits = ["layers", "attention_heads", "ffn_dim", "dropout"]
labels = traits
last = history[-1]
values = [
last["layers"] / 12,
last["attention_heads"] / 12,
last["ffn_dim"] / 2048,
last["dropout"] / 0.5,
]
angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
values += values[:1]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(5,5), subplot_kw=dict(polar=True))
ax.plot(angles, values, "o-", linewidth=2)
ax.fill(angles, values, alpha=0.25)
ax.set_thetagrids(np.degrees(angles[:-1]), labels)
ax.set_title("Final Generation Trait Radar")
ax.grid(True)
img_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
plt.savefig(img_path)
plt.close(fig)
return img_path
def evolve_and_display(gens):
try:
model.evolve(generations=gens)
history = model.get_history()
eval_result = model.evaluate()
summary = "\n".join([f"{k}: {v}" for k, v in history[-1].items()])
history_txt = [json.dumps(h, indent=2) for h in history]
while len(history_txt) < 10:
history_txt.append("")
# Save CSV and JSON
df = pd.DataFrame(history)
csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
json_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
df.to_csv(csv_path, index=False)
with open(json_path, "w") as jf:
json.dump(history, jf, indent=2)
radar_img = plot_radar(history)
return (
f"{eval_result['accuracy']*100:.2f}%",
f"{eval_result['params']:.2f}M params",
summary,
*history_txt,
radar_img,
csv_path,
json_path,
)
except Exception as e:
print("🚨 ERROR during evolution:", str(e))
import traceback
traceback.print_exc()
return (
"Error", "Error", "Error",
*["Error"] * 10,
None, None, None
)
with gr.Blocks() as demo:
gr.Markdown("# 🧬 EvoTransformer – Evolving Transformer Architectures")
gr.Markdown("Simulate trait mutation and adaptive architecture generation.")
with gr.Row():
generations = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
evolve_btn = gr.Button("🧬 Evolve Architecture")
with gr.Row():
acc_out = gr.Text(label="Simulated Accuracy")
param_out = gr.Text(label="Estimated Parameters")
summary_out = gr.Textbox(label="Current Config Summary", lines=5)
with gr.Accordion("🧬 Evolution History", open=False):
hist_outputs = [gr.Textbox(label=f"Gen {i+1} Config", lines=4) for i in range(10)]
radar_plot = gr.Image(label="Final Generation Trait Radar")
with gr.Row():
csv_out = gr.File(label="Download CSV History")
json_out = gr.File(label="Download JSON History")
evolve_btn.click(
evolve_and_display,
inputs=[generations],
outputs=[
acc_out,
param_out,
summary_out,
*hist_outputs,
radar_plot,
csv_out,
json_out
],
)
demo.launch()