Alvaro commited on
Commit
2aed0aa
·
1 Parent(s): bf7e729

Add new ML models and CLI tools for prediction

Browse files

Introduces several new machine learning models (XGBoost, LightGBM, SVC, RandomForest, BernoulliNB) to the prediction pipeline by refactoring model logic into a shared base class. Adds requirements for new dependencies. Implements two new CLI scripts: save_model.py for training and saving models, and predict_new.py for predicting outcomes of hypothetical fights using saved models. Improves preprocessing robustness for date parsing.

requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
  requests
2
  beautifulsoup4
3
  pandas
4
- scikit-learn
 
 
 
 
 
1
  requests
2
  beautifulsoup4
3
  pandas
4
+ scikit-learn
5
+ lazypredict
6
+ tqdm
7
+ xgboost
8
+ lightgbm
src/predict/main.py CHANGED
@@ -1,6 +1,14 @@
1
  import argparse
2
- from .models import EloBaselineModel, LogisticRegressionModel
3
  from .pipeline import PredictionPipeline
 
 
 
 
 
 
 
 
 
4
 
5
  def main():
6
  """
@@ -22,6 +30,11 @@ def main():
22
  models_to_run = [
23
  EloBaselineModel(),
24
  LogisticRegressionModel(),
 
 
 
 
 
25
  ]
26
  # --- End of Model Definition ---
27
 
 
1
  import argparse
 
2
  from .pipeline import PredictionPipeline
3
+ from .models import (
4
+ EloBaselineModel,
5
+ LogisticRegressionModel,
6
+ XGBoostModel,
7
+ SVCModel,
8
+ RandomForestModel,
9
+ BernoulliNBModel,
10
+ LGBMModel
11
+ )
12
 
13
  def main():
14
  """
 
30
  models_to_run = [
31
  EloBaselineModel(),
32
  LogisticRegressionModel(),
33
+ XGBoostModel(),
34
+ SVCModel(),
35
+ RandomForestModel(),
36
+ BernoulliNBModel(),
37
+ LGBMModel(),
38
  ]
39
  # --- End of Model Definition ---
40
 
src/predict/models.py CHANGED
@@ -4,6 +4,11 @@ import os
4
  from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
5
  import pandas as pd
6
  from sklearn.linear_model import LogisticRegression
 
 
 
 
 
7
  from ..config import FIGHTERS_CSV_PATH
8
  from .preprocess import preprocess_for_ml, _get_fighter_history_stats, _calculate_age
9
 
@@ -63,21 +68,24 @@ class EloBaselineModel(BaseModel):
63
  print(f"Warning: Could not find ELO for fighter {e}. Skipping prediction.")
64
  return None
65
 
66
- class LogisticRegressionModel(BaseModel):
67
  """
68
- A model that uses logistic regression to predict fight outcomes based on differential features.
 
69
  """
70
- def __init__(self):
71
- self.model = LogisticRegression(solver='liblinear', random_state=42)
 
 
72
  self.fighters_df = None
73
  self.fighter_histories = {}
74
 
75
  def train(self, train_fights):
76
  """
77
- Trains the logistic regression model by preprocessing the training data
78
- and fitting the model.
79
  """
80
- print("Training LogisticRegressionModel...")
81
 
82
  # 1. Prepare data for prediction-time feature generation
83
  self.fighters_df = pd.read_csv(FIGHTERS_CSV_PATH)
@@ -87,17 +95,16 @@ class LogisticRegressionModel(BaseModel):
87
  if col in self.fighters_df.columns:
88
  self.fighters_df[col] = pd.to_numeric(self.fighters_df[col], errors='coerce')
89
 
90
- # 2. Pre-calculate fighter histories for efficient lookup during prediction
91
  train_fights_with_dates = []
92
  for fight in train_fights:
93
  fight['date_obj'] = pd.to_datetime(fight['event_date'])
94
  train_fights_with_dates.append(fight)
95
-
96
  for fighter_name in self.fighters_df.index:
97
  history = [f for f in train_fights_with_dates if fighter_name in (f['fighter_1'], f['fighter_2'])]
98
  self.fighter_histories[fighter_name] = sorted(history, key=lambda x: x['date_obj'])
99
 
100
- # 3. Preprocess training data and fit the model
101
  X_train, y_train, _ = preprocess_for_ml(train_fights, FIGHTERS_CSV_PATH)
