EvoTransformer-Demo / diagram.py
HemanM's picture
Update diagram.py
6e375cd verified
raw
history blame
464 Bytes
# diagram.py
import matplotlib.pyplot as plt
import tempfile
def show_transformer_diagram():
fig, ax = plt.subplots(figsize=(6, 4))
ax.text(0.5, 0.6, 'Input β†’ Embedding β†’ Self-Attention β†’ FFN β†’ Output',
fontsize=12, ha='center', va='center', wrap=True)
ax.axis('off')
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmpfile:
plt.savefig(tmpfile.name, bbox_inches="tight")
return tmpfile.name