Spaces:
Sleeping
Sleeping
import gradio as gr | |
from evo_transformer import EvoTransformer | |
from plot import plot_radar_chart | |
import tempfile | |
import json | |
import csv | |
from PIL import Image | |
# Initialize EvoTransformer | |
evo = EvoTransformer() | |
def evolve_model(generations): | |
evo.reset() # Reset state | |
evo.run_evolution(generations) | |
# Get final config | |
final_config = evo.history[-1] | |
summary = f"layers: {final_config['layers']}\nattention_heads: {final_config['attention_heads']}\n" | |
summary += f"ffn_dim: {final_config['ffn_dim']}\ndropout: {final_config['dropout']}\n" | |
summary += f"memory: {final_config['memory']}" | |
# Simulate results | |
accuracy = round(final_config['layers'] * 1.23 + final_config['dropout'] * 10, 2) | |
params = round(final_config['layers'] * final_config['ffn_dim'] * final_config['attention_heads'] / 1000, 2) | |
accuracy_str = f"{accuracy:.2f}%" | |
params_str = f"{params:.2f}M params" | |
# Radar chart | |
radar_img = plot_radar_chart(final_config) | |
# History logs | |
gen_configs = [] | |
for i, cfg in enumerate(evo.history): | |
gen_configs.append(f"Gen {i+1} Config: {cfg}") | |
# Write CSV | |
csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='') | |
writer = csv.DictWriter(csv_file, fieldnames=final_config.keys()) | |
writer.writeheader() | |
writer.writerows(evo.history) | |
csv_file.close() | |
# Write JSON | |
json_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w') | |
json.dump(evo.history, json_file) | |
json_file.close() | |
# Pad history to 10 generations | |
while len(gen_configs) < 10: | |
gen_configs.append("") | |
return ( | |
accuracy_str, # 1 | |
params_str, # 2 | |
summary, # 3 | |
radar_img, # 4 | |
None, # 5 (diagram placeholder) | |
gen_configs[0], # 6 | |
gen_configs[1], # 7 | |
gen_configs[2], # 8 | |
gen_configs[3], # 9 | |
gen_configs[4], # 10 | |
gen_configs[5], # 11 | |
gen_configs[6], # 12 | |
gen_configs[7], # 13 | |
gen_configs[8], # 14 | |
gen_configs[9], # 15 | |
csv_file.name, # 16 | |
json_file.name # 17 | |
) | |
# Gradio UI | |
with gr.Blocks(title="🧬 EvoTransformer – Evolving Transformer Architectures") as demo: | |
gr.Markdown("## 🧬 EvoTransformer – Evolving Transformer Architectures\nSimulate trait mutation and adaptive architecture generation.") | |
with gr.Row(): | |
generations_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Generations") | |
evolve_btn = gr.Button("🧬 Evolve Architecture") | |
with gr.Row(): | |
accuracy_output = gr.Textbox(label="Simulated Accuracy") | |
params_output = gr.Textbox(label="Estimated Parameters") | |
summary_output = gr.Textbox(label="Current Config Summary") | |
gr.Markdown("## 🧬 Evolution History") | |
radar_output = gr.Image(label="Final Generation Trait Radar", type="pil") | |
diagram_output = gr.Image(label="Illustrative Transformer Structure") # Not used, just placeholder | |
with gr.Accordion("Downloadable Files", open=False): | |
csv_output = gr.File(label="Download CSV History") | |
json_output = gr.File(label="Download JSON History") | |
gen_outputs = [gr.Textbox(label=f"Gen {i+1} Config") for i in range(10)] | |
evolve_btn.click( | |
evolve_model, | |
inputs=[generations_input], | |
outputs=[ | |
accuracy_output, params_output, summary_output, | |
radar_output, diagram_output, | |
*gen_outputs, | |
csv_output, json_output | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |