HemanM's picture
Update app.py
68934cc verified
raw
history blame
3.23 kB
import gradio as gr
import pandas as pd
import json
import os
from evo_transformer import EvoTransformer
from plot import plot_radar_chart
from diagram import plot_architecture_diagram
# === Initialize ===
evo = EvoTransformer()
# === Gradio Output Placeholders ===
accuracy_output = gr.Text(label="Simulated Accuracy")
params_output = gr.Text(label="Estimated Parameters")
current_config_output = gr.Text(label="Current Config Summary")
radar_image = gr.Image(label="Final Generation Trait Radar")
diagram_image = gr.Image(label="Illustrative Transformer Structure")
history_outputs = []
# === Core Evolution Function ===
def evolve_model(generations):
evo.__init__() # Reset EvoTransformer
evo.evolve(generations)
score = evo.evaluate()
radar_path = "radar_chart.png"
diagram_path = "architecture_diagram.png"
# Plot radar chart
try:
plot_radar_chart(evo.config, radar_path)
except Exception as e:
print("Radar chart error:", e)
radar_path = None
# Plot architecture diagram
diagram_bytes = None
try:
plot_architecture_diagram(evo.config)
if os.path.exists(diagram_path):
with open(diagram_path, "rb") as f:
diagram_bytes = f.read()
except Exception as e:
print("Diagram plot error:", e)
# History cards
history = evo.get_history()
history_cards = [f"Gen {i+1} Config: {h}" for i, h in enumerate(history)]
# History downloadables
df = pd.DataFrame(history)
df_csv_path = "evo_history.csv"
df.to_csv(df_csv_path, index=False)
json_path = "evo_history.json"
with open(json_path, "w") as f:
json.dump(history, f, indent=2)
return (
f"{score['accuracy']*100:.2f}%",
f"{score['params']:.2f}M params",
str(evo.config),
radar_path,
diagram_bytes,
*history_cards,
df_csv_path,
json_path
)
# === Gradio Interface ===
generations_input = gr.Slider(1, 10, value=3, step=1, label="Number of Generations")
evolve_btn = gr.Button("\U0001F9EC Evolve Architecture")
with gr.Blocks(title="EvoTransformer") as demo:
gr.Markdown("""
# 🧬 EvoTransformer – Evolving Transformer Architectures
Simulate trait mutation and adaptive architecture generation.
""")
with gr.Row():
generations_input.render()
evolve_btn.render()
accuracy_output.render()
params_output.render()
current_config_output.render()
gr.Markdown("## 🧬 Evolution History")
radar_image.render()
diagram_image.render()
with gr.Accordion("Downloadable Files", open=True):
csv_file = gr.File(label="Download CSV History")
json_file = gr.File(label="Download JSON History")
for _ in range(10):
card = gr.Textbox(label="", interactive=False)
history_outputs.append(card)
evolve_btn.click(
fn=evolve_model,
inputs=[generations_input],
outputs=[
accuracy_output,
params_output,
current_config_output,
radar_image,
diagram_image,
*history_outputs,
csv_file,
json_file
]
)
demo.launch()