Apnea-Detector / models /respfusion.py
AbdullahNasir's picture
Added modularity
c0953f1
import torch
import torch.nn as nn
import xgboost as xgb
import numpy as np
class RespFusion(nn.Module):
def __init__(self, tft_model, tcn_model, ets_model, bilstm_model, meta_learner_path=None, weights=None, strategy='stacking',):
super(RespFusion, self).__init__()
self.tft = tft_model
self.tcn = tcn_model
self.ets = ets_model
self.bilstm = bilstm_model
self.strategy = strategy # 'stacking' or other strategies
# Initialize XGBoost meta-learner
self.meta_learner = xgb.XGBRegressor() # Or XGBClassifier for classification
# Load the meta-learner if a path is provided
if meta_learner_path is not None:
self.meta_learner.load_model(meta_learner_path)
print(f"Meta-learner loaded from {meta_learner_path}")
# Storage for stacking training data
self.stacking_features = []
self.stacking_targets = []
# Set model weights for ensembling, default to equal weights for weighted_average strategy
if strategy == 'weighted_average':
if weights is None:
self.weights = [1.0, 1.0, 1.0, 1.0]
else:
assert len(weights) == 4, "Weights must match the number of models."
self.weights = weights
def forward(self, x):
# Get predictions from each base model
tft_output = self.tft(x).detach().cpu().numpy()
tcn_output = self.tcn(x).detach().cpu().numpy()
ets_output = self.ets(x).detach().cpu().numpy()
bilstm_output = self.bilstm(x).detach().cpu().numpy()
if self.strategy == 'stacking':
# Combine outputs into features for the meta-learner
features = np.column_stack((tft_output, tcn_output, ets_output, bilstm_output))
# During inference, use the meta-learner to make predictions
ensemble_output = self.meta_learner.predict(features)
return torch.tensor(ensemble_output).to(x.device).float()
elif self.strategy == 'voting':
# For soft voting, calculate the average
ensemble_output = torch.mean(torch.stack([tft_output, tcn_output, ets_output, bilstm_output], dim=0), dim=0)
return ensemble_output
elif self.strategy == 'weighted_average':
# Weighted average of outputs
ensemble_output = (
self.weights[0] * tft_output +
self.weights[1] * tcn_output +
self.weights[2] * ets_output +
self.weights[3] * bilstm_output
) / sum(self.weights)
return ensemble_output
elif self.strategy == 'simple_average':
# Simple average of outputs
ensemble_output = (tft_output + tcn_output + ets_output + bilstm_output) / 4
return ensemble_output
else:
raise ValueError(f"Invalid strategy: {self.strategy}. Currently supports only 'stacking', 'voting', 'weighted_average', and 'simple_average'.")
def collect_stacking_data(self, x, y):
"""Collect base model outputs and corresponding targets for meta-learner training."""
tft_output = self.tft(x).detach().cpu().numpy()
tcn_output = self.tcn(x).detach().cpu().numpy()
ets_output = self.ets(x).detach().cpu().numpy()
bilstm_output = self.bilstm(x).detach().cpu().numpy()
# Stack features and store
features = np.column_stack((tft_output, tcn_output, ets_output, bilstm_output))
self.stacking_features.append(features)
self.stacking_targets.append(y.detach().cpu().numpy())
def train_meta_learner(self, save_path=None):
"""Train the XGBoost meta-learner on collected data and save the model."""
# Concatenate all collected features and targets
X = np.vstack(self.stacking_features)
y = np.concatenate(self.stacking_targets)
# Train the XGBoost model
self.meta_learner.fit(X, y)
print("Meta-learner trained successfully!")
# Save the trained meta-learner
if save_path:
self.meta_learner.save_model(save_path)
print(f"Meta-learner saved to {save_path}")