npb_data_app / data.py
patrickramos's picture
Remove debug lines form data.py
15d84ea
import polars as pl
import os
from tqdm.auto import tqdm
import pykakasi
from huggingface_hub import snapshot_download
import numpy as np
from string import ascii_letters
from convert import (
aux_global_id_to_code, presult,
team_name_short,
ball_kind, ball_kind_code, general_ball_kind, general_ball_kind_code, lr,
game_kind
)
DATA_PATH = snapshot_download(
repo_id='Ramos-Ramos/npb_data_app',
repo_type='dataset',
local_dir='./files',
cache_dir='./.cache',
allow_patterns=['*/pbp_data.parquet', '*/pbp_text.parquet', '*/pbp_aux.parquet', '*/schedule.parquet', '*/aux_schedule.parquet', 'players.parquet', 'players_translated.parquet', 'players_translated_manual.parquet']
)
SEASONS = [2021, 2022, 2023, 2024, 2025]
data_df = pl.DataFrame()
text_df = pl.DataFrame()
aux_df = pl.DataFrame()
sched_df = pl.DataFrame()
aux_sched_df = pl.DataFrame()
for season in tqdm(SEASONS):
_data_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'pbp_data.parquet'))
data_df = pl.concat((data_df, _data_df))
_text_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'pbp_text.parquet'))
text_df = pl.concat((text_df, _text_df))
_aux_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'pbp_aux.parquet'))
aux_df = pl.concat((aux_df, _aux_df), how='diagonal_relaxed')
_sched_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'schedule.parquet'))
sched_df = pl.concat((sched_df, _sched_df))
_aux_sched_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'aux_schedule.parquet'))
aux_sched_df = pl.concat((aux_sched_df, _aux_sched_df))
def select_name(names):
'''
When given mutiple names,
prioritizes the name with ASCII characters (ex. R. マルティネス > マルティネス),
followed by the shorter name (ex. 大勢 > 翁田 大勢)
Names with ASCII characters help differentiate between foreign players,
whlie shorter names are more accurate for players going by shorter names
'''
lens = []
for name in names:
if any([char in ascii_letters for char in name]):
return name
else:
lens.append(len(name))
return names[np.argmin(lens).item()]
# load player dfs
players_df = (
pl.read_parquet('files/players.parquet')
.with_columns(pl.col('playerName').str.normalize('NFKC').str.replace_all('・', ' '))
.group_by('playerId').agg(pl.col('playerName').map_elements(select_name, return_dtype=pl.String))
)
translated_df = (
pl.read_parquet('files/players_translated.parquet')
.with_columns(pl.col('name_jp').str.normalize('NFKC').str.replace_all('・', ' '))
)
manual_translated_df = pl.read_parquet('files/players_translated_manual.parquet')
# names with no romanization are approximated with kana translation
kks = pykakasi.kakasi()
# take names in parenthesis when they contain an ascii character
translated_df = (
translated_df
.with_columns(
pl.when(pl.col('name_jp').str.contains(r'\('))
.then(pl.col('name_jp').str.extract(r'.*\(', 0).str.strip_chars_end(' ('))
.otherwise(pl.col('name_jp'))
.str.replace_all('・', ' ')
.alias('name_jp')
)
.with_columns(pl.col('name_kana').str.normalize('NFKC').str.replace_all('・', ' '))
.with_columns(pl.col('name_kana').str.extract(r'\(.*\)', 0).str.strip_chars('()').alias('in_parentheses'))
.with_columns(pl.col('name_kana').str.extract(r'.*\(', 0).str.strip_chars_end('(').alias('before_parentheses'))
.with_columns(
pl.when(pl.col('name_en').is_null())
.then
(
pl.when(pl.col('in_parentheses').is_not_null() | pl.col('before_parentheses').is_not_null())
.then(
pl.when(pl.col('in_parentheses').map_elements(lambda name: any([char in ascii_letters for char in name]), pl.Boolean))
.then(pl.col('in_parentheses'))
.otherwise(pl.col('before_parentheses'))
)
.otherwise(pl.col('name_kana').map_elements(lambda name: ''.join([word['hepburn'].capitalize() for word in kks.convert(name)]), return_dtype=pl.String))
)
.otherwise(pl.col('name_en'))
.alias('name_en')
)
.with_columns(pl.col('name_en').str.replace_all(',', '').str.to_titlecase())
)
# handle inconsistent kanji between sources
for old_char, new_char in [
('崎', '﨑'),
('高', '髙'),
('徳', '德'),
('濱', '濵'),
('瀬', '瀨')
]:
players_df = (
players_df.with_columns(
pl.when(~pl.col('playerName').is_in(translated_df['name_jp']))
.then(pl.col('playerName').str.replace(old_char, new_char))
.otherwise('playerName')
)
)
# merge player dfs
players_df = (
players_df
.join(manual_translated_df.rename({'name_en': 'name_en_manual'}), on='playerId', how='left')
.join(
(
translated_df
.with_columns(
pl.when(pl.col('name_jp').str.contains(r'\.') & ~pl.col('name_jp').is_in(players_df.filter(pl.len().over('playerName') == 1)['playerName']))
.then(pl.col('name_jp').str.strip_chars(ascii_letters+'.'))
.otherwise('name_jp')
)
[['name_jp', 'name_en']]
),
left_on='playerName', right_on='name_jp', how='left'
)
.with_columns(pl.coalesce('name_en_manual', 'name_en').alias('name_en'))
.unique() # remove duplicates from names with multiple matches in other dataframes
.drop('name_en_manual', 'name_jp')
# .filter(pl.col('name_en').is_null())
)
aux_df = (
aux_df
.filter(pl.col('type') != 'RUNNER')
.join(aux_sched_df[['gameGlobalId', 'gameDate']], on='gameGlobalId')
.with_columns(
pl.col('gameDate').str.to_date().dt.strftime('%Y%m%d'),
pl.col('home').struct.field('globalId').replace_strict(aux_global_id_to_code).alias('home'),
pl.col('visitor').struct.field('globalId').replace_strict(aux_global_id_to_code).alias('visitor'),
pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2')).alias('tob_code'),
)
.filter(
# pl.col('pitch').struct.field('count') > 0
# either one alone should be enough but let's use them together to be safe
~((pl.col('code') == 98) & (pl.col('id') == 1))
)
.with_columns(
(pl.col('pitch').struct.field('count') == 1).cum_sum().over(['gameGlobalId', 'inning', 'tob']).alias('pa_count')
)
.with_columns(
pl.col('code').is_in([6402, 6404, 6406, 6405]).any().over(['gameGlobalId', 'inning', 'tob', 'pa_count']).alias('ibb')
)
.with_columns(
pl.when(~pl.col('ibb')).then(pl.col('pitch').struct.field('count') == 1).cum_sum().over(['gameGlobalId', 'inning', 'tob']).alias('new_pa_count')
)
.with_columns(
pl.len().over(['gameGlobalId', 'inning', 'tob', 'new_pa_count']).alias('pa_pitches'),
pl.max('new_pa_count').over(['gameGlobalId', 'inning', 'tob']).alias('inning_pas')
)
.with_columns(
(
pl.col('gameDate') + '_' + \
pl.col('visitor') + '_' + \
pl.col('home') + '_' + \
pl.col('inning').str.zfill(2) + pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2')) + pl.col('new_pa_count').cast(pl.String).str.zfill(2) + '_' +\
pl.col('pitch').struct.field('count').cast(pl.String)
).alias('universal_code'),
(
pl.col('gameDate') + '_' + \
pl.col('visitor') + '_' + \
pl.col('home') + '_' + \
pl.col('inning').str.zfill(2) + pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2'))
).alias('inning_code'),
(
pl.col('gameDate') + '_' + \
pl.col('visitor') + '_' + \
pl.col('home') + '_' + \
pl.col('inning').str.zfill(2) + pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2')) + pl.col('new_pa_count').cast(pl.String).str.zfill(2)
).alias('pa_code')
)
)
data_df = (
data_df
.with_columns(
*[
pl.col(col).cast(pl.Int32)
for col
in ['gameId', 'ballKind', 'ballSpeed', 'x', 'y', 'presult', 'bresult', 'battedX', 'battedY']
],
pl.col('UpdatedAt').str.to_datetime(),
pl.col('fiveDigitSerialNumber').str.slice(offset=0, length=3).alias('half_inning'),
pl.col('fiveDigitSerialNumber').str.slice(offset=3, length=2).alias('batter'),
)
.with_columns(
# pl.count('ID').over(['gameId', 'fiveDigitSerialNumber']).alias('pa_pitches')
(~pl.col('presult').is_in([0])).sum().over(['gameId', 'fiveDigitSerialNumber']).alias('pa_pitches'),
pl.col('presult').is_in([139]).any().over(['gameId', 'fiveDigitSerialNumber']).alias('ibb')
)
.filter(
(pl.col('pa_pitches') > 0)
)
.with_columns(
pl.when(~pl.col('ibb')).then(pl.col('batter'))
)
.with_columns(
pl.when(~pl.col('ibb')).then(pl.col('batter').rank('dense')).over(['gameId', 'half_inning']).cast(pl.String).str.zfill(2).alias('new_batter')
)
.with_columns(
(pl.col('half_inning') + pl.col('new_batter')).alias('newFiveDigitSerialNumber')
)
.with_columns(pl.max('new_batter').cast(pl.Int32).over(['gameId', pl.col('newFiveDigitSerialNumber').str.slice(offset=0, length=3)]).alias('inning_pas'))
.join(
(
sched_df[['GameID', 'HomeTeamNameES', 'VisitorTeamNameES']]
.rename({'GameID': 'gameId'})
.with_columns(
pl.col('HomeTeamNameES').replace_strict(team_name_short).alias('home_team_name_short'),
pl.col('VisitorTeamNameES').replace_strict(team_name_short).alias('visitor_team_name_short')
)
),
on='gameId'
)
.with_columns(pl.col('UpdatedAt').dt.strftime('%Y%m%d').alias('date'))
.with_columns(
(pl.col('date') + '_' + pl.col('VisitorTeamNameES') + '_' + pl.col('HomeTeamNameES') + '_' + pl.col('newFiveDigitSerialNumber')).alias('universal_code') + '_' + pl.col('atBatBallCount'),
(pl.col('date') + '_' + pl.col('VisitorTeamNameES') + '_' + pl.col('HomeTeamNameES') + '_' + pl.col('newFiveDigitSerialNumber').str.slice(offset=0, length=3)).alias('inning_code'),
(pl.col('date') + '_' + pl.col('VisitorTeamNameES') + '_' + pl.col('HomeTeamNameES') + '_' + pl.col('newFiveDigitSerialNumber')).alias('pa_code')
)
.join(
(
aux_df.filter(~pl.col('ibb'))[['universal_code', 'battingResult', 'inning_pas', 'pa_pitches']]
.rename({'battingResult': 'aux_bresult', 'inning_pas': 'aux_inning_pas', 'pa_pitches': 'aux_pa_pitches'})
),
on='universal_code',
how='left'
)
.join(
players_df.rename({'name_en': 'pitcher_name'}), left_on='pitId', right_on='playerId', how='left'
)
.join(
text_df[['GameID', 'GameKindID']].with_columns(
pl.col('GameID').cast(pl.Int32),
pl.col('GameKindID').cast(pl.Int32),
).unique(),
how='left',
left_on='gameId',
right_on='GameID'
)
.with_columns(pl.col('GameKindID').replace_strict(game_kind).alias('GameKindName'))
.with_columns(
pl.when((pl.col('inning_pas') == pl.col('aux_inning_pas')) & (pl.col('pa_pitches') == pl.col('aux_pa_pitches')))
.then('aux_bresult')
.alias('aux_bresult'),
pl.col('x').add(-100).mul(-1),
pl.col('y').neg().add(250),
pl.col('presult').alias('presult_id'),
pl.col('ballKind').replace_strict(ball_kind),
pl.col('ballKind').replace_strict(ball_kind_code).alias('ballKind_code'),
pl.col('ballKind').replace_strict(general_ball_kind).alias('general_ballKind'),
pl.col('ballKind').replace_strict(general_ball_kind_code).alias('general_ballKind_code'),
pl.col('batLR').replace_strict(lr),
pl.col('pitLR').replace_strict(lr),
pl.col('date').str.to_date('%Y%m%d'),
pl.when(pl.col('GameKindName').str.contains('Regular Season') | (pl.col('GameKindName') == 'Interleague'))
.then(pl.lit('Regular Season'))
.when(~pl.col('GameKindName').is_in(['Spring Training', 'All-Star Game']))
.then(pl.lit('Postseason'))
.otherwise('GameKindName')
.alias('coarse_game_kind'),
pl.when(pl.col('half_inning').str.ends_with(1)).then('HomeTeamNameES').otherwise('VisitorTeamNameES').alias('pitcher_team'),
pl.when(pl.col('half_inning').str.ends_with(1)).then('home_team_name_short').otherwise('visitor_team_name_short').alias('pitcher_team_name_short')
)
.with_columns(
pl.col('presult_id').replace_strict(presult).alias('presult')
)
.with_columns(
pl.col('presult').is_in(['None', 'Balk', 'Batter interference', 'Catcher interference', 'Pitcher delay', 'Intentional walk', 'Unknown']).not_().alias('pitch'),
pl.col('presult').is_in(['Swinging strike', 'Swinging strikeout']).alias('whiff'),
)
.with_columns(
(pl.col('pitch') & pl.col('presult').is_in(['Hit by pitch', 'Sacrifice bunt', 'Sacrifice fly', 'Looking strike', 'Ball', 'Walk', 'Looking strikeout', 'Sacrifice hit error', 'Sacrifice fly error', "Sacrifice fielder's choice", 'Bunt strikeout']).not_()).alias('swing'),
(pl.col('whiff') | pl.col('presult').is_in(['Looking strike', 'Uncaught third strike', 'Looking strikeout'])).alias('csw')
)
.with_columns((pl.col('x').is_between(-60, 60) & pl.col('y').is_between(50, 50+150)).alias('zone'))
.with_columns((pl.col('x').is_between(-40, 40) & pl.col('y').is_between(75, 75+100)).alias('heart'))
.with_columns((pl.col('x').is_between(-80, 80) & pl.col('y').is_between(25, 25+200) & ~pl.col('heart')).alias('shadow'))
.with_columns((pl.col('x').is_between(-100, 101) & pl.col('y').is_between(0, 0+251) & ~pl.col('heart') & ~pl.col('shadow')).alias('chase'))
)
if __name__ == '__main__':
breakpoint()