HemanM's picture
Update app.py
b9e3604 verified
raw
history blame
2.98 kB
# app.py
import gradio as gr
from evo_transformer import EvoTransformer
import pandas as pd
import matplotlib.pyplot as plt
import tempfile
import os
import json
# === Initialize Model ===
model = EvoTransformer()
# === Create Interface Components ===
generations = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
evolve_btn = gr.Button("🧬 Evolve Architecture")
accuracy_out = gr.Textbox(label="Simulated Accuracy")
params_out = gr.Textbox(label="Estimated Parameters")
tabbox = gr.Textbox(label="Current Config Summary")
# Dynamic display of evolution history
history_display = [gr.Textbox(label=f"Gen {i+1} Config") for i in range(10)]
# Download buttons
csv_btn = gr.File(label="Download CSV History")
json_btn = gr.File(label="Download JSON History")
# === Helper: Create Evolution Radar Plot ===
def evolve_and_display(gens):
try:
model.evolve(generations=gens)
history = model.get_history()
eval_result = model.evaluate()
summary = "\n".join([f"{k}: {v}" for k, v in history[-1].items()])
history_txt = [json.dumps(h, indent=2) for h in history]
while len(history_txt) < 10:
history_txt.append("")
# Save CSV and JSON
df = pd.DataFrame(history)
csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
json_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
df.to_csv(csv_path, index=False)
with open(json_path, "w") as jf:
json.dump(history, jf, indent=2)
radar_img = plot_radar(history)
return (
f"{eval_result['accuracy']*100:.2f}%",
f"{eval_result['params']:.2f}M params",
summary,
*history_txt,
radar_img,
csv_path,
json_path,
)
except Exception as e:
print("🚨 ERROR during evolution:", str(e))
import traceback
traceback.print_exc()
return (
"Error", "Error", "Error",
*["Error"] * 10,
None, None, None
)
# === Interface Layout ===
with gr.Blocks(title="EvoTransformer Demo") as demo:
gr.Markdown("# 🧬 EvoTransformer – Evolving Transformer Architectures")
gr.Markdown("Simulate trait mutation and adaptive architecture generation.")
with gr.Row():
generations.render()
evolve_btn.render()
with gr.Row():
accuracy_out.render()
params_out.render()
with gr.Row():
radar_plot.render()
tabbox.render()
with gr.Row():
for box in history_display:
box.render()
with gr.Row():
csv_btn.render()
json_btn.render()
evolve_btn.click(
evolve_and_display,
inputs=[generations],
outputs=[accuracy_out, params_out, tabbox, *history_display, radar_plot, csv_btn, json_btn],
)
# === Launch App ===
if __name__ == "__main__":
demo.launch()