import gradio as gr import matplotlib.pyplot as plt import pandas as pd import numpy as np import io from evo_transformer import EvoTransformer # Global instance (resettable) evo = EvoTransformer() # === Visualization Functions === def plot_radar(config): labels = ["Layers", "Attention Heads", "FFN Dim", "Dropout", "Memory"] values = [ config["layers"], config["attention_heads"], config["ffn_dim"], int(config["dropout"] * 100), int(config["memory"]) ] angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist() values += values[:1] angles += angles[:1] fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) ax.plot(angles, values, "o-", linewidth=2) ax.fill(angles, values, alpha=0.25) ax.set_thetagrids(np.degrees(angles[:-1]), labels) ax.set_title("Final Architecture (Radar Chart)") return fig def evolve_and_display(generations): global evo evo = EvoTransformer() # Reset model evo.evolve(generations) df = evo.get_history_df() final_config = evo.get_final_config() accuracy, params = evo.evaluate() fig = plot_radar(final_config) json_file = io.BytesIO() json_file.write(evo.get_history_json().encode("utf-8")) json_file.seek(0) csv_file = io.BytesIO() df.to_csv(csv_file, index=False) csv_file.seek(0) return ( accuracy, params, gr.Tabs.update(visible=True), [gr.Textbox.update(value=str(row)) for _, row in df.iterrows()], fig, (csv_file, "evo_history.csv"), (json_file, "evo_history.json"), ) # === Gradio UI === with gr.Blocks(title="EvoTransformer Live Demo") as demo: gr.Markdown( "🚀 **EvoTransformer Live Demo**\n\n" "This demo evolves a Transformer architecture and displays how traits change over generations." ) with gr.Row(): generations = gr.Slider(1, 10, value=5, label="Generations") evolve_btn = gr.Button("Evolve Now 🚀") with gr.Row(): accuracy_out = gr.Number(label="Estimated Accuracy", value=0) params_out = gr.Number(label="Estimated Params (M)", value=0) tabbox = gr.Tabs(visible=False) with tabbox: with gr.Tab(label="Evolution History"): history_display = [gr.Textbox(label=str(i+1), interactive=False) for i in range(10)] with gr.Tab(label="Radar View"): radar_plot = gr.Plot() with gr.Row(): csv_btn = gr.File(label="Download CSV") json_btn = gr.File(label="Download JSON") evolve_btn.click( evolve_and_display, inputs=[generations], outputs=[accuracy_out, params_out, tabbox, history_display, radar_plot, csv_btn, json_btn] ) if __name__ == "__main__": demo.launch()