HemanM commited on
Commit
848a355
·
verified ·
1 Parent(s): cfe4306

Update plot.py

Browse files
Files changed (1) hide show
  1. plot.py +23 -16
plot.py CHANGED
@@ -1,26 +1,33 @@
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
- import tempfile
 
4
 
5
  def plot_radar_chart(config):
6
- labels = list(config.keys())
7
- values = list(config.values())
8
-
9
- # Normalize boolean to int
10
- values = [int(v) if isinstance(v, bool) else v for v in values]
11
-
12
- angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
 
 
13
  values += values[:1]
 
14
  angles += angles[:1]
15
 
16
  fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
17
- ax.plot(angles, values, 'o-', linewidth=2)
18
  ax.fill(angles, values, alpha=0.25)
19
- ax.set_thetagrids(np.degrees(angles[:-1]), labels)
20
- ax.set_title("Trait Radar", va='bottom')
21
- ax.grid(True)
 
22
 
23
- file_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
24
- plt.savefig(file_path, bbox_inches="tight")
25
- plt.close()
26
- return file_path
 
 
1
+ # plot.py
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
+ import io
5
+ from PIL import Image
6
 
7
  def plot_radar_chart(config):
8
+ # Traits to plot
9
+ traits = ["layers", "attention_heads", "ffn_dim", "dropout"]
10
+ values = [
11
+ config["layers"],
12
+ config["attention_heads"],
13
+ config["ffn_dim"] / 512, # Normalize
14
+ config["dropout"] * 10, # Normalize
15
+ ]
16
+ # Repeat first value to close the loop
17
  values += values[:1]
18
+ angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
19
  angles += angles[:1]
20
 
21
  fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
22
+ ax.plot(angles, values, "o-", linewidth=2)
23
  ax.fill(angles, values, alpha=0.25)
24
+ ax.set_yticklabels([])
25
+ ax.set_xticks(angles[:-1])
26
+ ax.set_xticklabels(traits)
27
+ ax.set_title("Final Generation Trait Radar", fontsize=14)
28
 
29
+ buf = io.BytesIO()
30
+ plt.savefig(buf, format="png")
31
+ buf.seek(0)
32
+ image = Image.open(buf)
33
+ return image