EvoTransformer-Demo / diagrams.py
HemanM's picture
Update diagrams.py
5f10669 verified
raw
history blame
920 Bytes
import matplotlib.pyplot as plt
def plot_architecture_diagram(config):
try:
fig, ax = plt.subplots(figsize=(6, 2)) # reduced size
layers = config.get("layers", 4)
heads = config.get("attention_heads", 4)
ffn = config.get("ffn_dim", 1024)
mem = config.get("memory", False)
labels = []
for i in range(layers):
lbl = f"L{i+1}\n{heads}H\n{ffn}F"
if mem and i == layers // 2:
lbl += "\nMem"
labels.append(lbl)
ax.plot(range(layers), [1] * layers, "o-", linewidth=2)
for i, label in enumerate(labels):
ax.text(i, 1.02, label, ha="center", fontsize=8)
ax.set_ylim(0.9, 1.15)
ax.axis("off")
fig.savefig("architecture_diagram.png", bbox_inches="tight", dpi=150)
plt.close(fig)
except Exception as e:
print("⚠️ Diagram plot error:", e)