File size: 4,873 Bytes
a9286c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import re
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
mpl.use("svg")
# parse a '.md' file and find a table. return table information
def parse_table(filepath):
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
header = []
body = []
found_start = False # if found table start line
parse_done = False # if parse table done
for l in lines:
if found_start and parse_done:
break
l = l.strip()
if not l:
continue
if l.startswith("|") and l.endswith("|"):
if not found_start:
found_start = True
row = [c.strip() for c in l.split("|") if c.strip()]
if not header:
header = row
else:
body.append(row)
elif found_start:
parse_done = True
return header, body
# parse models information
def parse_data(models_info):
min_list = []
max_list = []
colors = []
for model in models_info:
# remove \*
data = [x.replace("\\*", "") for x in model]
# get max data
max_data = -1
max_idx = -1
min_data = 9999999
min_idx = -1
for i in range(len(data)):
try:
d = float(data[i])
if d > max_data:
max_data = d
max_idx = i
if d < min_data:
min_data = d
min_idx = i
except:
pass
min_list.append(min_idx)
max_list.append(max_idx)
# calculate colors
color = []
for t in data:
try:
t = (float(t) - min_data) / (max_data - min_data)
color.append(cmap(t))
except:
color.append('white')
colors.append(color)
return colors, min_list, max_list
if __name__ == '__main__':
hardware_info, models_info = parse_table("./README.md")
cmap = mpl.colormaps.get_cmap("RdYlGn_r")
# remove empty line
models_info.pop(0)
# remove reference
hardware_info = [re.sub(r'\[(.+?)]\(.+?\)', r'\1', r) for r in hardware_info]
models_info = [[re.sub(r'\[(.+?)]\(.+?\)', r'\1', c) for c in r] for r in models_info]
table_colors, min_list, max_list = parse_data(models_info)
table_texts = [hardware_info] + models_info
table_colors = [['white'] * len(hardware_info)] + table_colors
# create a color bar. base width set to 1000, color map height set to 80
fig, axs = plt.subplots(nrows=3, figsize=(10, 0.8))
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
axs[0].imshow(gradient, aspect='auto', cmap=cmap)
axs[0].text(-0.01, 0.5, "Faster", va='center', ha='right', fontsize=11, transform=axs[0].transAxes)
axs[0].text(1.01, 0.5, "Slower", va='center', ha='left', fontsize=11, transform=axs[0].transAxes)
# initialize a table
table = axs[1].table(cellText=table_texts,
cellColours=table_colors,
cellLoc="left",
loc="upper left")
# adjust table position
table_pos = axs[1].get_position()
axs[1].set_position([
table_pos.x0,
table_pos.y0 - table_pos.height,
table_pos.width,
table_pos.height
])
table.set_fontsize(11)
table.auto_set_font_size(False)
table.scale(1, 2)
table.auto_set_column_width(list(range(len(table_texts[0]))))
table.AXESPAD = 0 # cancel padding
# highlight the best number
for i in range(len(min_list)):
cell = table.get_celld()[(i + 1, min_list[i])]
cell.set_text_props(weight='bold', color='white')
table_height = 0
table_width = 0
# calculate table height and width
for i in range(len(table_texts)):
cell = table.get_celld()[(i, 0)]
table_height += cell.get_height()
for i in range(len(table_texts[0])):
cell = table.get_celld()[(0, i)]
table_width += cell.get_width() + 0.1
# add notes for table
axs[2].text(0, -table_height - 0.8, "\*: Models are quantized in per-channel mode, which run slower than per-tensor quantized models on NPU.", va='bottom', ha='left', fontsize=11, transform=axs[1].transAxes)
# turn off labels
for ax in axs:
ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
# adjust color map position to center
cm_pos = axs[0].get_position()
axs[0].set_position([
(table_width - 1) / 2,
cm_pos.y0,
cm_pos.width,
cm_pos.height
])
plt.rcParams['svg.fonttype'] = 'none'
plt.savefig("./color_table.svg", format='svg', bbox_inches="tight", pad_inches=0, metadata={'Date': None, 'Creator': None})
|