File size: 985 Bytes
848a355
febaebe
 
848a355
 
febaebe
 
848a355
 
 
 
 
 
 
 
 
febaebe
848a355
febaebe
 
 
848a355
febaebe
848a355
 
 
 
febaebe
848a355
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# plot.py
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image

def plot_radar_chart(config):
    # Traits to plot
    traits = ["layers", "attention_heads", "ffn_dim", "dropout"]
    values = [
        config["layers"],
        config["attention_heads"],
        config["ffn_dim"] / 512,  # Normalize
        config["dropout"] * 10,   # Normalize
    ]
    # Repeat first value to close the loop
    values += values[:1]
    angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
    angles += angles[:1]

    fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
    ax.plot(angles, values, "o-", linewidth=2)
    ax.fill(angles, values, alpha=0.25)
    ax.set_yticklabels([])
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(traits)
    ax.set_title("Final Generation Trait Radar", fontsize=14)

    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    image = Image.open(buf)
    return image