File size: 4,287 Bytes
c0953f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import xgboost as xgb
import numpy as np

class RespFusion(nn.Module):
    def __init__(self, tft_model, tcn_model, ets_model, bilstm_model, meta_learner_path=None, weights=None, strategy='stacking',):
        super(RespFusion, self).__init__()
        self.tft = tft_model
        self.tcn = tcn_model
        self.ets = ets_model
        self.bilstm = bilstm_model
        self.strategy = strategy  # 'stacking' or other strategies

        # Initialize XGBoost meta-learner
        self.meta_learner = xgb.XGBRegressor()  # Or XGBClassifier for classification

        # Load the meta-learner if a path is provided
        if meta_learner_path is not None:
            self.meta_learner.load_model(meta_learner_path)
            print(f"Meta-learner loaded from {meta_learner_path}")

        # Storage for stacking training data
        self.stacking_features = []
        self.stacking_targets = []
        
        # Set model weights for ensembling, default to equal weights for weighted_average strategy
        if strategy == 'weighted_average':
            if weights is None:
                self.weights = [1.0, 1.0, 1.0, 1.0]
            else:
                assert len(weights) == 4, "Weights must match the number of models."
                self.weights = weights
                
        
    def forward(self, x):
        # Get predictions from each base model
        tft_output = self.tft(x).detach().cpu().numpy()
        tcn_output = self.tcn(x).detach().cpu().numpy()
        ets_output = self.ets(x).detach().cpu().numpy()
        bilstm_output = self.bilstm(x).detach().cpu().numpy()

        if self.strategy == 'stacking':
            # Combine outputs into features for the meta-learner
            features = np.column_stack((tft_output, tcn_output, ets_output, bilstm_output))
            # During inference, use the meta-learner to make predictions
            ensemble_output = self.meta_learner.predict(features)
            return torch.tensor(ensemble_output).to(x.device).float()
        
        elif self.strategy == 'voting':
            # For soft voting, calculate the average
            ensemble_output = torch.mean(torch.stack([tft_output, tcn_output, ets_output, bilstm_output], dim=0), dim=0)
            return ensemble_output

        elif self.strategy == 'weighted_average':
            # Weighted average of outputs
            ensemble_output = (
                self.weights[0] * tft_output +
                self.weights[1] * tcn_output +
                self.weights[2] * ets_output +
                self.weights[3] * bilstm_output
            ) / sum(self.weights)
            return ensemble_output
        
        elif self.strategy == 'simple_average':
            # Simple average of outputs
            ensemble_output = (tft_output + tcn_output + ets_output + bilstm_output) / 4
            return ensemble_output
        
        
        else:
            raise ValueError(f"Invalid strategy: {self.strategy}. Currently supports only 'stacking', 'voting', 'weighted_average', and 'simple_average'.")

    def collect_stacking_data(self, x, y):
        """Collect base model outputs and corresponding targets for meta-learner training."""
        tft_output = self.tft(x).detach().cpu().numpy()
        tcn_output = self.tcn(x).detach().cpu().numpy()
        ets_output = self.ets(x).detach().cpu().numpy()
        bilstm_output = self.bilstm(x).detach().cpu().numpy()

        # Stack features and store
        features = np.column_stack((tft_output, tcn_output, ets_output, bilstm_output))
        self.stacking_features.append(features)
        self.stacking_targets.append(y.detach().cpu().numpy())

    def train_meta_learner(self, save_path=None):
        """Train the XGBoost meta-learner on collected data and save the model."""
        # Concatenate all collected features and targets
        X = np.vstack(self.stacking_features)
        y = np.concatenate(self.stacking_targets)

        # Train the XGBoost model
        self.meta_learner.fit(X, y)
        print("Meta-learner trained successfully!")

        # Save the trained meta-learner
        if save_path:
            self.meta_learner.save_model(save_path)
            print(f"Meta-learner saved to {save_path}")