HemanM commited on
Commit
b0fd119
·
verified ·
1 Parent(s): 0be77ef

Update plot.py

Browse files
Files changed (1) hide show
  1. plot.py +18 -20
plot.py CHANGED
@@ -1,33 +1,31 @@
1
  # plot.py
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
- import io
5
  from PIL import Image
6
 
7
  def plot_radar_chart(config):
8
- # Traits to plot
9
- traits = ["layers", "attention_heads", "ffn_dim", "dropout"]
10
- values = [
11
- config["layers"],
12
- config["attention_heads"],
13
- config["ffn_dim"] / 512, # Normalize
14
- config["dropout"] * 10, # Normalize
15
- ]
16
- # Repeat first value to close the loop
17
  values += values[:1]
18
- angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
19
  angles += angles[:1]
20
 
21
- fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
22
- ax.plot(angles, values, "o-", linewidth=2)
23
  ax.fill(angles, values, alpha=0.25)
24
- ax.set_yticklabels([])
25
- ax.set_xticks(angles[:-1])
26
- ax.set_xticklabels(traits)
27
- ax.set_title("Final Generation Trait Radar", fontsize=14)
28
 
29
- buf = io.BytesIO()
 
 
 
 
 
30
  plt.savefig(buf, format="png")
 
31
  buf.seek(0)
32
- image = Image.open(buf)
33
- return image
 
1
  # plot.py
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
+ from io import BytesIO
5
  from PIL import Image
6
 
7
  def plot_radar_chart(config):
8
+ labels = list(config.keys())
9
+ values = list(config.values())
10
+
11
+ # Convert boolean to numeric for plotting
12
+ values = [1 if v is True else 0 if v is False else v for v in values]
13
+
14
+ angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
 
 
15
  values += values[:1]
 
16
  angles += angles[:1]
17
 
18
+ fig, ax = plt.subplots(figsize=(6, 6), subplot_kw={'polar': True})
19
+ ax.plot(angles, values, 'o-', linewidth=2)
20
  ax.fill(angles, values, alpha=0.25)
 
 
 
 
21
 
22
+ ax.set_thetagrids(np.degrees(angles[:-1]), labels)
23
+ ax.set_title("Trait Radar Chart", fontsize=14)
24
+ ax.grid(True)
25
+
26
+ buf = BytesIO()
27
+ plt.tight_layout()
28
  plt.savefig(buf, format="png")
29
+ plt.close(fig)
30
  buf.seek(0)
31
+ return Image.open(buf)