HemanM commited on
Commit
fee1f1e
·
verified ·
1 Parent(s): febaebe

Create diagrams.py

Browse files
Files changed (1) hide show
  1. diagrams.py +23 -0
diagrams.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import tempfile
3
+
4
+ def get_transformer_diagram(config):
5
+ fig, ax = plt.subplots(figsize=(8, 2))
6
+ ax.axis("off")
7
+
8
+ layer_count = config["layers"]
9
+ memory = config.get("memory", False)
10
+
11
+ for i in range(layer_count):
12
+ ax.add_patch(plt.Rectangle((i * 1.5, 0.5), 1.2, 1, edgecolor='black', facecolor='skyblue'))
13
+ ax.text(i * 1.5 + 0.6, 1, f"Layer {i+1}", ha='center', va='center')
14
+
15
+ ax.text(layer_count * 1.5, 1, "→ Output", va='center', fontsize=12, weight='bold')
16
+
17
+ if memory:
18
+ ax.text(layer_count * 0.75, 1.6, "Memory Enabled", color='purple', fontsize=10, ha='center')
19
+
20
+ file_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
21
+ plt.savefig(file_path, bbox_inches="tight")
22
+ plt.close()
23
+ return file_path