File size: 10,832 Bytes
6f4e394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088f2ca
6f4e394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088f2ca
6f4e394
 
 
088f2ca
 
 
6f4e394
 
088f2ca
6f4e394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088f2ca
6f4e394
 
 
 
088f2ca
6f4e394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088f2ca
6f4e394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import os
import torch
import torchaudio
import torchvision
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import sys
from tqdm import tqdm

# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data

# Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file
from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES

# Print library versions
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")

# Device selection
device = torch.device(
    "cuda" if torch.cuda.is_available() 
    else "mps" if torch.backends.mps.is_available() 
    else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")

# Define the top-performing models based on the previous evaluation
TOP_MODELS = [
    {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
    {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
    {"image_backbone": "resnet50", "audio_backbone": "transformer"}
]

# Define class for the MoE model
class WatermelonMoEModel(torch.nn.Module):
    def __init__(self, model_configs, model_dir="models", weights=None):
        """
        Mixture of Experts model that combines multiple backbone models.
        
        Args:
            model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
            model_dir: Directory where model checkpoints are stored
            weights: Optional list of weights for each model (None for equal weighting)
        """
        super(WatermelonMoEModel, self).__init__()
        self.models = []
        self.model_configs = model_configs
        
        # Load each model
        for config in model_configs:
            img_backbone = config["image_backbone"]
            audio_backbone = config["audio_backbone"]
            
            # Initialize model
            model = WatermelonModelModular(img_backbone, audio_backbone)
            
            # Load weights
            model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
            if os.path.exists(model_path):
                print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
                model.load_state_dict(torch.load(model_path, map_location=device))
            else:
                print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
                continue
                
            model.to(device)
            model.eval()  # Set to evaluation mode
            self.models.append(model)
        
        # Set model weights (uniform by default)
        if weights:
            assert len(weights) == len(self.models), "Number of weights must match number of models"
            self.weights = weights
        else:
            self.weights = [1.0 / len(self.models)] * len(self.models)
            
        print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
        print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")

    def forward(self, mfcc, image):
        """
        Forward pass through the MoE model.
        Returns the weighted average of all model outputs.
        """
        outputs = []
        
        # Get outputs from each model
        with torch.no_grad():
            for i, model in enumerate(self.models):
                output = model(mfcc, image)
                print(f"DEBUG: Model {i} output: {output}")
                outputs.append(output * self.weights[i])
        
        # Return weighted average
        final_output = torch.sum(torch.stack(outputs), dim=0)
        print(f"DEBUG: Raw prediction: {final_output}")
        return final_output


def evaluate_moe_model(data_dir, model_dir="models", weights=None):
    """
    Evaluate the MoE model on the test set.
    """
    # Load dataset
    print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}")
    dataset = WatermelonDataset(data_dir)
    n_samples = len(dataset)

    # Split dataset
    train_size = int(0.7 * n_samples)
    val_size = int(0.2 * n_samples)
    test_size = n_samples - train_size - val_size

    _, _, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )

    # Use a reasonable batch size
    batch_size = 8
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize MoE model
    moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
    moe_model.eval()

    # Evaluation metrics
    mae_criterion = torch.nn.L1Loss()
    mse_criterion = torch.nn.MSELoss()
    
    test_mae = 0.0
    test_mse = 0.0
    
    print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples")
    
    # Individual model predictions for analysis
    individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": [] 
                             for config in TOP_MODELS}
    true_labels = []
    moe_predictions = []
    
    # Evaluation loop
    test_iterator = tqdm(test_loader, desc="Testing MoE")
    
    with torch.no_grad():
        for i, (mfcc, image, label) in enumerate(test_iterator):
            try:
                mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
                
                # Store individual model outputs for analysis
                for j, model in enumerate(moe_model.models):
                    config = TOP_MODELS[j]
                    model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
                    output = model(mfcc, image)
                    individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
                    print(f"DEBUG: Model {j} output: {output}")
                
                # Get MoE prediction
                output = moe_model(mfcc, image)
                moe_predictions.extend(output.view(-1).cpu().numpy())
                print(f"DEBUG: MoE prediction: {output}")
                
                # Store true labels
                label = label.view(-1, 1).float()
                true_labels.extend(label.view(-1).cpu().numpy())
                
                # Calculate metrics
                mae = mae_criterion(output, label)
                mse = mse_criterion(output, label)
                
                test_mae += mae.item()
                test_mse += mse.item()
                
                test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"})
                
                # Clean up memory
                if device.type == 'cuda':
                    del mfcc, image, label, output, mae, mse
                    torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                continue

    # Calculate average metrics
    avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
    avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf')
    
    print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===")
    print(f"Test MAE: {avg_test_mae:.4f}")
    print(f"Test MSE: {avg_test_mse:.4f}")
    
    # Compare with individual models
    print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===")
    print(f"{'Model':<30} {'Test MAE':<15}")
    print("="*45)
    
    # Load previous results
    results_file = "backbone_evaluation_results.json"
    if os.path.exists(results_file):
        with open(results_file, 'r') as f:
            previous_results = json.load(f)
        
        # Filter results for our top models
        for config in TOP_MODELS:
            img_backbone = config["image_backbone"]
            audio_backbone = config["audio_backbone"]
            
            for result in previous_results:
                if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone:
                    print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}")
    
    print(f"MoE (Ensemble)               {avg_test_mae:<15.4f}")
    
    # Save results and predictions
    results = {
        "moe_test_mae": float(avg_test_mae),
        "moe_test_mse": float(avg_test_mse),
        "true_labels": [float(x) for x in true_labels],
        "moe_predictions": [float(x) for x in moe_predictions],
        "individual_predictions": {key: [float(x) for x in values] 
                                  for key, values in individual_predictions.items()}
    }
    
    with open("moe_evaluation_results.json", 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json")
    
    return avg_test_mae, avg_test_mse


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction")
    parser.add_argument(
        "--data_dir", 
        type=str, 
        default="../cleaned", 
        help="Path to the cleaned dataset directory"
    )
    parser.add_argument(
        "--model_dir", 
        type=str, 
        default="models", 
        help="Directory containing model checkpoints"
    )
    parser.add_argument(
        "--weighting", 
        type=str, 
        choices=["uniform", "performance"], 
        default="uniform", 
        help="How to weight the models (uniform or based on performance)"
    )
    
    args = parser.parse_args()
    
    # Determine weights based on argument
    weights = None
    if args.weighting == "performance":
        # Weights inversely proportional to the MAE (better models get higher weights)
        # These are the MAE values from the provided results
        mae_values = [0.3635, 0.3765, 0.3959]  # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
        
        # Convert to weights (inverse of MAE, normalized)
        inverse_mae = [1/mae for mae in mae_values]
        total = sum(inverse_mae)
        weights = [val/total for val in inverse_mae]
        
        print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
    else:
        print(f"\033[92mINFO\033[0m: Using uniform weights")
    
    # Evaluate the MoE model
    evaluate_moe_model(args.data_dir, args.model_dir, weights)