HemanM's picture
Update app.py
eeda69b verified
raw
history blame
2.83 kB
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import io
from evo_transformer import EvoTransformer
# Global instance (resettable)
evo = EvoTransformer()
# === Visualization Functions ===
def plot_radar(config):
labels = ["Layers", "Attention Heads", "FFN Dim", "Dropout", "Memory"]
values = [
config["layers"],
config["attention_heads"],
config["ffn_dim"],
int(config["dropout"] * 100),
int(config["memory"])
]
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
values += values[:1]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
ax.plot(angles, values, "o-", linewidth=2)
ax.fill(angles, values, alpha=0.25)
ax.set_thetagrids(np.degrees(angles[:-1]), labels)
ax.set_title("Final Architecture (Radar Chart)")
return fig
def evolve_and_display(generations):
global evo
evo = EvoTransformer() # Reset model
evo.evolve(generations)
df = evo.get_history_df()
final_config = evo.get_final_config()
accuracy, params = evo.evaluate()
fig = plot_radar(final_config)
json_file = io.BytesIO()
json_file.write(evo.get_history_json().encode("utf-8"))
json_file.seek(0)
csv_file = io.BytesIO()
df.to_csv(csv_file, index=False)
csv_file.seek(0)
return (
accuracy,
params,
gr.Tabs.update(visible=True),
[gr.Textbox.update(value=str(row)) for _, row in df.iterrows()],
fig,
(csv_file, "evo_history.csv"),
(json_file, "evo_history.json"),
)
# === Gradio UI ===
with gr.Blocks(title="EvoTransformer Live Demo") as demo:
gr.Markdown(
"πŸš€ **EvoTransformer Live Demo**\n\n"
"This demo evolves a Transformer architecture and displays how traits change over generations."
)
with gr.Row():
generations = gr.Slider(1, 10, value=5, label="Generations")
evolve_btn = gr.Button("Evolve Now πŸš€")
with gr.Row():
accuracy_out = gr.Number(label="Estimated Accuracy", value=0)
params_out = gr.Number(label="Estimated Params (M)", value=0)
tabbox = gr.Tabs(visible=False)
with tabbox:
with gr.Tab(label="Evolution History"):
history_display = [gr.Textbox(label=str(i+1), interactive=False) for i in range(10)]
with gr.Tab(label="Radar View"):
radar_plot = gr.Plot()
with gr.Row():
csv_btn = gr.File(label="Download CSV")
json_btn = gr.File(label="Download JSON")
evolve_btn.click(
evolve_and_display,
inputs=[generations],
outputs=[accuracy_out, params_out, tabbox, history_display, radar_plot, csv_btn, json_btn]
)
if __name__ == "__main__":
demo.launch()