Spaces:
Sleeping
Sleeping
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import plotly.colors as pc | |
| from scipy.stats import gaussian_kde | |
| import numpy as np | |
| import polars as pl | |
| import gradio as gr | |
| from math import ceil | |
| from translate import max_pitch_types, jp_pitch_to_en_pitch | |
| from data import ( | |
| df, | |
| # pitch_stats, rhb_pitch_stats,lhb_pitch_stats, | |
| # league_pitch_stats, rhb_league_pitch_stats, lhb_league_pitch_stats | |
| compute_pitch_stats, compute_league_pitch_stats | |
| ) | |
| MAX_LOCS = len(jp_pitch_to_en_pitch) | |
| LOCS_PER_ROW = 4 | |
| MAX_ROWS = ceil(MAX_LOCS/LOCS_PER_ROW) | |
| INSUFFICIENT_PITCHES_MSG = 'No visualization: Not enough pitches thrown' | |
| INSUFFICIENT_PITCHES_MSG_MULTI_LINE = 'No visualization:<br>Not enough pitches thrown' | |
| # GRADIO FUNCTIONS | |
| # def clone_if_dataframe(item): | |
| # if isinstance(item, pl.DataFrame): | |
| # # print(type(item)) | |
| # return item.clone() | |
| # else: | |
| # return item | |
| # | |
| # def clone_df(fn): | |
| # def _fn(*args, **kwargs): | |
| # args = [clone_if_dataframe(arg) for arg in args] | |
| # kwargs = {k: clone_if_dataframe(arg) for k, arg in kwargs.items()} | |
| # return fn(*args, **kwargs) | |
| # return _fn | |
| # | |
| def copy_dataframe(df, num_copy_to): | |
| return [df.clone() for _ in range(num_copy_to)] | |
| # location maps | |
| def fit_pred_kde(data, X, Y): | |
| kde = gaussian_kde(data) | |
| return kde(np.stack((X, Y)).reshape(2, -1)).reshape(*X.shape) | |
| plot_s = 256 | |
| sz_h = 200 | |
| sz_w = 160 | |
| h_h = 200 - 40*2 | |
| h_w = 160 - 32*2 | |
| kde_range = np.arange(-plot_s/2, plot_s/2, 1) | |
| X, Y = np.meshgrid( | |
| kde_range, | |
| kde_range | |
| ) | |
| def coordinatify(h, w): | |
| return dict( | |
| x0=-w/2, | |
| y0=-h/2, | |
| x1=w/2, | |
| y1=h/2 | |
| ) | |
| colorscale = pc.sequential.OrRd | |
| colorscale = [ | |
| [0, 'rgba(0, 0, 0, 0)'], | |
| ] + [ | |
| [i / len(colorscale), color] for i, color in enumerate(colorscale, start=1) | |
| ] | |
| # @clone_df | |
| def plot_loc(df, handedness, league_df=None, min_pitches=3, max_pitches=5000): | |
| loc = df.select(['plate_x', 'plate_z']) | |
| fig = go.Figure() | |
| if len(loc) >= min_pitches: | |
| Z = fit_pred_kde(loc.to_numpy().T, X, Y) | |
| fig.add_shape( | |
| type="rect", | |
| **coordinatify(sz_h, sz_w), | |
| line_color='gray', | |
| # fillcolor='rgba(220, 220, 220, 0.75)', #gainsboro | |
| ) | |
| fig.add_shape( | |
| type="rect", | |
| **coordinatify(h_h, h_w), | |
| line_color='dimgray', | |
| ) | |
| fig.add_trace(go.Contour( | |
| z=Z, | |
| x=kde_range, | |
| y=kde_range, | |
| colorscale=colorscale, | |
| zmin=1e-5, | |
| zmax=Z.max(), | |
| contours={ | |
| 'start': 1e-5, | |
| 'end': Z.max(), | |
| 'size': Z.max() / 5 | |
| }, | |
| showscale=False | |
| )) | |
| else: | |
| fig.add_annotation( | |
| x=0, | |
| y=0, | |
| text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE, | |
| showarrow=False | |
| ) | |
| if league_df is not None: | |
| league_loc = league_df.select(pl.col('plate_x', 'plate_z')) | |
| if len(league_loc) > max_pitches: | |
| league_loc = league_loc.sample(max_pitches, seed=0) | |
| if len(league_loc) >= min_pitches: | |
| league_Z = fit_pred_kde(league_loc.to_numpy().T, X, Y) | |
| percentile = np.quantile(league_Z, 0.9) | |
| fig.add_trace(go.Contour( | |
| z=league_Z, | |
| x=kde_range, | |
| y=kde_range, | |
| colorscale=[ | |
| [0, 'rgba(0, 0, 0, 0)'], | |
| [1, 'rgba(0, 0, 0, 0)'] | |
| ], | |
| zmin=percentile, | |
| zmax=league_Z.max(), | |
| contours={ | |
| 'start': percentile, | |
| 'end': league_Z.max(), | |
| 'size': league_Z.max() - percentile, | |
| # 'coloring': 'heatmap' | |
| }, | |
| line={ | |
| 'width': 2, | |
| 'color': 'black', | |
| 'dash': 'dash' | |
| }, | |
| showlegend=True, | |
| showscale=False, | |
| visible=True if handedness != 'Both' else 'legendonly', | |
| name='NPB' | |
| )) | |
| fig.update_layout( | |
| xaxis=dict(range=[-plot_s/2, plot_s/2+1], showticklabels=False), | |
| yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1, showticklabels=False), | |
| legend=dict(orientation='h', y=0, yanchor='top'), | |
| # width=384, | |
| # height=384 | |
| ) | |
| return fig | |
| # velo distribution | |
| # @clone_df | |
| def plot_velo(df=None, player=None, velos=None, pitch_type=None, pitch_name=None, min_pitches=2): | |
| assert not ((velos is None and player is None) or (velos is not None and player is not None)), 'exactly one of `player` or `velos` must be specified' | |
| if velos is None and player is not None: | |
| assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified' | |
| assert df is not None, '`df` must be provided if `velos` not provided' | |
| pitch_val = pitch_type or pitch_name | |
| pitch_col = 'pitch_type' if pitch_type else 'pitch_name' | |
| # velos = df.set_index(['name', pitch_col]).sort_index().loc[(player, pitch_val), 'release_speed'] | |
| velos = df.filter((pl.col('name') == player) & (pl.col(pitch_col) == pitch_val))['release_speed'] | |
| fig = go.Figure() | |
| if len(velos) >= min_pitches: | |
| fig = fig.add_trace(go.Violin(x=velos, side='positive', hoveron='points', points=False, meanline_visible=True, name='Velocity Distribution')) | |
| median = velos.median() | |
| x_range = [median-25, median+25] | |
| else: | |
| fig.add_annotation( | |
| x=(170+125)/2, | |
| y=0.3/2, | |
| text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE, | |
| showarrow=False, | |
| ) | |
| x_range = [125, 170] | |
| fig.update_layout( | |
| xaxis=dict( | |
| title='Velocity', | |
| range=x_range, | |
| scaleratio=2 | |
| ), | |
| yaxis=dict( | |
| title='Frequency', | |
| range=[0, 0.3], | |
| scaleanchor='x', | |
| scaleratio=1, | |
| tickvals=np.linspace(0, 0.3, 3), | |
| ticktext=np.linspace(0, 0.3, 3), | |
| ), | |
| autosize=True, | |
| # width=512, | |
| # height=256, | |
| modebar_remove=['zoom', 'autoScale', 'resetScale'], | |
| ) | |
| return fig | |
| # @clone_df | |
| def plot_velo_summary(df, league_df, player): | |
| min_pitches = 2 | |
| # player_df = df.set_index('name').sort_index().loc[player].sort_values('pitch_name').set_index('pitch_name') | |
| # pitch_counts = player_df.index.value_counts(ascending=True) | |
| player_df = df.filter(pl.col('release_speed').is_not_null()) | |
| pitch_counts = player_df['pitch_name'].value_counts().sort('count') | |
| # league_df = df.set_index('pitch_name').sort_index() | |
| league_df = league_df.filter(pl.col('release_speed').is_not_null()) | |
| fig = go.Figure() | |
| min_velo = player_df['release_speed'].min() if len(player_df) else 130 | |
| max_velo = player_df['release_speed'].max() if len(player_df) else 160 | |
| velo_center = (min_velo + max_velo) / 2 | |
| # for i, (pitch_name, count) in enumerate(pitch_counts.items()): | |
| for i, (pitch_name, count) in enumerate(pitch_counts.iter_rows()): | |
| # velos = player_df.loc[pitch_name, 'release_speed'] | |
| # league_velos = league_df.loc[pitch_name, 'release_speed'] | |
| velos = player_df.filter(pl.col('pitch_name') == pitch_name)['release_speed'] | |
| league_velos = league_df.filter(pl.col('pitch_name') == pitch_name)['release_speed'] | |
| fig.add_trace(go.Violin( | |
| x=league_velos, | |
| y=[pitch_name]*len(league_velos), | |
| line_color='gray', | |
| side='positive', | |
| orientation='h', | |
| meanline_visible=True, | |
| points=False, | |
| legendgroup='NPB', | |
| legendrank=1, | |
| # visible='legendonly', | |
| # showlegend=False, | |
| showlegend=i==0, | |
| name='NPB', | |
| )) | |
| if count >= min_pitches: | |
| fig.add_trace(go.Violin( | |
| x=velos, | |
| y=[pitch_name]*len(velos), | |
| side='positive', | |
| orientation='h', | |
| meanline_visible=True, | |
| points=False, | |
| legendgroup=pitch_name, | |
| legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i), | |
| name=pitch_name | |
| )) | |
| else: | |
| fig.add_trace(go.Scatter( | |
| x=[velo_center], | |
| y=[pitch_name], | |
| text=[INSUFFICIENT_PITCHES_MSG], | |
| textposition='top center', | |
| hovertext=False, | |
| mode="lines+text", | |
| legendgroup=pitch_name, | |
| legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i), | |
| name=pitch_name, | |
| )) | |
| # fig.add_trace(go.Violin( | |
| # x=league_df['release_speed'], | |
| # y=[player]*len(league_df), | |
| # line_color='gray', | |
| # side='positive', | |
| # orientation='h', | |
| # meanline_visible=True, | |
| # points=False, | |
| # legendgroup='NPB', | |
| # legendrank=1, | |
| # # visible='legendonly', | |
| # name='NPB', | |
| # )) | |
| # fig.add_trace(go.Violin( | |
| # x=player_df['release_speed'], | |
| # y=[player]*len(player_df), | |
| # side='positive', | |
| # orientation='h', | |
| # meanline_visible=True, | |
| # points=False, | |
| # legendrank=0, | |
| # name=player | |
| # )) | |
| # fig.update_xaxes(title='Velocity', range=[player_df['release_speed'].dropna().min() - 2, player_df['release_speed'].dropna().max() + 2]) | |
| fig.update_xaxes(title='Velocity', range=[min_velo - 2, max_velo + 2]) | |
| # fig.update_yaxes(range=[0, len(pitch_counts)+1-0.25], visible=False) | |
| fig.update_yaxes(range=[0, len(pitch_counts)-0.25], visible=False) | |
| fig.update_layout( | |
| violingap=0, | |
| violingroupgap=0, | |
| legend=dict(orientation='h', y=-0.15, yanchor='top'), | |
| modebar_remove=['zoom', 'select2d', 'lasso2d', 'pan', 'autoScale'], | |
| dragmode=False | |
| ) | |
| return fig | |
| def update_dfs(player, handedness, start_date, end_date, df): | |
| date_filter = (pl.col('game_date') >= start_date) & (pl.col('game_date') <= end_date) | |
| if handedness == 'Both': | |
| handedness_filter = pl.col('stand').is_in(['R', 'L']) | |
| # _pitch_stats = pitch_stats | |
| # _league_pitch_stats = league_pitch_stats | |
| elif handedness == 'Right': | |
| handedness_filter = pl.col('stand') == 'R' | |
| # _pitch_stats = rhb_pitch_stats | |
| # _league_pitch_stats = rhb_league_pitch_stats | |
| elif handedness == 'Left': | |
| handedness_filter = pl.col('stand') == 'L' | |
| # _pitch_stats = lhb_pitch_stats | |
| # _league_pitch_stats = lhb_league_pitch_stats | |
| player_filter = pl.col('name') == player | |
| non_player_filter = handedness_filter & date_filter | |
| final_filter = player_filter & non_player_filter | |
| _df = df.filter(final_filter) | |
| _league_df = df.filter(non_player_filter) | |
| return ( | |
| _df, | |
| _league_df, | |
| compute_pitch_stats(_df), | |
| compute_league_pitch_stats(_league_df), | |
| ) | |
| def create_set_download_file_fn(filepath): | |
| def set_download_file(df): | |
| df.write_csv(filepath) | |
| return filepath | |
| return set_download_file | |
| def preview_df(df): | |
| return df.head() | |
| # @clone_df | |
| def plot_usage(df, player): | |
| fig = px.pie(df.select('pitch_name'), names='pitch_name') | |
| fig.update_traces(texttemplate='%{percent:.1%}', hovertemplate=f'<b>{player}</b><br>' + 'threw a <b>%{label}</b><br><b>%{percent:.1%}</b> of the time (<b>%{value}</b> pitches)') | |
| return fig | |
| # @clone_df | |
| def plot_pitch_cards(df, league_df, pitch_stats, handedness): | |
| pitch_counts = df['pitch_name'].value_counts().sort('count', descending=True) | |
| pitch_rows = [] | |
| pitch_groups = [] | |
| pitch_names = [] | |
| pitch_infos = [] | |
| pitch_velos = [] | |
| pitch_locs = [] | |
| for row in range(ceil(len(pitch_counts) / LOCS_PER_ROW)): | |
| pitch_rows.append(gr.update(visible=True)) | |
| for row in range(len(pitch_rows), MAX_ROWS): | |
| pitch_rows.append(gr.update(visible=False)) | |
| for pitch_name, count in pitch_counts.iter_rows(): | |
| pitch_groups.append(gr.update(visible=True)) | |
| pitch_names.append(gr.update(value=f'### {pitch_name}', visible=True)) | |
| pitch_infos.append(gr.update( | |
| value=pitch_stats.filter(pl.col('pitch_name') == pitch_name).select(['Whiff%', 'CSW%']), | |
| visible=True | |
| )) | |
| pitch_velos.append(gr.update( | |
| value=plot_velo(velos=df.filter((pl.col('pitch_name') == pitch_name) & (pl.col('release_speed').is_not_null()))['release_speed']), | |
| visible=True | |
| )) | |
| pitch_locs.append(gr.update( | |
| value=plot_loc( | |
| df=df.filter(pl.col('pitch_name') == pitch_name), | |
| handedness=handedness, | |
| league_df=league_df.filter(pl.col('pitch_name') == pitch_name) | |
| ), | |
| label='Pitch location', | |
| visible=True | |
| )) | |
| for _ in range(max_pitch_types - len(pitch_names)): | |
| pitch_groups.append(gr.update(visible=False)) | |
| pitch_names.append(gr.update(value=None, visible=False)) | |
| pitch_infos.append(gr.update(value=None, visible=False)) | |
| pitch_velos.append(gr.update(value=None, visible=False)) | |
| pitch_locs.append(gr.update(value=None, visible=False)) | |
| return pitch_rows + pitch_groups + pitch_names + pitch_infos + pitch_velos + pitch_locs | |
| # @clone_df | |
| def update_velo_stats(pitch_stats, league_pitch_stats): | |
| return ( | |
| pitch_stats | |
| .select( | |
| pl.col('pitch_name').alias('Pitch'), | |
| pl.col('Velocity (KPH)').alias('Avg. Velo (KPH)'), | |
| pl.col('Velocity (MPH)').alias('Avg. Velo (MPH)'), | |
| pl.col('Count') | |
| ) | |
| .join( | |
| league_pitch_stats.select( | |
| pl.col('pitch_name').alias('Pitch'), | |
| pl.col('Velocity (KPH)').alias('League Avg. Velo (KPH)'), | |
| pl.col('Velocity (MPH)').alias('Leauge Avg. Velo (MPH)'), | |
| ), | |
| on='Pitch', | |
| how='inner' | |
| ) | |
| .sort('Count', descending=True) | |
| .drop('Count') | |
| ) | |