HemanM's picture
Update app.py
c67165f verified
raw
history blame
2.57 kB
import gradio as gr
from evo_transformer import EvoTransformer
from plots import plot_radar_chart
from diagrams import get_transformer_diagram
import pandas as pd
import json
import tempfile
et = EvoTransformer()
def run_evolution(generations):
et.reset()
et.evolve(generations)
final_eval = et.evaluate()
csv_path = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
json_path = tempfile.NamedTemporaryFile(delete=False, suffix=".json").name
df = pd.DataFrame(et.get_history())
df.to_csv(csv_path, index=False)
with open(json_path, "w") as f:
json.dump(et.get_history(), f)
radar_plot = plot_radar_chart(et.config)
diagram_path = get_transformer_diagram(et.config)
history_outputs = [gr.Textbox(label=f"Gen {i+1} Config", value=json.dumps(cfg, indent=2), lines=4) for i, cfg in enumerate(et.get_history())]
return (
f"{final_eval['accuracy']*100:.2f}%",
f"{final_eval['params']:.2f}M params",
json.dumps(et.config, indent=2),
radar_plot,
diagram_path,
history_outputs,
csv_path,
json_path
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🧬 EvoTransformer – Evolving Transformer Architectures")
gr.Markdown("Simulate trait mutation and adaptive architecture generation.")
with gr.Row():
generations_slider = gr.Slider(1, 10, value=3, label="Number of Generations", step=1)
evolve_btn = gr.Button("🧬 Evolve Architecture", variant="primary")
with gr.Row():
accuracy_output = gr.Textbox(label="Simulated Accuracy")
param_output = gr.Textbox(label="Estimated Parameters")
current_config = gr.Textbox(label="Current Config Summary", lines=5)
with gr.Column():
gr.Markdown("## 🧬 Evolution History")
radar_output = gr.Image(label="Final Generation Trait Radar", height=400)
diagram_output = gr.Image(label="Illustrative Transformer Structure", height=300)
history_group = gr.Group()
with gr.Row():
csv_download = gr.File(label="Download CSV History")
json_download = gr.File(label="Download JSON History")
evolve_btn.click(
fn=run_evolution,
inputs=[generations_slider],
outputs=[
accuracy_output,
param_output,
current_config,
radar_output,
diagram_output,
history_group,
csv_download,
json_download,
],
)
if __name__ == "__main__":
demo.launch()