Spaces:
Sleeping
Sleeping
Alvaro
commited on
Commit
·
e012a04
1
Parent(s):
c81156a
Refactor prediction pipeline and modularize models
Browse filesReplaces the monolithic predict.py with a modular prediction pipeline. Adds main.py to orchestrate model evaluation, models.py for model abstractions and the ELO baseline, and pipeline.py for data loading, splitting, evaluation, and reporting. Updates process_fights_for_elo to accept either a file path or pre-loaded data for improved flexibility.
- src/analysis/elo.py +19 -10
- src/predict/main.py +30 -0
- src/predict/models.py +56 -0
- src/predict/pipeline.py +111 -0
- src/predict/predict.py +0 -95
src/analysis/elo.py
CHANGED
@@ -30,23 +30,32 @@ def update_elo_draw(elo1, elo2):
|
|
30 |
|
31 |
return elo1 + change1, elo2 + change2
|
32 |
|
33 |
-
def process_fights_for_elo(
|
34 |
"""
|
35 |
-
Processes
|
|
|
36 |
"""
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
return None
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
# Sort fights by date to process them in chronological order
|
46 |
try:
|
47 |
fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
|
48 |
except (ValueError, KeyError) as e:
|
49 |
-
print(f"Error sorting fights by date
|
50 |
return None
|
51 |
|
52 |
elos = {}
|
|
|
30 |
|
31 |
return elo1 + change1, elo2 + change2
|
32 |
|
33 |
+
def process_fights_for_elo(fights_data=FIGHTS_CSV_PATH):
|
34 |
"""
|
35 |
+
Processes fights chronologically to calculate ELO scores.
|
36 |
+
Accepts either a CSV file path or a pre-loaded list of fights.
|
37 |
"""
|
38 |
+
fights = []
|
39 |
+
if isinstance(fights_data, str):
|
40 |
+
# If a string is passed, treat it as a file path
|
41 |
+
if not os.path.exists(fights_data):
|
42 |
+
print(f"Error: Fights data file not found at '{fights_data}'.")
|
43 |
+
return None
|
44 |
+
with open(fights_data, 'r', encoding='utf-8') as f:
|
45 |
+
fights = list(csv.DictReader(f))
|
46 |
+
elif isinstance(fights_data, list):
|
47 |
+
# If a list is passed, use it directly
|
48 |
+
fights = fights_data
|
49 |
+
else:
|
50 |
+
print(f"Error: Invalid data type passed to process_fights_for_elo: {type(fights_data)}")
|
51 |
return None
|
52 |
|
53 |
+
# Sort fights by date to process them in chronological order.
|
54 |
+
# This is crucial if loading from a file and a good safeguard if a list is passed.
|
|
|
|
|
55 |
try:
|
56 |
fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
|
57 |
except (ValueError, KeyError) as e:
|
58 |
+
print(f"Error sorting fights by date: {e}")
|
59 |
return None
|
60 |
|
61 |
elos = {}
|
src/predict/main.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import EloBaselineModel
|
2 |
+
from .pipeline import PredictionPipeline
|
3 |
+
|
4 |
+
def main():
|
5 |
+
"""
|
6 |
+
Sets up the models and runs the prediction pipeline.
|
7 |
+
This is where you can add new models to compare them.
|
8 |
+
"""
|
9 |
+
print("--- Initializing Machine Learning Prediction Pipeline ---")
|
10 |
+
|
11 |
+
# 1. Initialize the models you want to test
|
12 |
+
elo_model = EloBaselineModel()
|
13 |
+
|
14 |
+
# Add other models here to compare them, e.g.:
|
15 |
+
# logistic_model = LogisticRegressionModel()
|
16 |
+
|
17 |
+
# 2. Create a list of the models to evaluate
|
18 |
+
models_to_run = [
|
19 |
+
elo_model,
|
20 |
+
# logistic_model
|
21 |
+
]
|
22 |
+
|
23 |
+
# 3. Initialize and run the pipeline
|
24 |
+
pipeline = PredictionPipeline(models=models_to_run)
|
25 |
+
|
26 |
+
# Set detailed_report=False for a summary, or True for a full detailed report
|
27 |
+
pipeline.run(detailed_report=False)
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
main()
|
src/predict/models.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
|
5 |
+
|
6 |
+
class BaseModel(ABC):
|
7 |
+
"""
|
8 |
+
Abstract base class for all prediction models.
|
9 |
+
Ensures that every model has a standard interface for training and prediction.
|
10 |
+
"""
|
11 |
+
@abstractmethod
|
12 |
+
def train(self, train_fights):
|
13 |
+
"""
|
14 |
+
Trains or prepares the model using historical fight data.
|
15 |
+
|
16 |
+
:param train_fights: A list of historical fight data dictionaries.
|
17 |
+
"""
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def predict(self, fighter1_name, fighter2_name):
|
22 |
+
"""
|
23 |
+
Predicts the winner of a single fight.
|
24 |
+
|
25 |
+
:param fighter1_name: The name of the first fighter.
|
26 |
+
:param fighter2_name: The name of the second fighter.
|
27 |
+
:return: The name of the predicted winning fighter.
|
28 |
+
"""
|
29 |
+
pass
|
30 |
+
|
31 |
+
class EloBaselineModel(BaseModel):
|
32 |
+
"""
|
33 |
+
A baseline prediction model that predicts the winner based on the higher ELO rating.
|
34 |
+
"""
|
35 |
+
def __init__(self):
|
36 |
+
self.historical_elos = {}
|
37 |
+
|
38 |
+
def train(self, train_fights):
|
39 |
+
"""
|
40 |
+
Calculates the ELO ratings for all fighters based on historical data.
|
41 |
+
These ratings are then stored to be used for predictions.
|
42 |
+
"""
|
43 |
+
print("Training ELO Baseline Model...")
|
44 |
+
self.historical_elos = process_fights_for_elo(train_fights)
|
45 |
+
print("ELO Model training complete.")
|
46 |
+
|
47 |
+
def predict(self, fighter1_name, fighter2_name):
|
48 |
+
"""
|
49 |
+
Predicts the winner based on which fighter has the higher historical ELO.
|
50 |
+
If a fighter has no ELO rating, the default initial ELO is used.
|
51 |
+
"""
|
52 |
+
elo1 = self.historical_elos.get(fighter1_name, INITIAL_ELO)
|
53 |
+
elo2 = self.historical_elos.get(fighter2_name, INITIAL_ELO)
|
54 |
+
|
55 |
+
# Return the name of the fighter with the higher ELO
|
56 |
+
return fighter1_name if elo1 > elo2 else fighter2_name
|
src/predict/pipeline.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from datetime import datetime
|
5 |
+
from collections import OrderedDict
|
6 |
+
from ..scrape.config import FIGHTS_CSV_PATH
|
7 |
+
from .models import BaseModel
|
8 |
+
|
9 |
+
class PredictionPipeline:
|
10 |
+
"""
|
11 |
+
Orchestrates the model training, evaluation, and reporting pipeline.
|
12 |
+
"""
|
13 |
+
def __init__(self, models):
|
14 |
+
if not all(isinstance(m, BaseModel) for m in models):
|
15 |
+
raise TypeError("All models must be instances of BaseModel.")
|
16 |
+
self.models = models
|
17 |
+
self.train_fights = []
|
18 |
+
self.test_fights = []
|
19 |
+
self.results = {}
|
20 |
+
|
21 |
+
def _load_and_split_data(self, num_test_events=10):
|
22 |
+
"""Loads and splits the data into chronological training and testing sets."""
|
23 |
+
print("\n--- Loading and Splitting Data ---")
|
24 |
+
if not os.path.exists(FIGHTS_CSV_PATH):
|
25 |
+
raise FileNotFoundError(f"Fights data not found at '{FIGHTS_CSV_PATH}'.")
|
26 |
+
|
27 |
+
with open(FIGHTS_CSV_PATH, 'r', encoding='utf-8') as f:
|
28 |
+
fights = list(csv.DictReader(f))
|
29 |
+
|
30 |
+
fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
|
31 |
+
|
32 |
+
all_events = list(OrderedDict.fromkeys(f['event_name'] for f in fights))
|
33 |
+
if len(all_events) < num_test_events:
|
34 |
+
print(f"Warning: Fewer than {num_test_events} events found. Adjusting test set size.")
|
35 |
+
num_test_events = len(all_events)
|
36 |
+
|
37 |
+
test_event_names = all_events[-num_test_events:]
|
38 |
+
self.train_fights = [f for f in fights if f['event_name'] not in test_event_names]
|
39 |
+
self.test_fights = [f for f in fights if f['event_name'] in test_event_names]
|
40 |
+
print(f"Data loaded. {len(self.train_fights)} training fights, {len(self.test_fights)} testing fights.")
|
41 |
+
print(f"Testing on the last {num_test_events} events.")
|
42 |
+
|
43 |
+
def run(self, detailed_report=True):
|
44 |
+
"""Executes the full pipeline: load, train, evaluate, and report."""
|
45 |
+
self._load_and_split_data()
|
46 |
+
|
47 |
+
eval_fights = [f for f in self.test_fights if f['winner'] not in ["Draw", "NC", ""]]
|
48 |
+
if not eval_fights:
|
49 |
+
print("No fights with definitive outcomes in the test set. Aborting.")
|
50 |
+
return
|
51 |
+
|
52 |
+
for model in self.models:
|
53 |
+
model_name = model.__class__.__name__
|
54 |
+
print(f"\n--- Evaluating Model: {model_name} ---")
|
55 |
+
|
56 |
+
model.train(self.train_fights)
|
57 |
+
|
58 |
+
correct_predictions = 0
|
59 |
+
predictions = []
|
60 |
+
|
61 |
+
for fight in eval_fights:
|
62 |
+
f1_name, f2_name = fight['fighter_1'], fight['fighter_2']
|
63 |
+
actual_winner = fight['winner']
|
64 |
+
predicted_winner = model.predict(f1_name, f2_name)
|
65 |
+
|
66 |
+
is_correct = (predicted_winner == actual_winner)
|
67 |
+
if is_correct:
|
68 |
+
correct_predictions += 1
|
69 |
+
|
70 |
+
predictions.append({
|
71 |
+
'fight': f"{f1_name} vs. {f2_name}",
|
72 |
+
'predicted_winner': predicted_winner,
|
73 |
+
'actual_winner': actual_winner,
|
74 |
+
'is_correct': is_correct
|
75 |
+
})
|
76 |
+
|
77 |
+
accuracy = (correct_predictions / len(eval_fights)) * 100
|
78 |
+
self.results[model_name] = {
|
79 |
+
'accuracy': accuracy,
|
80 |
+
'predictions': predictions,
|
81 |
+
'total_fights': len(eval_fights)
|
82 |
+
}
|
83 |
+
|
84 |
+
if detailed_report:
|
85 |
+
self._report_detailed_results()
|
86 |
+
else:
|
87 |
+
self._report_summary()
|
88 |
+
|
89 |
+
def _report_summary(self):
|
90 |
+
"""Prints a concise summary of model performance."""
|
91 |
+
print("\n\n--- Prediction Pipeline Summary ---")
|
92 |
+
print(f"{'Model':<25} | {'Accuracy':<10} | {'Fights Evaluated':<20}")
|
93 |
+
print("-" * 65)
|
94 |
+
for model_name, result in self.results.items():
|
95 |
+
print(f"{model_name:<25} | {result['accuracy']:<9.2f}% | {result['total_fights']:<20}")
|
96 |
+
print("-" * 65)
|
97 |
+
|
98 |
+
def _report_detailed_results(self):
|
99 |
+
"""Prints a summary and detailed report of the model evaluations."""
|
100 |
+
print("\n\n--- Prediction Pipeline Finished: Detailed Report ---")
|
101 |
+
for model_name, result in self.results.items():
|
102 |
+
print(f"\n--- Model: {model_name} ---")
|
103 |
+
print(f" Overall Accuracy: {result['accuracy']:.2f}%")
|
104 |
+
print(" Detailed Predictions:")
|
105 |
+
for p in result['predictions']:
|
106 |
+
status = "CORRECT" if p['is_correct'] else "INCORRECT"
|
107 |
+
print(f" - Fight: {p['fight']}")
|
108 |
+
print(f" -> Predicted: {p['predicted_winner']}")
|
109 |
+
print(f" -> Actual: {p['actual_winner']}")
|
110 |
+
print(f" -> Result: {status}")
|
111 |
+
print("------------------------" + "-" * len(model_name))
|
src/predict/predict.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import csv
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
from datetime import datetime
|
5 |
-
from ..scrape.config import FIGHTS_CSV_PATH, FIGHTERS_CSV_PATH
|
6 |
-
|
7 |
-
def load_fighters_data():
|
8 |
-
"""Loads fighter data, including ELO scores, into a dictionary."""
|
9 |
-
if not os.path.exists(FIGHTERS_CSV_PATH):
|
10 |
-
print(f"Error: Fighter data not found at '{FIGHTERS_CSV_PATH}'.")
|
11 |
-
print("Please run the ELO analysis first ('python -m src.analysis.elo').")
|
12 |
-
return None
|
13 |
-
|
14 |
-
fighters = {}
|
15 |
-
with open(FIGHTERS_CSV_PATH, 'r', encoding='utf-8') as f:
|
16 |
-
reader = csv.DictReader(f)
|
17 |
-
for row in reader:
|
18 |
-
full_name = f"{row['first_name']} {row['last_name']}".strip()
|
19 |
-
fighters[full_name] = {'elo': float(row.get('elo', 1500))} # Default ELO if missing
|
20 |
-
return fighters
|
21 |
-
|
22 |
-
def load_fights_data():
|
23 |
-
"""Loads fight data and sorts it chronologically."""
|
24 |
-
if not os.path.exists(FIGHTS_CSV_PATH):
|
25 |
-
print(f"Error: Fights data not found at '{FIGHTS_CSV_PATH}'.")
|
26 |
-
return None
|
27 |
-
|
28 |
-
with open(FIGHTS_CSV_PATH, 'r', encoding='utf-8') as f:
|
29 |
-
fights = list(csv.DictReader(f))
|
30 |
-
|
31 |
-
# Sort fights chronologically to ensure a proper train/test split later
|
32 |
-
fights.sort(key=lambda x: datetime.strptime(x['event_date'], '%B %d, %Y'))
|
33 |
-
return fights
|
34 |
-
|
35 |
-
def run_elo_baseline_model(fights, fighters):
|
36 |
-
"""
|
37 |
-
Runs a simple baseline prediction model where the fighter with the higher ELO is predicted to win.
|
38 |
-
"""
|
39 |
-
correct_predictions = 0
|
40 |
-
total_predictions = 0
|
41 |
-
|
42 |
-
for fight in fights:
|
43 |
-
fighter1_name = fight['fighter_1']
|
44 |
-
fighter2_name = fight['fighter_2']
|
45 |
-
actual_winner = fight['winner']
|
46 |
-
|
47 |
-
# Skip fights that are draws or no contests
|
48 |
-
if actual_winner in ["Draw", "NC", ""]:
|
49 |
-
continue
|
50 |
-
|
51 |
-
fighter1 = fighters.get(fighter1_name)
|
52 |
-
fighter2 = fighters.get(fighter2_name)
|
53 |
-
|
54 |
-
if not fighter1 or not fighter2:
|
55 |
-
continue # Skip if fighter data is missing
|
56 |
-
|
57 |
-
elo1 = fighter1.get('elo', 1500)
|
58 |
-
elo2 = fighter2.get('elo', 1500)
|
59 |
-
|
60 |
-
# Predict winner based on higher ELO
|
61 |
-
predicted_winner = fighter1_name if elo1 > elo2 else fighter2_name
|
62 |
-
|
63 |
-
if predicted_winner == actual_winner:
|
64 |
-
correct_predictions += 1
|
65 |
-
|
66 |
-
total_predictions += 1
|
67 |
-
|
68 |
-
accuracy = (correct_predictions / total_predictions) * 100 if total_predictions > 0 else 0
|
69 |
-
return accuracy, total_predictions
|
70 |
-
|
71 |
-
def main():
|
72 |
-
"""
|
73 |
-
Main function to run the prediction pipeline.
|
74 |
-
"""
|
75 |
-
print("--- Starting ML Prediction Pipeline ---")
|
76 |
-
|
77 |
-
# Load data
|
78 |
-
fighters_data = load_fighters_data()
|
79 |
-
fights_data = load_fights_data()
|
80 |
-
|
81 |
-
if not fighters_data or not fights_data:
|
82 |
-
print("Aborting pipeline due to missing data.")
|
83 |
-
return
|
84 |
-
|
85 |
-
# Run baseline model
|
86 |
-
print("\nRunning Baseline Model (Predicting winner by highest ELO)...")
|
87 |
-
accuracy, total_fights = run_elo_baseline_model(fights_data, fighters_data)
|
88 |
-
|
89 |
-
print("\n--- Baseline Model Evaluation ---")
|
90 |
-
print(f"Total Fights Evaluated: {total_fights}")
|
91 |
-
print(f"Model Accuracy: {accuracy:.2f}%")
|
92 |
-
print("---------------------------------")
|
93 |
-
|
94 |
-
if __name__ == '__main__':
|
95 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|