Alvaro commited on
Commit
e012a04
·
1 Parent(s): c81156a

Refactor prediction pipeline and modularize models

Browse files

Replaces the monolithic predict.py with a modular prediction pipeline. Adds main.py to orchestrate model evaluation, models.py for model abstractions and the ELO baseline, and pipeline.py for data loading, splitting, evaluation, and reporting. Updates process_fights_for_elo to accept either a file path or pre-loaded data for improved flexibility.

src/analysis/elo.py CHANGED
@@ -30,23 +30,32 @@ def update_elo_draw(elo1, elo2):
30
 
31
  return elo1 + change1, elo2 + change2
32
 
33
- def process_fights_for_elo(fights_csv_path=FIGHTS_CSV_PATH):
34
  """
35
- Processes all fights chronologically to calculate final ELO scores for all fighters.
 
36
  """
37
- if not os.path.exists(fights_csv_path):
38
- print(f"Error: Fights data file not found at '{fights_csv_path}'.")
39
- print("Please run the scraping pipeline first using 'src/scrape/main.py'.")
 
 
 
 
 
 
 
 
 
 
40
  return None
41
 
42
- with open(fights_csv_path, 'r', encoding='utf-8') as f:
43
- fights = list(csv.DictReader(f))
44
-
45
- # Sort fights by date to process them in chronological order
46
  try:
47
  fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
48
  except (ValueError, KeyError) as e:
49
- print(f"Error sorting fights by date. Make sure 'event_date' exists and is in 'Month Day, Year' format. Error: {e}")
50
  return None
51
 
52
  elos = {}
 
30
 
31
  return elo1 + change1, elo2 + change2
32
 
33
+ def process_fights_for_elo(fights_data=FIGHTS_CSV_PATH):
34
  """
35
+ Processes fights chronologically to calculate ELO scores.
36
+ Accepts either a CSV file path or a pre-loaded list of fights.
37
  """
38
+ fights = []
39
+ if isinstance(fights_data, str):
40
+ # If a string is passed, treat it as a file path
41
+ if not os.path.exists(fights_data):
42
+ print(f"Error: Fights data file not found at '{fights_data}'.")
43
+ return None
44
+ with open(fights_data, 'r', encoding='utf-8') as f:
45
+ fights = list(csv.DictReader(f))
46
+ elif isinstance(fights_data, list):
47
+ # If a list is passed, use it directly
48
+ fights = fights_data
49
+ else:
50
+ print(f"Error: Invalid data type passed to process_fights_for_elo: {type(fights_data)}")
51
  return None
52
 
53
+ # Sort fights by date to process them in chronological order.
54
+ # This is crucial if loading from a file and a good safeguard if a list is passed.
 
 
55
  try:
56
  fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
57
  except (ValueError, KeyError) as e:
58
+ print(f"Error sorting fights by date: {e}")
59
  return None
60
 
61
  elos = {}