102
  print(f"Fitting model on {X_train.shape[0]} samples...")
103
  self.model.fit(X_train, y_train)
@@ -111,19 +118,19 @@ class LogisticRegressionModel(BaseModel):
111
  fight_date = pd.to_datetime(fight['event_date'])
112
 
113
  if f1_name not in self.fighters_df.index or f2_name not in self.fighters_df.index:
114
- print(f"Warning: Fighter not found in data. Skipping prediction for {f1_name} vs {f2_name}")
115
  return None
116
 
117
- # 1. Get base stats
118
- f1_stats, f2_stats = self.fighters_df.loc[f1_name], self.fighters_df.loc[f2_name]
119
  if isinstance(f1_stats, pd.DataFrame): f1_stats = f1_stats.iloc[0]
120
  if isinstance(f2_stats, pd.DataFrame): f2_stats = f2_stats.iloc[0]
121
 
122
- # 2. Get historical stats
123
- f1_hist_stats = _get_fighter_history_stats(f1_name, fight_date, self.fighter_histories.get(f1_name, []), self.fighters_df)
124
- f2_hist_stats = _get_fighter_history_stats(f2_name, fight_date, self.fighter_histories.get(f2_name, []), self.fighters_df)
 
125
 
126
- # 3. Create differential features
127
  f1_age = _calculate_age(f1_stats.get('dob'), fight['event_date'])
128
  f2_age = _calculate_age(f2_stats.get('dob'), fight['event_date'])
129
 
@@ -140,9 +147,37 @@ class LogisticRegressionModel(BaseModel):
140
  }
141
 
142
  feature_vector = pd.DataFrame([features]).fillna(0)
143
-
144
- # 4. Predict
145
- # The model predicts the probability of class '1' (a win for fighter_1)
146
  prediction = self.model.predict(feature_vector)[0]
147
-
148
- return f1_name if prediction == 1 else f2_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
5
  import pandas as pd
6
  from sklearn.linear_model import LogisticRegression
7
+ from sklearn.svm import SVC
8
+ from sklearn.naive_bayes import BernoulliNB
9
+ from sklearn.ensemble import RandomForestClassifier
10
+ from xgboost import XGBClassifier
11
+ from lightgbm import LGBMClassifier
12
  from ..config import FIGHTERS_CSV_PATH
13
  from .preprocess import preprocess_for_ml, _get_fighter_history_stats, _calculate_age
14
 
 
68
  print(f"Warning: Could not find ELO for fighter {e}. Skipping prediction.")
69
  return None
70
 
71
+ class BaseMLModel(BaseModel):
72
  """
73
+ An abstract base class for machine learning models that handles all common
74
+ data preparation, training, and prediction logic.
75
  """
76
+ def __init__(self, model):
77
+ if model is None:
78
+ raise ValueError("A model must be provided.")
79
+ self.model = model
80
  self.fighters_df = None
81
  self.fighter_histories = {}
82
 
83
  def train(self, train_fights):
84
  """
85
+ Trains the machine learning model. This involves loading fighter data,
86
+ pre-calculating histories, and fitting the model on the preprocessed data.
87
  """
88
+ print(f"--- Training {self.model.__class__.__name__} ---")
89
 
90
  # 1. Prepare data for prediction-time feature generation
91
  self.fighters_df = pd.read_csv(FIGHTERS_CSV_PATH)
 
95
  if col in self.fighters_df.columns:
96
  self.fighters_df[col] = pd.to_numeric(self.fighters_df[col], errors='coerce')
97
 
98
+ # 2. Pre-calculate fighter histories
99
  train_fights_with_dates = []
100
  for fight in train_fights:
101
  fight['date_obj'] = pd.to_datetime(fight['event_date'])
102
  train_fights_with_dates.append(fight)
 
103
  for fighter_name in self.fighters_df.index:
104
  history = [f for f in train_fights_with_dates if fighter_name in (f['fighter_1'], f['fighter_2'])]
105
  self.fighter_histories[fighter_name] = sorted(history, key=lambda x: x['date_obj'])
106
 
107
+ # 3. Preprocess and fit
108
  X_train, y_train, _ = preprocess_for_ml(train_fights, FIGHTERS_CSV_PATH)
