HemanM's picture
Update plot.py
848a355 verified
raw
history blame
985 Bytes
# plot.py
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image
def plot_radar_chart(config):
# Traits to plot
traits = ["layers", "attention_heads", "ffn_dim", "dropout"]
values = [
config["layers"],
config["attention_heads"],
config["ffn_dim"] / 512, # Normalize
config["dropout"] * 10, # Normalize
]
# Repeat first value to close the loop
values += values[:1]
angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
angles += angles[:1]
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
ax.plot(angles, values, "o-", linewidth=2)
ax.fill(angles, values, alpha=0.25)
ax.set_yticklabels([])
ax.set_xticks(angles[:-1])
ax.set_xticklabels(traits)
ax.set_title("Final Generation Trait Radar", fontsize=14)
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image = Image.open(buf)
return image