npb_data_app / plotting.py
patrickramos's picture
Add general pitch classification
d1369a2
raw
history blame
13.3 kB
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import transforms
from matplotlib.colors import LinearSegmentedColormap
import polars as pl
from pyfonts import load_google_font
from scipy.stats import gaussian_kde
import numpy as np
from types import SimpleNamespace
from datetime import date
from data import data_df
from convert import ball_kind_code_to_color, get_text_color_from_color
from stats import filter_data_by_date_and_game_kind, compute_team_games, compute_pitch_stats
mpl.use('Agg')
def get_pitcher_stats(id, lr=None, game_kind=None, start_date=None, end_date=None, min_ip=1, min_pitches=1, pitch_class_type='specific'):
source_data = data_df.filter(pl.col('ballKind_code') != '-')
# if start_date is not None:
# source_data = source_data.filter(pl.col('date') >= start_date)
# if end_date is not None:
# source_data = source_data.filter(pl.col('date') <= end_date)
#
# if game_kind is not None:
# source_data = source_data.filter(pl.col('coarse_game_kind') == game_kind)
source_data = filter_data_by_date_and_game_kind(source_data, start_date=start_date, end_date=end_date, game_kind=game_kind)
source_data = (
compute_team_games(source_data)
.with_columns(
pl.when(pl.col('half_inning').str.ends_with('1')).then('home_games').otherwise('visitor_games').first().over('pitId').alias('games'),
pl.col('inning_code').unique().len().over('pitId').alias('IP')
)
)
if min_ip == 'qualified':
source_data = source_data.with_columns((pl.col('IP') >= pl.col('games')).alias('qualified'))
else:
source_data = source_data.with_columns((pl.col('IP') >= min_ip).alias('qualified'))
if lr is not None:
source_data = source_data.filter(pl.col('batLR') == lr)
pitch_stats = compute_pitch_stats(source_data, player_type='pitcher', pitch_class_type=pitch_class_type, min_pitches=min_pitches).filter(pl.col('pitId') == id)
pitch_shapes = (
source_data
.filter(
(pl.col('pitId') == id) &
pl.col('x').is_not_null() &
pl.col('y').is_not_null() &
(pl.col('ballSpeed') > 0)
)
[['pitId', 'general_ballKind_code', 'ballKind_code', 'ballSpeed', 'x', 'y']]
)
pitcher_stats = (
source_data
.group_by('pitId')
.agg(
pl.col('pitcher_name').first(),
(pl.when(pl.col('presult').str.contains('strikeout')).then(1).otherwise(0).sum() / pl.col('pa_code').unique().len()).alias('K%'),
(pl.when(pl.col('presult') == 'Walk').then(1).otherwise(0).sum() / pl.col('pa_code').unique().len()).alias('BB%'),
(pl.col('csw').sum() / pl.col('pitch').sum()).alias('CSW%'),
pl.col('aux_bresult').struct.field('batType').drop_nulls().value_counts(normalize=True),
pl.first('qualified')
)
.explode('batType')
.unnest('batType')
.pivot(on='batType', values='proportion')
.fill_null(0)
.with_columns(
(pl.col('G') + pl.col('B')).alias('GB%'),
(pl.col('F') + pl.col('P')).alias('FB%'),
pl.col('L').alias('LD%'),
)
.drop('G', 'F', 'B', 'P', 'L')
.with_columns(
(pl.when(pl.col('qualified')).then(pl.col(stat)).rank(descending=(stat == 'BB%'))/pl.when(pl.col('qualified')).then(pl.col(stat)).count()).alias(f'{stat}_pctl')
for stat in ['CSW%', 'K%', 'BB%', 'GB%']
)
.filter(pl.col('pitId') == id)
)
return SimpleNamespace(pitcher_stats=pitcher_stats, pitch_stats=pitch_stats, pitch_shapes=pitch_shapes)
def get_card_data(id, **kwargs):
both, left, right = get_pitcher_stats(id, **kwargs), get_pitcher_stats(id, 'l', **kwargs), get_pitcher_stats(id, 'r', **kwargs)
pitcher_stats = both.pitcher_stats.join(left.pitcher_stats, on='pitId', suffix='_left').join(right.pitcher_stats, on='pitId', suffix='_right')
pitch_stats = both.pitch_stats.join(left.pitch_stats, on='ballKind_code', how='full', suffix='_left').join(right.pitch_stats, on='ballKind_code', how='full', suffix='_right').fill_null(0)
return SimpleNamespace(
pitcher_stats=pitcher_stats,
pitch_stats=pitch_stats,
both_pitch_shapes=both.pitch_shapes,
left_pitch_shapes=left.pitch_shapes,
right_pitch_shapes=right.pitch_shapes
)
def plot_arsenal(ax, pitches):
ax.set_xlim(0, 11)
x = np.arange(len(pitches)) + 0.5
y = np.zeros(len(pitches))
ax.scatter(x, y, c=[ball_kind_code_to_color.get(pitch, 'C0') for pitch in pitches], s=170)
for i, pitch in enumerate(pitches):
color = ball_kind_code_to_color.get(pitch, 'C0')
ax.text(x=i+0.5, y=0, s=pitch, horizontalalignment='center', verticalalignment='center', font=font, color=get_text_color_from_color(color))
def plot_usage(ax, usages):
left = 0
height = 0.8
for pitch, usage in usages.iter_rows():
color = ball_kind_code_to_color[pitch]
ax.barh(0, usage, height=height, left=left, color=color)
if usage > 0.1:
ax.text(left+usage/2, 0, f'{usage:.0%}', horizontalalignment='center', verticalalignment='center', size=8, font=font, color=get_text_color_from_color(color))
left += usage
ax.set_xlim(0, 1)
ax.set_ylim(-height/2, height/2*2.75)
x_range = np.arange(-100, 100+1)
y_range = np.arange(0, 250+1)
X, Y = np.meshgrid(x_range, y_range)
def fit_pred_kde(data):
kde = gaussian_kde(data)
Z = kde(np.concat((X, Y)).reshape(2, -1)).reshape(*X.shape)
return Z
def plot_loc(ax, locs):
ax.set_aspect('equal', adjustable='datalim')
ax.set_ylim(-52, 252)
ax.add_patch(plt.Rectangle((-100, 0), width=200, height=250, facecolor='darkgray', edgecolor='dimgray'))
ax.add_patch(plt.Rectangle((-80, 25), width=160, height=200, facecolor='gainsboro', edgecolor='dimgray'))
ax.add_patch(plt.Rectangle((-60, 50), width=120, height=150, fill=False, edgecolor='yellowgreen', linestyle=':'))
ax.add_patch(plt.Rectangle((-40, 75), width=80, height=100, facecolor='ivory', edgecolor='darkgray'))
ax.add_patch(plt.Polygon([(0, -10), (45, -30), (51, -50), (-51, -50), (-45, -30), (0, -10)], facecolor='snow', edgecolor='darkgray'))
for (pitch,), _locs in locs.sort(pl.len().over('general_ballKind_code'), descending=True).group_by('general_ballKind_code', maintain_order=True):
if len(_locs) <= 2:
continue
Z = fit_pred_kde(_locs[['x', 'y']].to_numpy().T)
Z = Z / Z.sum()
Z_flat = Z.ravel()
sorted_Z = np.sort(Z_flat)
sorted_Z_idxs = np.argsort(Z_flat)
Z_cumsum = (sorted_Z).cumsum()
t = Z_flat[sorted_Z_idxs[np.argmin(np.abs(Z_cumsum - (1-0.68)))]]
ax.contourf(X, Y, Z, levels=[t, 1], colors=ball_kind_code_to_color[pitch], alpha=0.5)
ax.contour(X, Y, Z, levels=t.reshape(1), colors=ball_kind_code_to_color[pitch], alpha=0.75)
def plot_velo(ax, velos):
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
for (pitch,), _velos in velos.group_by('general_ballKind_code'):
if len(_velos) <= 1:
continue
violin = ax.violinplot(_velos['ballSpeed'], orientation='horizontal', side='high', showextrema=False)
for _violin in violin['bodies']:
_violin.set_facecolor(ball_kind_code_to_color[pitch])
mean = _velos['ballSpeed'].mean()
ax.text(mean, 0.5, round(mean), horizontalalignment='center', verticalalignment='center', color='gray', alpha=0.75, font=font, transform=trans)
stat_cmap = LinearSegmentedColormap.from_list('stat', colors=['dodgerblue', 'snow', 'crimson'])
def plot_pitch_stats(ax, stats, stat_names):
ax.set_aspect('equal', adjustable='datalim')
# axis_to_data = lambda coords: ax.transData.inverted().transform(ax.transAxes.transform(coords))
table = mpl.table.Table(ax)
rows = len(stat_names) + 1
cols = len(stats) + 1
cell_height = 1/rows
cell_width = 1/cols
for row, stat in enumerate(stat_names, start=1):
cell = table.add_cell(row=row, col=0, width=cell_width, height=cell_height, text=stat, loc='center', fontproperties=font, edgecolor='white')
for col, pitch in enumerate(stats['ballKind_code'], start=1):
color = ball_kind_code_to_color.get(pitch, 'C0')
cell = table.add_cell(row=0, col=col, width=cell_width, height=cell_height, text=pitch, loc='center', fontproperties=font, facecolor=color, edgecolor='white')
cell.get_text().set_color(get_text_color_from_color(color))
_stats = stats.filter(pl.col('ballKind_code') == pitch)
qualified = _stats['qualified'].item()
for row, stat_name in enumerate(stat_names, start=1):
stat = _stats[stat_name].item()
stat_pctl = _stats[f'{stat_name}_pctl'].item()
cell = table.add_cell(row=row, col=col, width=cell_width, height=cell_height, text=f'{stat:.0%}', loc='center', fontproperties=font, facecolor=(stat_cmap([0, stat_pctl, 1])[1] if qualified else 'gainsboro'), edgecolor='white')
if not qualified:
cell.get_text().set_color('gray')
ax.add_artist(table)
def plot_pitcher_stats(ax, stats, stat_names):
ax.set_aspect('equal', adjustable='datalim')
table = mpl.table.Table(ax)
cell_height = 1
cell_width = 1/(len(stat_names)*2)
qualified = stats['qualified'].item()
for i, stat_name in enumerate(stat_names):
stat = stats[stat_name].item()
stat_pctl = stats[f'{stat_name}_pctl'].item()
table.add_cell(row=0, col=i*2, width=cell_width, height=cell_height, text=stat_name, loc='center', fontproperties=font, edgecolor='white')
cell = table.add_cell(row=0, col=i*2+1, width=cell_width, height=cell_height, text=f'{stat:.0%}', loc='center', fontproperties=font, facecolor=(stat_cmap([0, stat_pctl, 1])[1] if qualified else 'gainsboro'), edgecolor='white')
if not qualified:
cell.get_text().set_color('gray')
ax.add_artist(table)
font = load_google_font('Saira Extra Condensed', weight='medium')
def create_pitcher_overview_card(id, season, dpi=300):
data = get_card_data(id, start_date=date(season, 1, 1), end_date=date(season, 12, 31), game_kind='Regular Season', min_pitches=100, pitch_class_type='general')
fig = plt.figure(figsize=(1080/300, 1350/300), dpi=dpi)
gs = fig.add_gridspec(8, 6, height_ratios=[1, 1, 1.5, 6, 1, 3, 1, 0.5])
title_ax = fig.add_subplot(gs[0, :])
title_ax.text(x=0, y=0, s=data.pitcher_stats['pitcher_name'].item().upper(), verticalalignment='baseline', font=font, size=20)
# title_ax.text(x=1, y=1, s='2021\n-2023', horizontalalignment='right', verticalalignment='top', font=font, size=8)
title_ax.text(x=0.95, y=0, s=season, horizontalalignment='right', verticalalignment='baseline', font=font, size=20)
title_ax.text(x=1, y=0.5, s='REG', horizontalalignment='right', verticalalignment='center', font=font, size=10, rotation='vertical')
arsenal_ax = fig.add_subplot(gs[1, :])
plot_arsenal(arsenal_ax, data.pitch_stats['ballKind_code'])
usage_l_ax = fig.add_subplot(gs[2, :3])
plot_usage(usage_l_ax, data.pitch_stats[['ballKind_code', 'usage_left']])
usage_l_ax.text(0, 1, 'LHH usage', horizontalalignment='left', verticalalignment='top', linespacing=0.5, color='gray', font=font, size=10, transform=usage_l_ax.transAxes)
usage_r_ax = fig.add_subplot(gs[2, 3:])
plot_usage(usage_r_ax, data.pitch_stats[['ballKind_code', 'usage_right']])
usage_r_ax.text(0, 1, 'RHH usage', horizontalalignment='left', verticalalignment='top', linespacing=0.5, color='gray', font=font, size=10, transform=usage_r_ax.transAxes)
loc_l_ax = fig.add_subplot(gs[3, :3])
loc_l_ax.text(0, 1, 'LHH\nloc', verticalalignment='top', horizontalalignment='left', color='gray', font=font, size=10, transform=loc_l_ax.transAxes)
plot_loc(loc_l_ax, data.left_pitch_shapes)
loc_r_ax = fig.add_subplot(gs[3, 3:])
loc_r_ax.text(0, 1, 'RHH\nloc', verticalalignment='top', horizontalalignment='left', color='gray', font=font, size=10, transform=loc_r_ax.transAxes)
plot_loc(loc_r_ax, data.right_pitch_shapes)
velo_ax = fig.add_subplot(gs[4, :])
plot_velo(velo_ax, data.both_pitch_shapes)
velo_ax.text(0, 1, 'Velo', verticalalignment='top', horizontalalignment='left', color='gray', font=font, size=10, transform=velo_ax.transAxes)
pitch_stats_ax = fig.add_subplot(gs[5, :])
plot_pitch_stats(pitch_stats_ax, data.pitch_stats, ['CSW%', 'GB%'])
pitcher_stats_ax = fig.add_subplot(gs[6, :])
plot_pitcher_stats(pitcher_stats_ax, data.pitcher_stats, ['CSW%', 'K%', 'BB%', 'GB%'])
# k_ax = fig.add_subplot(gs[5, :2])
# plot_stat(k_ax, data.pitcher_stats, 'K%')
# bb_ax = fig.add_subplot(gs[5, 2:4])
# plot_stat(bb_ax, data.pitcher_s`tats, 'BB%')
# gb_ax = fig.add_subplot(gs[5, 4:])
# plot_stat(gb_ax, data.pitcher_stats, 'GB%')
credits_ax = fig.add_subplot(gs[7, :])
credits_ax.text(x=0, y=0.5, s='Data: SPAIA, Sanspo', verticalalignment='center', font=font, size=7)
credits_ax.text(x=1, y=0.5, s='@yakyucosmo', horizontalalignment='right', verticalalignment='center', font=font, size=7)
for ax in [
title_ax,
arsenal_ax,
usage_l_ax, usage_r_ax,
loc_l_ax, loc_r_ax,
velo_ax,
# k_ax, bb_ax, gb_ax,
pitch_stats_ax,
pitcher_stats_ax,
credits_ax
]:
ax.axis('off')
ax.tick_params(
axis='both',
which='both',
length=0,
labelbottom=False,
labelleft=False
)
return fig
# fig = create_card('1600153', season=2023, dpi=300)
# plt.show()