src/predict/main.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import EloBaselineModel
2
+ from .pipeline import PredictionPipeline
3
+
4
+ def main():
5
+ """
6
+ Sets up the models and runs the prediction pipeline.
7
+ This is where you can add new models to compare them.
8
+ """
9
+ print("--- Initializing Machine Learning Prediction Pipeline ---")
10
+
11
+ # 1. Initialize the models you want to test
12
+ elo_model = EloBaselineModel()
13
+
14
+ # Add other models here to compare them, e.g.:
15
+ # logistic_model = LogisticRegressionModel()
16
+
17
+ # 2. Create a list of the models to evaluate
18
+ models_to_run = [
19
+ elo_model,
20
+ # logistic_model
21
+ ]
22
+
23
+ # 3. Initialize and run the pipeline
24
+ pipeline = PredictionPipeline(models=models_to_run)
25
+
26
+ # Set detailed_report=False for a summary, or True for a full detailed report
27
+ pipeline.run(detailed_report=False)
28
+
29
+ if __name__ == '__main__':
30
+ main()
src/predict/models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import sys
3
+ import os
4
+ from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
5
+
6
+ class BaseModel(ABC):
7
+ """
8
+ Abstract base class for all prediction models.
9
+ Ensures that every model has a standard interface for training and prediction.
10
+ """
11
+ @abstractmethod
12
+ def train(self, train_fights):
13
+ """
14
+ Trains or prepares the model using historical fight data.
15
+
16
+ :param train_fights: A list of historical fight data dictionaries.
17
+ """
18
+ pass
19
+
20
+ @abstractmethod
21
+ def predict(self, fighter1_name, fighter2_name):
22
+ """
23
+ Predicts the winner of a single fight.
24
+
25
+ :param fighter1_name: The name of the first fighter.
26
+ :param fighter2_name: The name of the second fighter.
27
+ :return: The name of the predicted winning fighter.
28
+ """
29
+ pass
30
+
31
+ class EloBaselineModel(BaseModel):
32
+ """
33
+ A baseline prediction model that predicts the winner based on the higher ELO rating.
34
+ """
35
+ def __init__(self):
36
+ self.historical_elos = {}
37
+
38
+ def train(self, train_fights):
39
+ """
40
+ Calculates the ELO ratings for all fighters based on historical data.
41
+ These ratings are then stored to be used for predictions.
42
+ """
43
+ print("Training ELO Baseline Model...")
44
+ self.historical_elos = process_fights_for_elo(train_fights)
45
+ print("ELO Model training complete.")
46
+
47
+ def predict(self, fighter1_name, fighter2_name):
48
+ """
49
+ Predicts the winner based on which fighter has the higher historical ELO.
50
+ If a fighter has no ELO rating, the default initial ELO is used.
51
+ """
52
+ elo1 = self.historical_elos.get(fighter1_name, INITIAL_ELO)
53
+ elo2 = self.historical_elos.get(fighter2_name, INITIAL_ELO)
54
+
55
+ # Return the name of the fighter with the higher ELO
56
+ return fighter1_name if elo1 > elo2 else fighter2_name
src/predict/pipeline.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import sys
4
+ from datetime import datetime
5
+ from collections import OrderedDict
6
+ from ..scrape.config import FIGHTS_CSV_PATH
7
+ from .models import BaseModel
8
+
9
+ class PredictionPipeline:
10
+ """
11
+ Orchestrates the model training, evaluation, and reporting pipeline.
12
+ """
13
+ def __init__(self, models):
14
+ if not all(isinstance(m, BaseModel) for m in models):
15
+ raise TypeError("All models must be instances of BaseModel.")
16
+ self.models = models
17
+ self.train_fights = []
18
+ self.test_fights = []
19
+ self.results = {}
20
+
21
+ def _load_and_split_data(self, num_test_events=10):
22
+ """Loads and splits the data into chronological training and testing sets."""
23
+ print("\n--- Loading and Splitting Data ---")
24
+ if not os.path.exists(FIGHTS_CSV_PATH):
25
+ raise FileNotFoundError(f"Fights data not found at '{FIGHTS_CSV_PATH}'.")
26
+
27
+ with open(FIGHTS_CSV_PATH, 'r', encoding='utf-8') as f:
28
+ fights = list(csv.DictReader(f))
29
+
30
+ fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
31
+
32
+ all_events = list(OrderedDict.fromkeys(f['event_name'] for f in fights))
33
+ if len(all_events) < num_test_events:
34
+ print(f"Warning: Fewer than {num_test_events} events found. Adjusting test set size.")
35
+ num_test_events = len(all_events)
36
+
37
+ test_event_names = all_events[-num_test_events:]
38
+ self.train_fights = [f for f in fights if f['event_name'] not in test_event_names]
39
+ self.test_fights = [f for f in fights if f['event_name'] in test_event_names]
40
+ print(f"Data loaded. {len(self.train_fights)} training fights, {len(self.test_fights)} testing fights.")
41
+ print(f"Testing on the last {num_test_events} events.")
42
+
43
+ def run(self, detailed_report=True):
44
+ """Executes the full pipeline: load, train, evaluate, and report."""
45
+ self._load_and_split_data()
46
+
47
+ eval_fights = [f for f in self.test_fights if f['winner'] not in ["Draw", "NC", ""]]
48
+ if not eval_fights:
49
+ print("No fights with definitive outcomes in the test set. Aborting.")
50
+ return
51
+
52
+ for model in self.models:
53
+ model_name = model.__class__.__name__
54
+ print(f"\n--- Evaluating Model: {model_name} ---")
55
+
56
+ model.train(self.train_fights)
57
+
58
+ correct_predictions = 0
59
+ predictions = []
60
+
61
+ for fight in eval_fights:
62
+ f1_name, f2_name = fight['fighter_1'], fight['fighter_2']
63
+ actual_winner = fight['winner']
64
+ predicted_winner = model.predict(f1_name, f2_name)
65
+
66
+ is_correct = (predicted_winner == actual_winner)
67
+ if is_correct:
68
+ correct_predictions += 1
69
+
70
+ predictions.append({
71
+ 'fight': f"{f1_name} vs. {f2_name}",
72
+ 'predicted_winner': predicted_winner,
73
+ 'actual_winner': actual_winner,
74
+ 'is_correct': is_correct
75
+ })
76
+
77
+ accuracy = (correct_predictions / len(eval_fights)) * 100
78
+ self.results[model_name] = {
79
+ 'accuracy': accuracy,
80
+ 'predictions': predictions,
81
+ 'total_fights': len(eval_fights)
82
+ }
83
+
84
+ if detailed_report:
85
+ self._report_detailed_results()
86
+ else:
87
+ self._report_summary()
88
+
89
+ def _report_summary(self):
90
+ """Prints a concise summary of model performance."""
91
+ print("\n\n--- Prediction Pipeline Summary ---")
92
+ print(f"{'Model':<25} | {'Accuracy':<10} | {'Fights Evaluated':<20}")
93
+ print("-" * 65)
94
+ for model_name, result in self.results.items():
95
+ print(f"{model_name:<25} | {result['accuracy']:<9.2f}% | {result['total_fights']:<20}")
96
+ print("-" * 65)
97
+
98
+ def _report_detailed_results(self):
99
+ """Prints a summary and detailed report of the model evaluations."""
100
+ print("\n\n--- Prediction Pipeline Finished: Detailed Report ---")
101
+ for model_name, result in self.results.items():
102
+ print(f"\n--- Model: {model_name} ---")
103
+ print(f" Overall Accuracy: {result['accuracy']:.2f}%")
104
+ print(" Detailed Predictions:")
105
+ for p in result['predictions']:
106
+ status = "CORRECT" if p['is_correct'] else "INCORRECT"
107
+ print(f" - Fight: {p['fight']}")
108
+ print(f" -> Predicted: {p['predicted_winner']}")
109
+ print(f" -> Actual: {p['actual_winner']}")
110
+ print(f" -> Result: {status}")
111
+ print("------------------------" + "-" * len(model_name))
src/predict/predict.py DELETED
@@ -1,95 +0,0 @@
1
- import csv
2
- import os
3
- import sys
4
- from datetime import datetime
5
- from ..scrape.config import FIGHTS_CSV_PATH, FIGHTERS_CSV_PATH
6
-
7
- def load_fighters_data():
8
- """Loads fighter data, including ELO scores, into a dictionary."""
9
- if not os.path.exists(FIGHTERS_CSV_PATH):
10
- print(f"Error: Fighter data not found at '{FIGHTERS_CSV_PATH}'.")
11
- print("Please run the ELO analysis first ('python -m src.analysis.elo').")
12
- return None
13
-
14
- fighters = {}
15
- with open(FIGHTERS_CSV_PATH, 'r', encoding='utf-8') as f:
16
- reader = csv.DictReader(f)
17
- for row in reader:
18
- full_name = f"{row['first_name']} {row['last_name']}".strip()
19
- fighters[full_name] = {'elo': float(row.get('elo', 1500))} # Default ELO if missing
20
- return fighters
21
-
22
- def load_fights_data():
23
- """Loads fight data and sorts it chronologically."""
24
- if not os.path.exists(FIGHTS_CSV_PATH):
25
- print(f"Error: Fights data not found at '{FIGHTS_CSV_PATH}'.")
26
- return None
27
-
28
- with open(FIGHTS_CSV_PATH, 'r', encoding='utf-8') as f:
29
- fights = list(csv.DictReader(f))
30
-
31
- # Sort fights chronologically to ensure a proper train/test split later
32
- fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
33
- return fights
34
-
35
- def run_elo_baseline_model(fights, fighters):
36
- """
37
- Runs a simple baseline prediction model where the fighter with the higher ELO is predicted to win.
38
- """
39
- correct_predictions = 0
40
- total_predictions = 0
41
-
42
- for fight in fights:
43
- fighter1_name = fight['fighter_1']
44
- fighter2_name = fight['fighter_2']
45
- actual_winner = fight['winner']
46
-
47
- # Skip fights that are draws or no contests
48
- if actual_winner in ["Draw", "NC", ""]:
49
- continue
50
-
51
- fighter1 = fighters.get(fighter1_name)
52
- fighter2 = fighters.get(fighter2_name)
53
-
54
- if not fighter1 or not fighter2:
55
- continue # Skip if fighter data is missing
56
-
57
- elo1 = fighter1.get('elo', 1500)
58
- elo2 = fighter2.get('elo', 1500)
59
-
60
- # Predict winner based on higher ELO
61
- predicted_winner = fighter1_name if elo1 > elo2 else fighter2_name
62
-
63
- if predicted_winner == actual_winner:
64
- correct_predictions += 1
65
-
66
- total_predictions += 1
67
-
68
- accuracy = (correct_predictions / total_predictions) * 100 if total_predictions > 0 else 0
69
- return accuracy, total_predictions
70
-
71
- def main():
72
- """
73
- Main function to run the prediction pipeline.
74
- """
75
- print("--- Starting ML Prediction Pipeline ---")
76
-
77
- # Load data
78
- fighters_data = load_fighters_data()
79
- fights_data = load_fights_data()
80
-
81
- if not fighters_data or not fights_data:
82
- print("Aborting pipeline due to missing data.")
83
- return
84
-
85
- # Run baseline model
86
- print("\nRunning Baseline Model (Predicting winner by highest ELO)...")
87
- accuracy, total_fights = run_elo_baseline_model(fights_data, fighters_data)
88
-
89
- print("\n--- Baseline Model Evaluation ---")
90
- print(f"Total Fights Evaluated: {total_fights}")
91
- print(f"Model Accuracy: {accuracy:.2f}%")
92
- print("---------------------------------")
93
-
94
- if __name__ == '__main__':
95
- main()