109
  print(f"Fitting model on {X_train.shape[0]} samples...")
110
  self.model.fit(X_train, y_train)
 
118
  fight_date = pd.to_datetime(fight['event_date'])
119
 
120
  if f1_name not in self.fighters_df.index or f2_name not in self.fighters_df.index:
121
+ print(f"Warning: Fighter not found. Skipping prediction for {f1_name} vs {f2_name}")
122
  return None
123
 
124
+ f1_stats = self.fighters_df.loc[f1_name]
125
+ f2_stats = self.fighters_df.loc[f2_name]
126
  if isinstance(f1_stats, pd.DataFrame): f1_stats = f1_stats.iloc[0]
127
  if isinstance(f2_stats, pd.DataFrame): f2_stats = f2_stats.iloc[0]
128
 
129
+ f1_hist = self.fighter_histories.get(f1_name, [])
130
+ f2_hist = self.fighter_histories.get(f2_name, [])
131
+ f1_hist_stats = _get_fighter_history_stats(f1_name, fight_date, f1_hist, self.fighters_df)
132
+ f2_hist_stats = _get_fighter_history_stats(f2_name, fight_date, f2_hist, self.fighters_df)
133
 
 
134
  f1_age = _calculate_age(f1_stats.get('dob'), fight['event_date'])
135
  f2_age = _calculate_age(f2_stats.get('dob'), fight['event_date'])
136
 
 
147
  }
148
 
149
  feature_vector = pd.DataFrame([features]).fillna(0)
 
 
 
150
  prediction = self.model.predict(feature_vector)[0]
151
+ return f1_name if prediction == 1 else f2_name
152
+
153
+ class LogisticRegressionModel(BaseMLModel):
154
+ """A thin wrapper for scikit-learn's LogisticRegression."""
155
+ def __init__(self):
156
+ super().__init__(model=LogisticRegression())
157
+
158
+ class XGBoostModel(BaseMLModel):
159
+ """A thin wrapper for XGBoost's XGBClassifier."""
160
+ def __init__(self):
161
+ model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
162
+ super().__init__(model=model)
163
+
164
+ class SVCModel(BaseMLModel):
165
+ """A thin wrapper for scikit-learn's Support Vector Classifier."""
166
+ def __init__(self):
167
+ # Probability=True is needed for some reports, though it slows down training
168
+ super().__init__(model=SVC(probability=True, random_state=42))
169
+
170
+ class RandomForestModel(BaseMLModel):
171
+ """A thin wrapper for scikit-learn's RandomForestClassifier."""
172
+ def __init__(self):
173
+ super().__init__(model=RandomForestClassifier(random_state=42))
174
+
175
+ class BernoulliNBModel(BaseMLModel):
176
+ """A thin wrapper for scikit-learn's Bernoulli Naive Bayes classifier."""
177
+ def __init__(self):
178
+ super().__init__(model=BernoulliNB())
179
+
180
+ class LGBMModel(BaseMLModel):
181
+ """A thin wrapper for LightGBM's LGBMClassifier."""
182
+ def __init__(self):
183
+ super().__init__(model=LGBMClassifier(random_state=42))
src/predict/predict_new.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import joblib
4
+ from datetime import datetime
5
+
6
+ from ..config import OUTPUT_DIR
7
+
8
+ def predict_new_fight(fighter1_name, fighter2_name, model_path):
9
+ """
10
+ Loads a trained model and predicts the outcome of a new, hypothetical fight.
11
+ """
12
+ print("--- Predicting New Fight ---")
13
+
14
+ # 1. Load the trained model
15
+ if not os.path.exists(model_path):
16
+ raise FileNotFoundError(f"Model file not found at '{model_path}'. Please train and save a model first.")
17
+
18
+ print(f"Loading model from {model_path}...")
19
+ model = joblib.load(model_path)
20
+ print(f"Model '{model.model.__class__.__name__}' loaded.")
21
+
22
+ # 2. Create the fight dictionary for prediction
23
+ # The predict method requires a dictionary with specific keys.
24
+ # We use today's date as a placeholder for the event date.
25
+ fight = {
26
+ 'fighter_1': fighter1_name,
27
+ 'fighter_2': fighter2_name,
28
+ 'event_date': datetime.now().strftime('%B %d, %Y')
29
+ # Other keys like 'winner', 'method', etc., are not needed for prediction.
30
+ }
31
+
32
+ # 3. Make the prediction
33
+ print(f"\nPredicting winner for: {fighter1_name} vs. {fighter2_name}")
34
+ predicted_winner = model.predict(fight)
35
+
36
+ if predicted_winner:
37
+ print(f"\n---> Predicted Winner: {predicted_winner} <---")
38
+ else:
39
+ print("\nCould not make a prediction. One of the fighters may not be in the dataset.")
40
+
41
+ if __name__ == '__main__':
42
+ parser = argparse.ArgumentParser(description="Predict the outcome of a new UFC fight.")
43
+ parser.add_argument('fighter1', type=str, help="The full name of the first fighter (e.g., 'Jon Jones').")
44
+ parser.add_argument('fighter2', type=str, help="The full name of the second fighter (e.g., 'Stipe Miocic').")
45
+ parser.add_argument(
46
+ '--model_path',
47
+ type=str,
48
+ default=os.path.join(OUTPUT_DIR, 'trained_model.joblib'),
49
+ help="Path to the saved model file."
50
+ )
51
+ args = parser.parse_args()
52
+
53
+ predict_new_fight(args.fighter1, args.fighter2, args.model_path)
src/predict/preprocess.py CHANGED
@@ -122,7 +122,12 @@ def preprocess_for_ml(fights_to_process, fighters_csv_path):
122
  # 2. Pre-calculate fighter histories to speed up lookups
123
  # And convert date strings to datetime objects once
124
  for fight in fights_to_process:
125
- fight['date_obj'] = datetime.strptime(fight['event_date'], '%B %d, %Y')
 
 
 
 
 
126
 
127
  fighter_histories = {}
128
  for fighter_name in fighters_prepared.index:
 
122
  # 2. Pre-calculate fighter histories to speed up lookups
123
  # And convert date strings to datetime objects once
124
  for fight in fights_to_process:
125
+ try:
126
+ # This will work if event_date is a string
127
+ fight['date_obj'] = datetime.strptime(fight['event_date'], '%B %d, %Y')
128
+ except TypeError:
129
+ # This will be triggered if it's already a date-like object (e.g., Timestamp)
130
+ fight['date_obj'] = fight['event_date']
131
 
132
  fighter_histories = {}
133
  for fighter_name in fighters_prepared.index:
src/predict/save_model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import joblib
4
+ import pandas as pd
5
+
6
+ from ..config import FIGHTS_CSV_PATH, OUTPUT_DIR
7
+ import src.predict.models as models
8
+
9
+ def save_model(model_name):
10
+ """
11
+ Trains a specified model on the entire dataset and saves it to a file.
12
+
13
+ :param model_name: The name of the model class to train (e.g., 'XGBoostModel').
14
+ """
15
+ print(f"--- Training and Saving Model: {model_name} ---")
16
+
17
+ # 1. Get the model class from the models module
18
+ try:
19
+ ModelClass = getattr(models, model_name)
20
+ except AttributeError:
21
+ print(f"Error: Model '{model_name}' not found in src/predict/models.py")
22
+ return
23
+
24
+ model = ModelClass()
25
+
26
+ # 2. Load all available fights for training
27
+ if not os.path.exists(FIGHTS_CSV_PATH):
28
+ raise FileNotFoundError(f"Fights data not found at '{FIGHTS_CSV_PATH}'.")
29
+
30
+ all_fights = pd.read_csv(FIGHTS_CSV_PATH).to_dict('records')
31
+ print(f"Training model on all {len(all_fights)} available fights...")
32
+
33
+ # 3. Train the model
34
+ model.train(all_fights)
35
+
36
+ # 4. Save the entire trained model object
37
+ save_path = os.path.join(OUTPUT_DIR, 'trained_model.joblib')
38
+ joblib.dump(model, save_path)
39
+
40
+ print(f"\nModel saved successfully to {save_path}")
41
+
42
+ if __name__ == '__main__':
43
+ parser = argparse.ArgumentParser(description="Train and save a prediction model.")
44
+ parser.add_argument(
45
+ '--model',
46
+ type=str,
47
+ default='XGBoostModel',
48
+ help="The name of the model class to train and save."
49
+ )
50
+ args = parser.parse_args()
51
+
52
+ save_model(args.model)