Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import numpy as np | |
import io | |
from evo_transformer import EvoTransformer | |
# Global instance (resettable) | |
evo = EvoTransformer() | |
# === Visualization Functions === | |
def plot_radar(config): | |
labels = ["Layers", "Attention Heads", "FFN Dim", "Dropout", "Memory"] | |
values = [ | |
config["layers"], | |
config["attention_heads"], | |
config["ffn_dim"], | |
int(config["dropout"] * 100), | |
int(config["memory"]) | |
] | |
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist() | |
values += values[:1] | |
angles += angles[:1] | |
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) | |
ax.plot(angles, values, "o-", linewidth=2) | |
ax.fill(angles, values, alpha=0.25) | |
ax.set_thetagrids(np.degrees(angles[:-1]), labels) | |
ax.set_title("Final Architecture (Radar Chart)") | |
return fig | |
def evolve_and_display(generations): | |
global evo | |
evo = EvoTransformer() # Reset model | |
evo.evolve(generations) | |
df = evo.get_history_df() | |
final_config = evo.get_final_config() | |
accuracy, params = evo.evaluate() | |
fig = plot_radar(final_config) | |
json_file = io.BytesIO() | |
json_file.write(evo.get_history_json().encode("utf-8")) | |
json_file.seek(0) | |
csv_file = io.BytesIO() | |
df.to_csv(csv_file, index=False) | |
csv_file.seek(0) | |
return ( | |
accuracy, | |
params, | |
gr.Tabs.update(visible=True), | |
[gr.Textbox.update(value=str(row)) for _, row in df.iterrows()], | |
fig, | |
(csv_file, "evo_history.csv"), | |
(json_file, "evo_history.json"), | |
) | |
# === Gradio UI === | |
with gr.Blocks(title="EvoTransformer Live Demo") as demo: | |
gr.Markdown( | |
"π **EvoTransformer Live Demo**\n\n" | |
"This demo evolves a Transformer architecture and displays how traits change over generations." | |
) | |
with gr.Row(): | |
generations = gr.Slider(1, 10, value=5, label="Generations") | |
evolve_btn = gr.Button("Evolve Now π") | |
with gr.Row(): | |
accuracy_out = gr.Number(label="Estimated Accuracy", value=0) | |
params_out = gr.Number(label="Estimated Params (M)", value=0) | |
tabbox = gr.Tabs(visible=False) | |
with tabbox: | |
with gr.Tab(label="Evolution History"): | |
history_display = [gr.Textbox(label=str(i+1), interactive=False) for i in range(10)] | |
with gr.Tab(label="Radar View"): | |
radar_plot = gr.Plot() | |
with gr.Row(): | |
csv_btn = gr.File(label="Download CSV") | |
json_btn = gr.File(label="Download JSON") | |
evolve_btn.click( | |
evolve_and_display, | |
inputs=[generations], | |
outputs=[accuracy_out, params_out, tabbox, history_display, radar_plot, csv_btn, json_btn] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |