Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- .nfs00000001a2244b30003726a6 +1 -0
- .nfs00000001a2b1089c003726a7 +1 -0
- __pycache__/evaluate_backbones.cpython-310.pyc +0 -0
- __pycache__/preprocess.cpython-310.pyc +0 -0
- app.py +139 -43
- app_local_backup.py +100 -47
- app_moe.py +439 -0
- backbone_evaluation_results.json +110 -0
- evaluate_backbones.py +670 -0
- models/.nfs00000001a1a17512003726ad +3 -0
- models/.nfs00000001a234d9cd003726ac +3 -0
- models/.nfs00000001a2a11ea9003726ae +3 -0
- models/efficientnet_b0_transformer_model.pt +3 -0
- models/efficientnet_b3_transformer_model.pt +3 -0
- models/resnet50_transformer_model.pt +3 -0
- moe_evaluation_results.json +801 -0
- templates/.nfs00000001a2893bde003726a5 +1 -0
- test_moe_model.py +276 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
temp/temp_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
temp/temp_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
models/.nfs00000001a1a17512003726ad filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
models/.nfs00000001a234d9cd003726ac filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
models/.nfs00000001a2a11ea9003726ae filter=lfs diff=lfs merge=lfs -text
|
.nfs00000001a2244b30003726a6
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
.nfs00000001a2b1089c003726a7
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
__pycache__/evaluate_backbones.cpython-310.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
__pycache__/preprocess.cpython-310.pyc
ADDED
|
Binary file (1.27 kB). View file
|
|
|
app.py
CHANGED
|
@@ -6,21 +6,82 @@ import gradio as gr
|
|
| 6 |
import torchaudio
|
| 7 |
import torchvision
|
| 8 |
import spaces
|
| 9 |
-
|
| 10 |
-
# # Import Gradio Spaces GPU decorator
|
| 11 |
-
# try:
|
| 12 |
-
# from gradio import spaces
|
| 13 |
-
# HAS_SPACES = True
|
| 14 |
-
# print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
|
| 15 |
-
# except ImportError:
|
| 16 |
-
# HAS_SPACES = False
|
| 17 |
-
# print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
|
| 18 |
|
| 19 |
# Add parent directory to path to import preprocess functions
|
| 20 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
|
| 22 |
-
# Import functions from
|
| 23 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
| 26 |
def app_process_audio_data(waveform, sample_rate):
|
|
@@ -76,15 +137,12 @@ def app_process_audio_data(waveform, sample_rate):
|
|
| 76 |
print(traceback.format_exc())
|
| 77 |
return None
|
| 78 |
|
| 79 |
-
#
|
| 80 |
-
from preprocess import process_image_data
|
| 81 |
-
|
| 82 |
-
# Using the decorator directly on the function definition
|
| 83 |
@spaces.GPU
|
| 84 |
-
def predict_sugar_content(audio, image,
|
| 85 |
-
"""Function with GPU acceleration to predict watermelon sugar content in Brix"""
|
| 86 |
try:
|
| 87 |
-
#
|
| 88 |
if torch.cuda.is_available():
|
| 89 |
device = torch.device("cuda")
|
| 90 |
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
|
@@ -92,11 +150,11 @@ def predict_sugar_content(audio, image, model_path):
|
|
| 92 |
device = torch.device("cpu")
|
| 93 |
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
| 94 |
|
| 95 |
-
# Load model
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
print(f"\033[92mINFO\033[0m: Loaded model
|
| 100 |
|
| 101 |
# Debug information about input types
|
| 102 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
|
@@ -188,11 +246,11 @@ def predict_sugar_content(audio, image, model_path):
|
|
| 188 |
processed_image = processed_image.unsqueeze(0).to(device)
|
| 189 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
| 190 |
|
| 191 |
-
# Run inference
|
| 192 |
-
print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
|
| 193 |
if mfcc is not None and processed_image is not None:
|
| 194 |
with torch.no_grad():
|
| 195 |
-
brix_value =
|
| 196 |
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
| 197 |
else:
|
| 198 |
return "Error: Failed to process inputs. Please check the debug logs."
|
|
@@ -204,6 +262,12 @@ def predict_sugar_content(audio, image, model_path):
|
|
| 204 |
# Create a header with the numerical result
|
| 205 |
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
# Add Brix scale visualization
|
| 208 |
result += "Sugar Content Scale (in °Brix):\n"
|
| 209 |
result += "──────────────────────────────────\n"
|
|
@@ -257,22 +321,27 @@ def predict_sugar_content(audio, image, model_path):
|
|
| 257 |
error_msg += traceback.format_exc()
|
| 258 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
| 259 |
return error_msg
|
| 260 |
-
|
| 261 |
-
print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
|
| 262 |
-
|
| 263 |
|
| 264 |
-
def create_app(
|
| 265 |
"""Create and launch the Gradio interface"""
|
| 266 |
# Define the prediction function with model path
|
| 267 |
def predict_fn(audio, image):
|
| 268 |
-
return predict_sugar_content(audio, image,
|
| 269 |
|
| 270 |
# Create Gradio interface
|
| 271 |
-
with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface:
|
| 272 |
-
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor")
|
| 273 |
gr.Markdown("""
|
| 274 |
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
## Instructions:
|
| 277 |
1. Upload or record an audio of tapping the watermelon
|
| 278 |
2. Upload or capture an image of the watermelon
|
|
@@ -286,7 +355,7 @@ def create_app(model_path):
|
|
| 286 |
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
| 287 |
|
| 288 |
with gr.Column():
|
| 289 |
-
output = gr.Textbox(label="Prediction Results", lines=
|
| 290 |
|
| 291 |
submit_btn.click(
|
| 292 |
fn=predict_fn,
|
|
@@ -302,6 +371,11 @@ def create_app(model_path):
|
|
| 302 |
## About Brix Measurement
|
| 303 |
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
| 304 |
The average ripe watermelon has a Brix value between 9-11°.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
""")
|
| 306 |
|
| 307 |
return interface
|
|
@@ -309,12 +383,12 @@ def create_app(model_path):
|
|
| 309 |
if __name__ == "__main__":
|
| 310 |
import argparse
|
| 311 |
|
| 312 |
-
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App")
|
| 313 |
parser.add_argument(
|
| 314 |
-
"--
|
| 315 |
type=str,
|
| 316 |
-
default="models
|
| 317 |
-
help="
|
| 318 |
)
|
| 319 |
parser.add_argument(
|
| 320 |
"--share",
|
|
@@ -326,18 +400,40 @@ if __name__ == "__main__":
|
|
| 326 |
action="store_true",
|
| 327 |
help="Enable verbose debug output"
|
| 328 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
args = parser.parse_args()
|
| 331 |
|
| 332 |
if args.debug:
|
| 333 |
print(f"\033[92mINFO\033[0m: Debug mode enabled")
|
| 334 |
|
| 335 |
-
# Check if model exists
|
| 336 |
-
if not os.path.exists(args.
|
| 337 |
-
print(f"\033[91mERR!\033[0m: Model not found at {args.
|
| 338 |
-
print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path")
|
| 339 |
sys.exit(1)
|
| 340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
# Create and launch the app
|
| 342 |
-
app = create_app(args.
|
| 343 |
app.launch(share=args.share)
|
|
|
|
| 6 |
import torchaudio
|
| 7 |
import torchvision
|
| 8 |
import spaces
|
| 9 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Add parent directory to path to import preprocess functions
|
| 12 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
|
| 14 |
+
# Import functions from preprocess and model definitions
|
| 15 |
+
from preprocess import process_image_data
|
| 16 |
+
from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
|
| 17 |
+
|
| 18 |
+
# Define the top-performing models based on evaluation
|
| 19 |
+
TOP_MODELS = [
|
| 20 |
+
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
|
| 21 |
+
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
|
| 22 |
+
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# Define the MoE Model
|
| 26 |
+
class WatermelonMoEModel(torch.nn.Module):
|
| 27 |
+
def __init__(self, model_configs, model_dir="models", weights=None):
|
| 28 |
+
"""
|
| 29 |
+
Mixture of Experts model that combines multiple backbone models.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
|
| 33 |
+
model_dir: Directory where model checkpoints are stored
|
| 34 |
+
weights: Optional list of weights for each model (None for equal weighting)
|
| 35 |
+
"""
|
| 36 |
+
super(WatermelonMoEModel, self).__init__()
|
| 37 |
+
self.models = []
|
| 38 |
+
self.model_configs = model_configs
|
| 39 |
+
|
| 40 |
+
# Load each model
|
| 41 |
+
for config in model_configs:
|
| 42 |
+
img_backbone = config["image_backbone"]
|
| 43 |
+
audio_backbone = config["audio_backbone"]
|
| 44 |
+
|
| 45 |
+
# Initialize model
|
| 46 |
+
model = WatermelonModelModular(img_backbone, audio_backbone)
|
| 47 |
+
|
| 48 |
+
# Load weights
|
| 49 |
+
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
| 50 |
+
if os.path.exists(model_path):
|
| 51 |
+
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
| 52 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
| 53 |
+
else:
|
| 54 |
+
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
model.eval() # Set to evaluation mode
|
| 58 |
+
self.models.append(model)
|
| 59 |
+
|
| 60 |
+
# Set model weights (uniform by default)
|
| 61 |
+
if weights:
|
| 62 |
+
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
| 63 |
+
self.weights = weights
|
| 64 |
+
else:
|
| 65 |
+
self.weights = [1.0 / len(self.models)] * len(self.models)
|
| 66 |
+
|
| 67 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
|
| 68 |
+
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
| 69 |
+
|
| 70 |
+
def forward(self, mfcc, image):
|
| 71 |
+
"""
|
| 72 |
+
Forward pass through the MoE model.
|
| 73 |
+
Returns the weighted average of all model outputs.
|
| 74 |
+
"""
|
| 75 |
+
outputs = []
|
| 76 |
+
|
| 77 |
+
# Get outputs from each model
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
for i, model in enumerate(self.models):
|
| 80 |
+
output = model(mfcc, image)
|
| 81 |
+
outputs.append(output * self.weights[i])
|
| 82 |
+
|
| 83 |
+
# Return weighted average
|
| 84 |
+
return torch.sum(torch.stack(outputs), dim=0)
|
| 85 |
|
| 86 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
| 87 |
def app_process_audio_data(waveform, sample_rate):
|
|
|
|
| 137 |
print(traceback.format_exc())
|
| 138 |
return None
|
| 139 |
|
| 140 |
+
# Using the decorator for GPU acceleration
|
|
|
|
|
|
|
|
|
|
| 141 |
@spaces.GPU
|
| 142 |
+
def predict_sugar_content(audio, image, model_dir="models", weights=None):
|
| 143 |
+
"""Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
|
| 144 |
try:
|
| 145 |
+
# Check CUDA availability inside the GPU-decorated function
|
| 146 |
if torch.cuda.is_available():
|
| 147 |
device = torch.device("cuda")
|
| 148 |
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
|
|
|
| 150 |
device = torch.device("cpu")
|
| 151 |
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
| 152 |
|
| 153 |
+
# Load MoE model
|
| 154 |
+
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
|
| 155 |
+
moe_model.to(device)
|
| 156 |
+
moe_model.eval()
|
| 157 |
+
print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
|
| 158 |
|
| 159 |
# Debug information about input types
|
| 160 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
|
|
|
| 246 |
processed_image = processed_image.unsqueeze(0).to(device)
|
| 247 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
| 248 |
|
| 249 |
+
# Run inference with MoE model
|
| 250 |
+
print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
|
| 251 |
if mfcc is not None and processed_image is not None:
|
| 252 |
with torch.no_grad():
|
| 253 |
+
brix_value = moe_model(mfcc, processed_image)
|
| 254 |
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
| 255 |
else:
|
| 256 |
return "Error: Failed to process inputs. Please check the debug logs."
|
|
|
|
| 262 |
# Create a header with the numerical result
|
| 263 |
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
| 264 |
|
| 265 |
+
# Add extra info about the MoE model
|
| 266 |
+
result += "Using Ensemble of Top-3 Models:\n"
|
| 267 |
+
result += "- EfficientNet-B3 + Transformer\n"
|
| 268 |
+
result += "- EfficientNet-B0 + Transformer\n"
|
| 269 |
+
result += "- ResNet-50 + Transformer\n\n"
|
| 270 |
+
|
| 271 |
# Add Brix scale visualization
|
| 272 |
result += "Sugar Content Scale (in °Brix):\n"
|
| 273 |
result += "──────────────────────────────────\n"
|
|
|
|
| 321 |
error_msg += traceback.format_exc()
|
| 322 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
| 323 |
return error_msg
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
+
def create_app(model_dir="models", weights=None):
|
| 326 |
"""Create and launch the Gradio interface"""
|
| 327 |
# Define the prediction function with model path
|
| 328 |
def predict_fn(audio, image):
|
| 329 |
+
return predict_sugar_content(audio, image, model_dir, weights)
|
| 330 |
|
| 331 |
# Create Gradio interface
|
| 332 |
+
with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
|
| 333 |
+
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)")
|
| 334 |
gr.Markdown("""
|
| 335 |
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
| 336 |
|
| 337 |
+
## What's New
|
| 338 |
+
This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
|
| 339 |
+
- EfficientNet-B3 + Transformer
|
| 340 |
+
- EfficientNet-B0 + Transformer
|
| 341 |
+
- ResNet-50 + Transformer
|
| 342 |
+
|
| 343 |
+
The ensemble approach provides more accurate predictions than any single model!
|
| 344 |
+
|
| 345 |
## Instructions:
|
| 346 |
1. Upload or record an audio of tapping the watermelon
|
| 347 |
2. Upload or capture an image of the watermelon
|
|
|
|
| 355 |
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
| 356 |
|
| 357 |
with gr.Column():
|
| 358 |
+
output = gr.Textbox(label="Prediction Results", lines=15)
|
| 359 |
|
| 360 |
submit_btn.click(
|
| 361 |
fn=predict_fn,
|
|
|
|
| 371 |
## About Brix Measurement
|
| 372 |
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
| 373 |
The average ripe watermelon has a Brix value between 9-11°.
|
| 374 |
+
|
| 375 |
+
## About the Mixture of Experts Model
|
| 376 |
+
This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
|
| 377 |
+
Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
|
| 378 |
+
better than any individual model (best individual model: ~0.36 MAE).
|
| 379 |
""")
|
| 380 |
|
| 381 |
return interface
|
|
|
|
| 383 |
if __name__ == "__main__":
|
| 384 |
import argparse
|
| 385 |
|
| 386 |
+
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
|
| 387 |
parser.add_argument(
|
| 388 |
+
"--model_dir",
|
| 389 |
type=str,
|
| 390 |
+
default="models",
|
| 391 |
+
help="Directory containing the model checkpoints"
|
| 392 |
)
|
| 393 |
parser.add_argument(
|
| 394 |
"--share",
|
|
|
|
| 400 |
action="store_true",
|
| 401 |
help="Enable verbose debug output"
|
| 402 |
)
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--weighting",
|
| 405 |
+
type=str,
|
| 406 |
+
choices=["uniform", "performance"],
|
| 407 |
+
default="uniform",
|
| 408 |
+
help="How to weight the models (uniform or based on performance)"
|
| 409 |
+
)
|
| 410 |
|
| 411 |
args = parser.parse_args()
|
| 412 |
|
| 413 |
if args.debug:
|
| 414 |
print(f"\033[92mINFO\033[0m: Debug mode enabled")
|
| 415 |
|
| 416 |
+
# Check if model directory exists
|
| 417 |
+
if not os.path.exists(args.model_dir):
|
| 418 |
+
print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
|
|
|
|
| 419 |
sys.exit(1)
|
| 420 |
|
| 421 |
+
# Determine weights based on argument
|
| 422 |
+
weights = None
|
| 423 |
+
if args.weighting == "performance":
|
| 424 |
+
# Weights inversely proportional to the MAE (better models get higher weights)
|
| 425 |
+
# These are the MAE values from the evaluation results
|
| 426 |
+
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
|
| 427 |
+
|
| 428 |
+
# Convert to weights (inverse of MAE, normalized)
|
| 429 |
+
inverse_mae = [1/mae for mae in mae_values]
|
| 430 |
+
total = sum(inverse_mae)
|
| 431 |
+
weights = [val/total for val in inverse_mae]
|
| 432 |
+
|
| 433 |
+
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
|
| 434 |
+
else:
|
| 435 |
+
print(f"\033[92mINFO\033[0m: Using uniform weights")
|
| 436 |
+
|
| 437 |
# Create and launch the app
|
| 438 |
+
app = create_app(args.model_dir, weights)
|
| 439 |
app.launch(share=args.share)
|
app_local_backup.py
CHANGED
|
@@ -5,12 +5,22 @@ import numpy as np
|
|
| 5 |
import gradio as gr
|
| 6 |
import torchaudio
|
| 7 |
import torchvision
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Add parent directory to path to import preprocess functions
|
| 10 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 11 |
|
| 12 |
-
# Import functions from infer_watermelon.py
|
| 13 |
-
from
|
| 14 |
|
| 15 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
| 16 |
def app_process_audio_data(waveform, sample_rate):
|
|
@@ -69,14 +79,25 @@ def app_process_audio_data(waveform, sample_rate):
|
|
| 69 |
# Similarly for images, but let's import the original one
|
| 70 |
from preprocess import process_image_data
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def predict_sweetness(audio, image, model, device):
|
| 78 |
-
"""Predict sweetness of a watermelon from audio and image input"""
|
| 79 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Debug information about input types
|
| 81 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
| 82 |
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
|
|
@@ -97,7 +118,6 @@ def predict_sweetness(audio, image, model, device):
|
|
| 97 |
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
| 98 |
elif isinstance(audio, str):
|
| 99 |
# Direct path to audio file
|
| 100 |
-
import torchaudio
|
| 101 |
audio_data, sample_rate = torchaudio.load(audio)
|
| 102 |
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
|
| 103 |
else:
|
|
@@ -111,9 +131,6 @@ def predict_sweetness(audio, image, model, device):
|
|
| 111 |
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
|
| 112 |
|
| 113 |
# Import necessary libraries
|
| 114 |
-
import torchaudio
|
| 115 |
-
import torchvision
|
| 116 |
-
import torchvision.transforms.functional as F
|
| 117 |
from PIL import Image
|
| 118 |
|
| 119 |
# Audio handling - direct processing from the data in memory
|
|
@@ -162,7 +179,7 @@ def predict_sweetness(audio, image, model, device):
|
|
| 162 |
processed_image = process_image_data(image_tensor)
|
| 163 |
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
|
| 164 |
|
| 165 |
-
# Add batch dimension for inference
|
| 166 |
if mfcc is not None:
|
| 167 |
mfcc = mfcc.unsqueeze(0).to(device)
|
| 168 |
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
|
|
@@ -172,31 +189,67 @@ def predict_sweetness(audio, image, model, device):
|
|
| 172 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
| 173 |
|
| 174 |
# Run inference
|
| 175 |
-
print(f"\033[92mDEBUG\033[0m: Running inference")
|
| 176 |
if mfcc is not None and processed_image is not None:
|
| 177 |
with torch.no_grad():
|
| 178 |
-
|
| 179 |
-
print(f"\033[92mDEBUG\033[0m: Prediction successful: {
|
| 180 |
else:
|
| 181 |
return "Error: Failed to process inputs. Please check the debug logs."
|
| 182 |
|
| 183 |
-
# Format the result
|
| 184 |
-
if
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
#
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
else:
|
| 195 |
-
result += "
|
| 196 |
|
| 197 |
return result
|
| 198 |
else:
|
| 199 |
-
return "Error: Could not predict
|
| 200 |
|
| 201 |
except Exception as e:
|
| 202 |
import traceback
|
|
@@ -204,36 +257,36 @@ def predict_sweetness(audio, image, model, device):
|
|
| 204 |
error_msg += traceback.format_exc()
|
| 205 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
| 206 |
return error_msg
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
def create_app(model_path):
|
| 209 |
"""Create and launch the Gradio interface"""
|
| 210 |
-
#
|
| 211 |
-
model, device = init_model(model_path)
|
| 212 |
-
|
| 213 |
-
# Define the prediction function with model and device
|
| 214 |
def predict_fn(audio, image):
|
| 215 |
-
return
|
| 216 |
|
| 217 |
# Create Gradio interface
|
| 218 |
-
with gr.Blocks(title="Watermelon
|
| 219 |
-
gr.Markdown("# 🍉 Watermelon
|
| 220 |
gr.Markdown("""
|
| 221 |
-
This app predicts the
|
| 222 |
|
| 223 |
## Instructions:
|
| 224 |
1. Upload or record an audio of tapping the watermelon
|
| 225 |
2. Upload or capture an image of the watermelon
|
| 226 |
-
3. Click '
|
| 227 |
""")
|
| 228 |
|
| 229 |
with gr.Row():
|
| 230 |
with gr.Column():
|
| 231 |
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
|
| 232 |
image_input = gr.Image(label="Upload or Capture Image")
|
| 233 |
-
submit_btn = gr.Button("Predict
|
| 234 |
|
| 235 |
with gr.Column():
|
| 236 |
-
output = gr.Textbox(label="Prediction Results", lines=
|
| 237 |
|
| 238 |
submit_btn.click(
|
| 239 |
fn=predict_fn,
|
|
@@ -242,13 +295,13 @@ def create_app(model_path):
|
|
| 242 |
)
|
| 243 |
|
| 244 |
gr.Markdown("""
|
| 245 |
-
##
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
- Audio analysis using MFCC features and LSTM neural network
|
| 249 |
-
- Image analysis using ResNet-50 convolutional neural network
|
| 250 |
|
| 251 |
-
|
|
|
|
|
|
|
| 252 |
""")
|
| 253 |
|
| 254 |
return interface
|
|
@@ -256,7 +309,7 @@ def create_app(model_path):
|
|
| 256 |
if __name__ == "__main__":
|
| 257 |
import argparse
|
| 258 |
|
| 259 |
-
parser = argparse.ArgumentParser(description="Watermelon
|
| 260 |
parser.add_argument(
|
| 261 |
"--model_path",
|
| 262 |
type=str,
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import torchaudio
|
| 7 |
import torchvision
|
| 8 |
+
import spaces
|
| 9 |
+
|
| 10 |
+
# # Import Gradio Spaces GPU decorator
|
| 11 |
+
# try:
|
| 12 |
+
# from gradio import spaces
|
| 13 |
+
# HAS_SPACES = True
|
| 14 |
+
# print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
|
| 15 |
+
# except ImportError:
|
| 16 |
+
# HAS_SPACES = False
|
| 17 |
+
# print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
|
| 18 |
|
| 19 |
# Add parent directory to path to import preprocess functions
|
| 20 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
|
| 22 |
+
# Import functions from infer_watermelon.py and train_watermelon for the model
|
| 23 |
+
from train_watermelon import WatermelonModel
|
| 24 |
|
| 25 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
| 26 |
def app_process_audio_data(waveform, sample_rate):
|
|
|
|
| 79 |
# Similarly for images, but let's import the original one
|
| 80 |
from preprocess import process_image_data
|
| 81 |
|
| 82 |
+
# Using the decorator directly on the function definition
|
| 83 |
+
@spaces.GPU
|
| 84 |
+
def predict_sugar_content(audio, image, model_path):
|
| 85 |
+
"""Function with GPU acceleration to predict watermelon sugar content in Brix"""
|
|
|
|
|
|
|
|
|
|
| 86 |
try:
|
| 87 |
+
# Now check CUDA availability inside the GPU-decorated function
|
| 88 |
+
if torch.cuda.is_available():
|
| 89 |
+
device = torch.device("cuda")
|
| 90 |
+
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
| 91 |
+
else:
|
| 92 |
+
device = torch.device("cpu")
|
| 93 |
+
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
| 94 |
+
|
| 95 |
+
# Load model inside the function to ensure it's on the correct device
|
| 96 |
+
model = WatermelonModel().to(device)
|
| 97 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 98 |
+
model.eval()
|
| 99 |
+
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
|
| 100 |
+
|
| 101 |
# Debug information about input types
|
| 102 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
| 103 |
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
|
|
|
|
| 118 |
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
| 119 |
elif isinstance(audio, str):
|
| 120 |
# Direct path to audio file
|
|
|
|
| 121 |
audio_data, sample_rate = torchaudio.load(audio)
|
| 122 |
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
|
| 123 |
else:
|
|
|
|
| 131 |
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
|
| 132 |
|
| 133 |
# Import necessary libraries
|
|
|
|
|
|
|
|
|
|
| 134 |
from PIL import Image
|
| 135 |
|
| 136 |
# Audio handling - direct processing from the data in memory
|
|
|
|
| 179 |
processed_image = process_image_data(image_tensor)
|
| 180 |
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
|
| 181 |
|
| 182 |
+
# Add batch dimension for inference and move to device
|
| 183 |
if mfcc is not None:
|
| 184 |
mfcc = mfcc.unsqueeze(0).to(device)
|
| 185 |
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
|
|
|
|
| 189 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
| 190 |
|
| 191 |
# Run inference
|
| 192 |
+
print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
|
| 193 |
if mfcc is not None and processed_image is not None:
|
| 194 |
with torch.no_grad():
|
| 195 |
+
brix_value = model(mfcc, processed_image)
|
| 196 |
+
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
| 197 |
else:
|
| 198 |
return "Error: Failed to process inputs. Please check the debug logs."
|
| 199 |
|
| 200 |
+
# Format the result with a range display
|
| 201 |
+
if brix_value is not None:
|
| 202 |
+
brix_score = brix_value.item()
|
| 203 |
+
|
| 204 |
+
# Create a header with the numerical result
|
| 205 |
+
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
| 206 |
+
|
| 207 |
+
# Add Brix scale visualization
|
| 208 |
+
result += "Sugar Content Scale (in °Brix):\n"
|
| 209 |
+
result += "──────────────────────────────────\n"
|
| 210 |
+
|
| 211 |
+
# Create the scale display with Brix ranges
|
| 212 |
+
scale_ranges = [
|
| 213 |
+
(0, 8, "Low Sugar (< 8° Brix)"),
|
| 214 |
+
(8, 9, "Mild Sweetness (8-9° Brix)"),
|
| 215 |
+
(9, 10, "Medium Sweetness (9-10° Brix)"),
|
| 216 |
+
(10, 11, "Sweet (10-11° Brix)"),
|
| 217 |
+
(11, 13, "Very Sweet (11-13° Brix)")
|
| 218 |
+
]
|
| 219 |
|
| 220 |
+
# Find which category the prediction falls into
|
| 221 |
+
user_category = None
|
| 222 |
+
for min_val, max_val, category_name in scale_ranges:
|
| 223 |
+
if min_val <= brix_score < max_val:
|
| 224 |
+
user_category = category_name
|
| 225 |
+
break
|
| 226 |
+
if brix_score >= scale_ranges[-1][0]: # Handle edge case
|
| 227 |
+
user_category = scale_ranges[-1][2]
|
| 228 |
+
|
| 229 |
+
# Display the scale with the user's result highlighted
|
| 230 |
+
for min_val, max_val, category_name in scale_ranges:
|
| 231 |
+
if category_name == user_category:
|
| 232 |
+
result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n"
|
| 233 |
+
else:
|
| 234 |
+
result += f" {min_val}-{max_val}: {category_name}\n"
|
| 235 |
+
|
| 236 |
+
result += "──────────────────────────────────\n\n"
|
| 237 |
+
|
| 238 |
+
# Add assessment of the watermelon's sugar content
|
| 239 |
+
if brix_score < 8:
|
| 240 |
+
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
|
| 241 |
+
elif brix_score < 9:
|
| 242 |
+
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
|
| 243 |
+
elif brix_score < 10:
|
| 244 |
+
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
|
| 245 |
+
elif brix_score < 11:
|
| 246 |
+
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
|
| 247 |
else:
|
| 248 |
+
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
|
| 249 |
|
| 250 |
return result
|
| 251 |
else:
|
| 252 |
+
return "Error: Could not predict sugar content. Please try again with different inputs."
|
| 253 |
|
| 254 |
except Exception as e:
|
| 255 |
import traceback
|
|
|
|
| 257 |
error_msg += traceback.format_exc()
|
| 258 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
| 259 |
return error_msg
|
| 260 |
+
|
| 261 |
+
print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
|
| 262 |
+
|
| 263 |
|
| 264 |
def create_app(model_path):
|
| 265 |
"""Create and launch the Gradio interface"""
|
| 266 |
+
# Define the prediction function with model path
|
|
|
|
|
|
|
|
|
|
| 267 |
def predict_fn(audio, image):
|
| 268 |
+
return predict_sugar_content(audio, image, model_path)
|
| 269 |
|
| 270 |
# Create Gradio interface
|
| 271 |
+
with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface:
|
| 272 |
+
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor")
|
| 273 |
gr.Markdown("""
|
| 274 |
+
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
| 275 |
|
| 276 |
## Instructions:
|
| 277 |
1. Upload or record an audio of tapping the watermelon
|
| 278 |
2. Upload or capture an image of the watermelon
|
| 279 |
+
3. Click 'Predict' to get the sugar content estimation
|
| 280 |
""")
|
| 281 |
|
| 282 |
with gr.Row():
|
| 283 |
with gr.Column():
|
| 284 |
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
|
| 285 |
image_input = gr.Image(label="Upload or Capture Image")
|
| 286 |
+
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
| 287 |
|
| 288 |
with gr.Column():
|
| 289 |
+
output = gr.Textbox(label="Prediction Results", lines=12)
|
| 290 |
|
| 291 |
submit_btn.click(
|
| 292 |
fn=predict_fn,
|
|
|
|
| 295 |
)
|
| 296 |
|
| 297 |
gr.Markdown("""
|
| 298 |
+
## Tips for best results
|
| 299 |
+
- For audio: Tap the watermelon with your knuckle and record the sound
|
| 300 |
+
- For image: Take a clear photo of the whole watermelon in good lighting
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
## About Brix Measurement
|
| 303 |
+
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
| 304 |
+
The average ripe watermelon has a Brix value between 9-11°.
|
| 305 |
""")
|
| 306 |
|
| 307 |
return interface
|
|
|
|
| 309 |
if __name__ == "__main__":
|
| 310 |
import argparse
|
| 311 |
|
| 312 |
+
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App")
|
| 313 |
parser.add_argument(
|
| 314 |
"--model_path",
|
| 315 |
type=str,
|
app_moe.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torchaudio
|
| 7 |
+
import torchvision
|
| 8 |
+
import spaces
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
# Add parent directory to path to import preprocess functions
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
+
|
| 14 |
+
# Import functions from preprocess and model definitions
|
| 15 |
+
from preprocess import process_image_data
|
| 16 |
+
from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
|
| 17 |
+
|
| 18 |
+
# Define the top-performing models based on evaluation
|
| 19 |
+
TOP_MODELS = [
|
| 20 |
+
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
|
| 21 |
+
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
|
| 22 |
+
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# Define the MoE Model
|
| 26 |
+
class WatermelonMoEModel(torch.nn.Module):
|
| 27 |
+
def __init__(self, model_configs, model_dir="models", weights=None):
|
| 28 |
+
"""
|
| 29 |
+
Mixture of Experts model that combines multiple backbone models.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
|
| 33 |
+
model_dir: Directory where model checkpoints are stored
|
| 34 |
+
weights: Optional list of weights for each model (None for equal weighting)
|
| 35 |
+
"""
|
| 36 |
+
super(WatermelonMoEModel, self).__init__()
|
| 37 |
+
self.models = []
|
| 38 |
+
self.model_configs = model_configs
|
| 39 |
+
|
| 40 |
+
# Load each model
|
| 41 |
+
for config in model_configs:
|
| 42 |
+
img_backbone = config["image_backbone"]
|
| 43 |
+
audio_backbone = config["audio_backbone"]
|
| 44 |
+
|
| 45 |
+
# Initialize model
|
| 46 |
+
model = WatermelonModelModular(img_backbone, audio_backbone)
|
| 47 |
+
|
| 48 |
+
# Load weights
|
| 49 |
+
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
| 50 |
+
if os.path.exists(model_path):
|
| 51 |
+
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
| 52 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
| 53 |
+
else:
|
| 54 |
+
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
model.eval() # Set to evaluation mode
|
| 58 |
+
self.models.append(model)
|
| 59 |
+
|
| 60 |
+
# Set model weights (uniform by default)
|
| 61 |
+
if weights:
|
| 62 |
+
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
| 63 |
+
self.weights = weights
|
| 64 |
+
else:
|
| 65 |
+
self.weights = [1.0 / len(self.models)] * len(self.models)
|
| 66 |
+
|
| 67 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
|
| 68 |
+
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
| 69 |
+
|
| 70 |
+
def forward(self, mfcc, image):
|
| 71 |
+
"""
|
| 72 |
+
Forward pass through the MoE model.
|
| 73 |
+
Returns the weighted average of all model outputs.
|
| 74 |
+
"""
|
| 75 |
+
outputs = []
|
| 76 |
+
|
| 77 |
+
# Get outputs from each model
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
for i, model in enumerate(self.models):
|
| 80 |
+
output = model(mfcc, image)
|
| 81 |
+
outputs.append(output * self.weights[i])
|
| 82 |
+
|
| 83 |
+
# Return weighted average
|
| 84 |
+
return torch.sum(torch.stack(outputs), dim=0)
|
| 85 |
+
|
| 86 |
+
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
| 87 |
+
def app_process_audio_data(waveform, sample_rate):
|
| 88 |
+
"""Modified version of process_audio_data for the app that handles different tensor dimensions"""
|
| 89 |
+
try:
|
| 90 |
+
print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
|
| 91 |
+
|
| 92 |
+
# Handle different tensor dimensions
|
| 93 |
+
if waveform.dim() == 3:
|
| 94 |
+
print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
|
| 95 |
+
# For 3D tensor, take the first item (batch dimension)
|
| 96 |
+
waveform = waveform[0]
|
| 97 |
+
|
| 98 |
+
if waveform.dim() == 2:
|
| 99 |
+
# Use the first channel for stereo audio
|
| 100 |
+
waveform = waveform[0]
|
| 101 |
+
print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
|
| 102 |
+
|
| 103 |
+
# Resample to 16kHz if needed
|
| 104 |
+
resample_rate = 16000
|
| 105 |
+
if sample_rate != resample_rate:
|
| 106 |
+
print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
|
| 107 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
| 108 |
+
|
| 109 |
+
# Ensure 3 seconds of audio
|
| 110 |
+
if waveform.size(0) < 3 * resample_rate:
|
| 111 |
+
print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
|
| 112 |
+
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
|
| 113 |
+
else:
|
| 114 |
+
print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
|
| 115 |
+
waveform = waveform[: 3 * resample_rate]
|
| 116 |
+
|
| 117 |
+
# Apply MFCC transformation
|
| 118 |
+
print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
|
| 119 |
+
mfcc_transform = torchaudio.transforms.MFCC(
|
| 120 |
+
sample_rate=resample_rate,
|
| 121 |
+
n_mfcc=13,
|
| 122 |
+
melkwargs={
|
| 123 |
+
"n_fft": 256,
|
| 124 |
+
"win_length": 256,
|
| 125 |
+
"hop_length": 128,
|
| 126 |
+
"n_mels": 40,
|
| 127 |
+
}
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
mfcc = mfcc_transform(waveform)
|
| 131 |
+
print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
|
| 132 |
+
|
| 133 |
+
return mfcc
|
| 134 |
+
except Exception as e:
|
| 135 |
+
import traceback
|
| 136 |
+
print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
|
| 137 |
+
print(traceback.format_exc())
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
# Using the decorator for GPU acceleration
|
| 141 |
+
@spaces.GPU
|
| 142 |
+
def predict_sugar_content(audio, image, model_dir="models", weights=None):
|
| 143 |
+
"""Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
|
| 144 |
+
try:
|
| 145 |
+
# Check CUDA availability inside the GPU-decorated function
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
device = torch.device("cuda")
|
| 148 |
+
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
| 149 |
+
else:
|
| 150 |
+
device = torch.device("cpu")
|
| 151 |
+
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
| 152 |
+
|
| 153 |
+
# Load MoE model
|
| 154 |
+
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
|
| 155 |
+
moe_model.to(device)
|
| 156 |
+
moe_model.eval()
|
| 157 |
+
print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
|
| 158 |
+
|
| 159 |
+
# Debug information about input types
|
| 160 |
+
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
| 161 |
+
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
|
| 162 |
+
print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
|
| 163 |
+
if isinstance(image, np.ndarray):
|
| 164 |
+
print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
|
| 165 |
+
|
| 166 |
+
# Handle different audio input formats
|
| 167 |
+
if isinstance(audio, tuple) and len(audio) == 2:
|
| 168 |
+
# Standard Gradio format: (sample_rate, audio_data)
|
| 169 |
+
sample_rate, audio_data = audio
|
| 170 |
+
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
|
| 171 |
+
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
| 172 |
+
elif isinstance(audio, tuple) and len(audio) > 2:
|
| 173 |
+
# Sometimes Gradio returns (sample_rate, audio_data, other_info...)
|
| 174 |
+
sample_rate, audio_data = audio[0], audio[-1]
|
| 175 |
+
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
|
| 176 |
+
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
| 177 |
+
elif isinstance(audio, str):
|
| 178 |
+
# Direct path to audio file
|
| 179 |
+
audio_data, sample_rate = torchaudio.load(audio)
|
| 180 |
+
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
|
| 181 |
+
else:
|
| 182 |
+
return f"Error: Unsupported audio format. Got {type(audio)}"
|
| 183 |
+
|
| 184 |
+
# Create a temporary file path for the audio and image
|
| 185 |
+
temp_dir = "temp"
|
| 186 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 187 |
+
|
| 188 |
+
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
|
| 189 |
+
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
|
| 190 |
+
|
| 191 |
+
# Import necessary libraries
|
| 192 |
+
from PIL import Image
|
| 193 |
+
|
| 194 |
+
# Audio handling - direct processing from the data in memory
|
| 195 |
+
if isinstance(audio_data, np.ndarray):
|
| 196 |
+
# Convert numpy array to tensor
|
| 197 |
+
print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
|
| 198 |
+
audio_tensor = torch.tensor(audio_data).float()
|
| 199 |
+
|
| 200 |
+
# Handle different audio dimensions
|
| 201 |
+
if audio_data.ndim == 1:
|
| 202 |
+
# Single channel audio
|
| 203 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 204 |
+
elif audio_data.ndim == 2:
|
| 205 |
+
# Ensure channels are first dimension
|
| 206 |
+
if audio_data.shape[0] > audio_data.shape[1]:
|
| 207 |
+
# More rows than columns, probably (samples, channels)
|
| 208 |
+
audio_tensor = torch.tensor(audio_data.T).float()
|
| 209 |
+
else:
|
| 210 |
+
# Already a tensor
|
| 211 |
+
audio_tensor = audio_data.float()
|
| 212 |
+
|
| 213 |
+
print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
|
| 214 |
+
|
| 215 |
+
# Skip saving/loading and process directly
|
| 216 |
+
mfcc = app_process_audio_data(audio_tensor, sample_rate)
|
| 217 |
+
print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
|
| 218 |
+
|
| 219 |
+
# Image handling
|
| 220 |
+
if isinstance(image, np.ndarray):
|
| 221 |
+
print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
|
| 222 |
+
pil_image = Image.fromarray(image)
|
| 223 |
+
pil_image.save(temp_image_path)
|
| 224 |
+
print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
|
| 225 |
+
elif isinstance(image, str):
|
| 226 |
+
# If image is already a path
|
| 227 |
+
temp_image_path = image
|
| 228 |
+
print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
|
| 229 |
+
else:
|
| 230 |
+
return f"Error: Unsupported image format. Got {type(image)}"
|
| 231 |
+
|
| 232 |
+
# Process image
|
| 233 |
+
print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
|
| 234 |
+
image_tensor = torchvision.io.read_image(temp_image_path)
|
| 235 |
+
print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
|
| 236 |
+
image_tensor = image_tensor.float()
|
| 237 |
+
processed_image = process_image_data(image_tensor)
|
| 238 |
+
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
|
| 239 |
+
|
| 240 |
+
# Add batch dimension for inference and move to device
|
| 241 |
+
if mfcc is not None:
|
| 242 |
+
mfcc = mfcc.unsqueeze(0).to(device)
|
| 243 |
+
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
|
| 244 |
+
|
| 245 |
+
if processed_image is not None:
|
| 246 |
+
processed_image = processed_image.unsqueeze(0).to(device)
|
| 247 |
+
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
| 248 |
+
|
| 249 |
+
# Run inference with MoE model
|
| 250 |
+
print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
|
| 251 |
+
if mfcc is not None and processed_image is not None:
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
brix_value = moe_model(mfcc, processed_image)
|
| 254 |
+
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
| 255 |
+
else:
|
| 256 |
+
return "Error: Failed to process inputs. Please check the debug logs."
|
| 257 |
+
|
| 258 |
+
# Format the result with a range display
|
| 259 |
+
if brix_value is not None:
|
| 260 |
+
brix_score = brix_value.item()
|
| 261 |
+
|
| 262 |
+
# Create a header with the numerical result
|
| 263 |
+
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
| 264 |
+
|
| 265 |
+
# Add extra info about the MoE model
|
| 266 |
+
result += "Using Ensemble of Top-3 Models:\n"
|
| 267 |
+
result += "- EfficientNet-B3 + Transformer\n"
|
| 268 |
+
result += "- EfficientNet-B0 + Transformer\n"
|
| 269 |
+
result += "- ResNet-50 + Transformer\n\n"
|
| 270 |
+
|
| 271 |
+
# Add Brix scale visualization
|
| 272 |
+
result += "Sugar Content Scale (in °Brix):\n"
|
| 273 |
+
result += "──────────────────────────────────\n"
|
| 274 |
+
|
| 275 |
+
# Create the scale display with Brix ranges
|
| 276 |
+
scale_ranges = [
|
| 277 |
+
(0, 8, "Low Sugar (< 8° Brix)"),
|
| 278 |
+
(8, 9, "Mild Sweetness (8-9° Brix)"),
|
| 279 |
+
(9, 10, "Medium Sweetness (9-10° Brix)"),
|
| 280 |
+
(10, 11, "Sweet (10-11° Brix)"),
|
| 281 |
+
(11, 13, "Very Sweet (11-13° Brix)")
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
# Find which category the prediction falls into
|
| 285 |
+
user_category = None
|
| 286 |
+
for min_val, max_val, category_name in scale_ranges:
|
| 287 |
+
if min_val <= brix_score < max_val:
|
| 288 |
+
user_category = category_name
|
| 289 |
+
break
|
| 290 |
+
if brix_score >= scale_ranges[-1][0]: # Handle edge case
|
| 291 |
+
user_category = scale_ranges[-1][2]
|
| 292 |
+
|
| 293 |
+
# Display the scale with the user's result highlighted
|
| 294 |
+
for min_val, max_val, category_name in scale_ranges:
|
| 295 |
+
if category_name == user_category:
|
| 296 |
+
result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n"
|
| 297 |
+
else:
|
| 298 |
+
result += f" {min_val}-{max_val}: {category_name}\n"
|
| 299 |
+
|
| 300 |
+
result += "──────────────────────────────────\n\n"
|
| 301 |
+
|
| 302 |
+
# Add assessment of the watermelon's sugar content
|
| 303 |
+
if brix_score < 8:
|
| 304 |
+
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
|
| 305 |
+
elif brix_score < 9:
|
| 306 |
+
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
|
| 307 |
+
elif brix_score < 10:
|
| 308 |
+
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
|
| 309 |
+
elif brix_score < 11:
|
| 310 |
+
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
|
| 311 |
+
else:
|
| 312 |
+
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
|
| 313 |
+
|
| 314 |
+
return result
|
| 315 |
+
else:
|
| 316 |
+
return "Error: Could not predict sugar content. Please try again with different inputs."
|
| 317 |
+
|
| 318 |
+
except Exception as e:
|
| 319 |
+
import traceback
|
| 320 |
+
error_msg = f"Error: {str(e)}\n\n"
|
| 321 |
+
error_msg += traceback.format_exc()
|
| 322 |
+
print(f"\033[91mERR!\033[0m: {error_msg}")
|
| 323 |
+
return error_msg
|
| 324 |
+
|
| 325 |
+
def create_app(model_dir="models", weights=None):
|
| 326 |
+
"""Create and launch the Gradio interface"""
|
| 327 |
+
# Define the prediction function with model path
|
| 328 |
+
def predict_fn(audio, image):
|
| 329 |
+
return predict_sugar_content(audio, image, model_dir, weights)
|
| 330 |
+
|
| 331 |
+
# Create Gradio interface
|
| 332 |
+
with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
|
| 333 |
+
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)")
|
| 334 |
+
gr.Markdown("""
|
| 335 |
+
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
| 336 |
+
|
| 337 |
+
## What's New
|
| 338 |
+
This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
|
| 339 |
+
- EfficientNet-B3 + Transformer
|
| 340 |
+
- EfficientNet-B0 + Transformer
|
| 341 |
+
- ResNet-50 + Transformer
|
| 342 |
+
|
| 343 |
+
The ensemble approach provides more accurate predictions than any single model!
|
| 344 |
+
|
| 345 |
+
## Instructions:
|
| 346 |
+
1. Upload or record an audio of tapping the watermelon
|
| 347 |
+
2. Upload or capture an image of the watermelon
|
| 348 |
+
3. Click 'Predict' to get the sugar content estimation
|
| 349 |
+
""")
|
| 350 |
+
|
| 351 |
+
with gr.Row():
|
| 352 |
+
with gr.Column():
|
| 353 |
+
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
|
| 354 |
+
image_input = gr.Image(label="Upload or Capture Image")
|
| 355 |
+
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
| 356 |
+
|
| 357 |
+
with gr.Column():
|
| 358 |
+
output = gr.Textbox(label="Prediction Results", lines=15)
|
| 359 |
+
|
| 360 |
+
submit_btn.click(
|
| 361 |
+
fn=predict_fn,
|
| 362 |
+
inputs=[audio_input, image_input],
|
| 363 |
+
outputs=output
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
gr.Markdown("""
|
| 367 |
+
## Tips for best results
|
| 368 |
+
- For audio: Tap the watermelon with your knuckle and record the sound
|
| 369 |
+
- For image: Take a clear photo of the whole watermelon in good lighting
|
| 370 |
+
|
| 371 |
+
## About Brix Measurement
|
| 372 |
+
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
| 373 |
+
The average ripe watermelon has a Brix value between 9-11°.
|
| 374 |
+
|
| 375 |
+
## About the Mixture of Experts Model
|
| 376 |
+
This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
|
| 377 |
+
Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
|
| 378 |
+
better than any individual model (best individual model: ~0.36 MAE).
|
| 379 |
+
""")
|
| 380 |
+
|
| 381 |
+
return interface
|
| 382 |
+
|
| 383 |
+
if __name__ == "__main__":
|
| 384 |
+
import argparse
|
| 385 |
+
|
| 386 |
+
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
|
| 387 |
+
parser.add_argument(
|
| 388 |
+
"--model_dir",
|
| 389 |
+
type=str,
|
| 390 |
+
default="models",
|
| 391 |
+
help="Directory containing the model checkpoints"
|
| 392 |
+
)
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"--share",
|
| 395 |
+
action="store_true",
|
| 396 |
+
help="Create a shareable link for the app"
|
| 397 |
+
)
|
| 398 |
+
parser.add_argument(
|
| 399 |
+
"--debug",
|
| 400 |
+
action="store_true",
|
| 401 |
+
help="Enable verbose debug output"
|
| 402 |
+
)
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--weighting",
|
| 405 |
+
type=str,
|
| 406 |
+
choices=["uniform", "performance"],
|
| 407 |
+
default="uniform",
|
| 408 |
+
help="How to weight the models (uniform or based on performance)"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
args = parser.parse_args()
|
| 412 |
+
|
| 413 |
+
if args.debug:
|
| 414 |
+
print(f"\033[92mINFO\033[0m: Debug mode enabled")
|
| 415 |
+
|
| 416 |
+
# Check if model directory exists
|
| 417 |
+
if not os.path.exists(args.model_dir):
|
| 418 |
+
print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
|
| 419 |
+
sys.exit(1)
|
| 420 |
+
|
| 421 |
+
# Determine weights based on argument
|
| 422 |
+
weights = None
|
| 423 |
+
if args.weighting == "performance":
|
| 424 |
+
# Weights inversely proportional to the MAE (better models get higher weights)
|
| 425 |
+
# These are the MAE values from the evaluation results
|
| 426 |
+
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
|
| 427 |
+
|
| 428 |
+
# Convert to weights (inverse of MAE, normalized)
|
| 429 |
+
inverse_mae = [1/mae for mae in mae_values]
|
| 430 |
+
total = sum(inverse_mae)
|
| 431 |
+
weights = [val/total for val in inverse_mae]
|
| 432 |
+
|
| 433 |
+
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
|
| 434 |
+
else:
|
| 435 |
+
print(f"\033[92mINFO\033[0m: Using uniform weights")
|
| 436 |
+
|
| 437 |
+
# Create and launch the app
|
| 438 |
+
app = create_app(args.model_dir, weights)
|
| 439 |
+
app.launch(share=args.share)
|
backbone_evaluation_results.json
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"image_backbone": "efficientnet_b3",
|
| 4 |
+
"audio_backbone": "transformer",
|
| 5 |
+
"validation_mse": 0.21577325425086877,
|
| 6 |
+
"validation_mae": 0.36228722945237773,
|
| 7 |
+
"test_mse": 0.21746371760964395,
|
| 8 |
+
"test_mae": 0.36353210285305976,
|
| 9 |
+
"model_path": "test_models/efficientnet_b3_transformer_model.pt"
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"image_backbone": "efficientnet_b0",
|
| 13 |
+
"audio_backbone": "transformer",
|
| 14 |
+
"validation_mse": 0.24033201676912797,
|
| 15 |
+
"validation_mae": 0.42209602166444826,
|
| 16 |
+
"test_mse": 0.19470563121140003,
|
| 17 |
+
"test_mae": 0.37649240642786025,
|
| 18 |
+
"model_path": "test_models/efficientnet_b0_transformer_model.pt"
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"image_backbone": "resnet50",
|
| 22 |
+
"audio_backbone": "transformer",
|
| 23 |
+
"validation_mse": 0.22672857019381645,
|
| 24 |
+
"validation_mae": 0.3926378931754675,
|
| 25 |
+
"test_mse": 0.22427306957542897,
|
| 26 |
+
"test_mae": 0.39585837423801423,
|
| 27 |
+
"model_path": "test_models/resnet50_transformer_model.pt"
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"image_backbone": "resnet50",
|
| 31 |
+
"audio_backbone": "bidirectional_lstm",
|
| 32 |
+
"validation_mse": 0.2967155438203078,
|
| 33 |
+
"validation_mae": 0.3850937023376807,
|
| 34 |
+
"test_mse": 0.36476454623043536,
|
| 35 |
+
"test_mae": 0.425818096101284,
|
| 36 |
+
"model_path": "test_models/resnet50_bidirectional_lstm_model.pt"
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"image_backbone": "efficientnet_b0",
|
| 40 |
+
"audio_backbone": "bidirectional_lstm",
|
| 41 |
+
"validation_mse": 0.5120524473679371,
|
| 42 |
+
"validation_mae": 0.5665570046657171,
|
| 43 |
+
"test_mse": 0.5059382550418376,
|
| 44 |
+
"test_mae": 0.555050653219223,
|
| 45 |
+
"model_path": "test_models/efficientnet_b0_bidirectional_lstm_model.pt"
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"image_backbone": "efficientnet_b3",
|
| 49 |
+
"audio_backbone": "bidirectional_lstm",
|
| 50 |
+
"validation_mse": 0.8020018790012751,
|
| 51 |
+
"validation_mae": 0.7953977386156718,
|
| 52 |
+
"test_mse": 0.7042828559875488,
|
| 53 |
+
"test_mae": 0.7441241115331649,
|
| 54 |
+
"model_path": "test_models/efficientnet_b3_bidirectional_lstm_model.pt"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"image_backbone": "efficientnet_b0",
|
| 58 |
+
"audio_backbone": "gru",
|
| 59 |
+
"validation_mse": 1.1340507984161377,
|
| 60 |
+
"validation_mae": 0.8290961503982544,
|
| 61 |
+
"test_mse": 0.9705999374389649,
|
| 62 |
+
"test_mae": 0.7704607486724854,
|
| 63 |
+
"model_path": "test_models/efficientnet_b0_gru_model.pt"
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"image_backbone": "efficientnet_b0",
|
| 67 |
+
"audio_backbone": "lstm",
|
| 68 |
+
"validation_mse": 2.787272185087204,
|
| 69 |
+
"validation_mae": 1.5404645502567291,
|
| 70 |
+
"test_mse": 2.901867628097534,
|
| 71 |
+
"test_mae": 1.5843785762786866,
|
| 72 |
+
"model_path": "test_models/efficientnet_b0_lstm_model.pt"
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"image_backbone": "resnet50",
|
| 76 |
+
"audio_backbone": "gru",
|
| 77 |
+
"validation_mse": 3.9335442543029786,
|
| 78 |
+
"validation_mae": 1.8762320041656495,
|
| 79 |
+
"test_mse": 3.72695152759552,
|
| 80 |
+
"test_mae": 1.8381730556488036,
|
| 81 |
+
"model_path": "test_models/resnet50_gru_model.pt"
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"image_backbone": "resnet50",
|
| 85 |
+
"audio_backbone": "lstm",
|
| 86 |
+
"validation_mse": 6.088638782501221,
|
| 87 |
+
"validation_mae": 2.3887929677963258,
|
| 88 |
+
"test_mse": 6.1847597599029545,
|
| 89 |
+
"test_mae": 2.418113374710083,
|
| 90 |
+
"model_path": "test_models/resnet50_lstm_model.pt"
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"image_backbone": "efficientnet_b3",
|
| 94 |
+
"audio_backbone": "gru",
|
| 95 |
+
"validation_mse": 104.58460273742676,
|
| 96 |
+
"validation_mae": 10.183499813079834,
|
| 97 |
+
"test_mse": 104.58482055664062,
|
| 98 |
+
"test_mae": 10.180697345733643,
|
| 99 |
+
"model_path": "test_models/efficientnet_b3_gru_model.pt"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"image_backbone": "efficientnet_b3",
|
| 103 |
+
"audio_backbone": "lstm",
|
| 104 |
+
"validation_mse": 105.40057525634765,
|
| 105 |
+
"validation_mae": 10.221695899963379,
|
| 106 |
+
"test_mse": 105.17274551391601,
|
| 107 |
+
"test_mae": 10.21053056716919,
|
| 108 |
+
"model_path": "test_models/efficientnet_b3_lstm_model.pt"
|
| 109 |
+
}
|
| 110 |
+
]
|
evaluate_backbones.py
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import torchvision
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
import sys
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
# Add parent directory to path to import the preprocess functions
|
| 13 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 14 |
+
from preprocess import process_audio_data, process_image_data
|
| 15 |
+
|
| 16 |
+
# Print library versions
|
| 17 |
+
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
|
| 18 |
+
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
|
| 19 |
+
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
|
| 20 |
+
|
| 21 |
+
# Device selection
|
| 22 |
+
device = torch.device(
|
| 23 |
+
"cuda"
|
| 24 |
+
if torch.cuda.is_available()
|
| 25 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 26 |
+
)
|
| 27 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
| 28 |
+
|
| 29 |
+
# Hyperparameters
|
| 30 |
+
batch_size = 16
|
| 31 |
+
epochs = 1 # Just one epoch for evaluation
|
| 32 |
+
learning_rate = 0.0001
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class WatermelonDataset(Dataset):
|
| 36 |
+
def __init__(self, data_dir):
|
| 37 |
+
self.data_dir = data_dir
|
| 38 |
+
self.samples = []
|
| 39 |
+
|
| 40 |
+
# Walk through the directory structure
|
| 41 |
+
for sweetness_dir in os.listdir(data_dir):
|
| 42 |
+
sweetness = float(sweetness_dir)
|
| 43 |
+
sweetness_path = os.path.join(data_dir, sweetness_dir)
|
| 44 |
+
|
| 45 |
+
if os.path.isdir(sweetness_path):
|
| 46 |
+
for id_dir in os.listdir(sweetness_path):
|
| 47 |
+
id_path = os.path.join(sweetness_path, id_dir)
|
| 48 |
+
|
| 49 |
+
if os.path.isdir(id_path):
|
| 50 |
+
audio_file = os.path.join(id_path, f"{id_dir}.wav")
|
| 51 |
+
image_file = os.path.join(id_path, f"{id_dir}.jpg")
|
| 52 |
+
|
| 53 |
+
if os.path.exists(audio_file) and os.path.exists(image_file):
|
| 54 |
+
self.samples.append((audio_file, image_file, sweetness))
|
| 55 |
+
|
| 56 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}")
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.samples)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, idx):
|
| 62 |
+
audio_path, image_path, label = self.samples[idx]
|
| 63 |
+
|
| 64 |
+
# Load and process audio
|
| 65 |
+
try:
|
| 66 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 67 |
+
mfcc = process_audio_data(waveform, sample_rate)
|
| 68 |
+
|
| 69 |
+
# Load and process image
|
| 70 |
+
image = torchvision.io.read_image(image_path)
|
| 71 |
+
image = image.float()
|
| 72 |
+
processed_image = process_image_data(image)
|
| 73 |
+
|
| 74 |
+
return mfcc, processed_image, torch.tensor(label).float()
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}")
|
| 77 |
+
# Return a fallback sample or skip this sample
|
| 78 |
+
# For simplicity, we'll return the first sample again
|
| 79 |
+
if idx == 0: # Prevent infinite recursion
|
| 80 |
+
raise e
|
| 81 |
+
return self.__getitem__(0)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Define available backbone models
|
| 85 |
+
IMAGE_BACKBONES = {
|
| 86 |
+
"resnet50": {
|
| 87 |
+
"model": torchvision.models.resnet50,
|
| 88 |
+
"weights": torchvision.models.ResNet50_Weights.DEFAULT,
|
| 89 |
+
"output_dim": lambda model: model.fc.in_features
|
| 90 |
+
},
|
| 91 |
+
"efficientnet_b0": {
|
| 92 |
+
"model": torchvision.models.efficientnet_b0,
|
| 93 |
+
"weights": torchvision.models.EfficientNet_B0_Weights.DEFAULT,
|
| 94 |
+
"output_dim": lambda model: model.classifier[1].in_features
|
| 95 |
+
},
|
| 96 |
+
"efficientnet_b3": {
|
| 97 |
+
"model": torchvision.models.efficientnet_b3,
|
| 98 |
+
"weights": torchvision.models.EfficientNet_B3_Weights.DEFAULT,
|
| 99 |
+
"output_dim": lambda model: model.classifier[1].in_features
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
AUDIO_BACKBONES = {
|
| 104 |
+
"lstm": {
|
| 105 |
+
"model": lambda input_size, hidden_size: torch.nn.LSTM(
|
| 106 |
+
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
|
| 107 |
+
),
|
| 108 |
+
"output_dim": lambda hidden_size: hidden_size
|
| 109 |
+
},
|
| 110 |
+
"gru": {
|
| 111 |
+
"model": lambda input_size, hidden_size: torch.nn.GRU(
|
| 112 |
+
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
|
| 113 |
+
),
|
| 114 |
+
"output_dim": lambda hidden_size: hidden_size
|
| 115 |
+
},
|
| 116 |
+
"bidirectional_lstm": {
|
| 117 |
+
"model": lambda input_size, hidden_size: torch.nn.LSTM(
|
| 118 |
+
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, bidirectional=True
|
| 119 |
+
),
|
| 120 |
+
"output_dim": lambda hidden_size: hidden_size * 2 # * 2 because bidirectional
|
| 121 |
+
},
|
| 122 |
+
"transformer": {
|
| 123 |
+
"model": lambda input_size, hidden_size: torch.nn.TransformerEncoder(
|
| 124 |
+
torch.nn.TransformerEncoderLayer(
|
| 125 |
+
d_model=input_size, nhead=8, dim_feedforward=hidden_size, batch_first=True
|
| 126 |
+
),
|
| 127 |
+
num_layers=2
|
| 128 |
+
),
|
| 129 |
+
"output_dim": lambda hidden_size: 376 # Using input_size (mfcc dimensions)
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class WatermelonModelModular(torch.nn.Module):
|
| 135 |
+
def __init__(self, image_backbone_name, audio_backbone_name, audio_hidden_size=128):
|
| 136 |
+
super(WatermelonModelModular, self).__init__()
|
| 137 |
+
|
| 138 |
+
# Audio backbone setup
|
| 139 |
+
self.audio_backbone_name = audio_backbone_name
|
| 140 |
+
self.audio_hidden_size = audio_hidden_size
|
| 141 |
+
self.audio_input_size = 376 # From MFCC dimensions
|
| 142 |
+
|
| 143 |
+
audio_config = AUDIO_BACKBONES[audio_backbone_name]
|
| 144 |
+
self.audio_backbone = audio_config["model"](self.audio_input_size, self.audio_hidden_size)
|
| 145 |
+
audio_output_dim = audio_config["output_dim"](self.audio_hidden_size)
|
| 146 |
+
|
| 147 |
+
self.audio_fc = torch.nn.Linear(audio_output_dim, 128)
|
| 148 |
+
|
| 149 |
+
# Image backbone setup
|
| 150 |
+
self.image_backbone_name = image_backbone_name
|
| 151 |
+
image_config = IMAGE_BACKBONES[image_backbone_name]
|
| 152 |
+
|
| 153 |
+
self.image_backbone = image_config["model"](weights=image_config["weights"])
|
| 154 |
+
|
| 155 |
+
# Replace final layer for all image backbones to get features
|
| 156 |
+
if image_backbone_name.startswith("resnet"):
|
| 157 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
| 158 |
+
self.image_backbone.fc = torch.nn.Identity()
|
| 159 |
+
elif image_backbone_name.startswith("efficientnet"):
|
| 160 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
| 161 |
+
self.image_backbone.classifier = torch.nn.Identity()
|
| 162 |
+
elif image_backbone_name.startswith("convnext"):
|
| 163 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
| 164 |
+
self.image_backbone.classifier = torch.nn.Identity()
|
| 165 |
+
elif image_backbone_name.startswith("swin"):
|
| 166 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
| 167 |
+
self.image_backbone.head = torch.nn.Identity()
|
| 168 |
+
|
| 169 |
+
self.image_fc = torch.nn.Linear(self.image_output_dim, 128)
|
| 170 |
+
|
| 171 |
+
# Fully connected layers for final prediction
|
| 172 |
+
self.fc1 = torch.nn.Linear(256, 64)
|
| 173 |
+
self.fc2 = torch.nn.Linear(64, 1)
|
| 174 |
+
self.relu = torch.nn.ReLU()
|
| 175 |
+
|
| 176 |
+
def forward(self, mfcc, image):
|
| 177 |
+
# Audio backbone processing
|
| 178 |
+
if self.audio_backbone_name == "lstm" or self.audio_backbone_name == "gru":
|
| 179 |
+
audio_output, _ = self.audio_backbone(mfcc)
|
| 180 |
+
audio_output = audio_output[:, -1, :] # Use the output of the last time step
|
| 181 |
+
elif self.audio_backbone_name == "bidirectional_lstm":
|
| 182 |
+
audio_output, _ = self.audio_backbone(mfcc)
|
| 183 |
+
audio_output = audio_output[:, -1, :] # Use the output of the last time step
|
| 184 |
+
elif self.audio_backbone_name == "transformer":
|
| 185 |
+
audio_output = self.audio_backbone(mfcc)
|
| 186 |
+
audio_output = audio_output.mean(dim=1) # Average pooling over sequence length
|
| 187 |
+
|
| 188 |
+
audio_output = self.audio_fc(audio_output)
|
| 189 |
+
|
| 190 |
+
# Image backbone processing
|
| 191 |
+
image_output = self.image_backbone(image)
|
| 192 |
+
image_output = self.image_fc(image_output)
|
| 193 |
+
|
| 194 |
+
# Concatenate audio and image outputs
|
| 195 |
+
merged = torch.cat((audio_output, image_output), dim=1)
|
| 196 |
+
|
| 197 |
+
# Fully connected layers
|
| 198 |
+
output = self.relu(self.fc1(merged))
|
| 199 |
+
output = self.fc2(output)
|
| 200 |
+
|
| 201 |
+
return output
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def evaluate_model(data_dir, image_backbone, audio_backbone, audio_hidden_size=128, save_model_dir=None):
|
| 205 |
+
# Adjust batch size based on model complexity to avoid OOM errors
|
| 206 |
+
adjusted_batch_size = batch_size
|
| 207 |
+
|
| 208 |
+
# Models that typically require more memory get smaller batch sizes
|
| 209 |
+
if image_backbone in ["swin_b", "convnext_base"] or audio_backbone in ["transformer", "bidirectional_lstm"]:
|
| 210 |
+
adjusted_batch_size = max(4, batch_size // 2) # At least batch size of 4, but reduce by half if needed
|
| 211 |
+
print(f"\033[92mINFO\033[0m: Adjusted batch size to {adjusted_batch_size} for larger model")
|
| 212 |
+
|
| 213 |
+
# Create dataset
|
| 214 |
+
dataset = WatermelonDataset(data_dir)
|
| 215 |
+
n_samples = len(dataset)
|
| 216 |
+
|
| 217 |
+
# Split dataset
|
| 218 |
+
train_size = int(0.7 * n_samples)
|
| 219 |
+
val_size = int(0.2 * n_samples)
|
| 220 |
+
test_size = n_samples - train_size - val_size
|
| 221 |
+
|
| 222 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
| 223 |
+
dataset, [train_size, val_size, test_size]
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
train_loader = DataLoader(train_dataset, batch_size=adjusted_batch_size, shuffle=True)
|
| 227 |
+
val_loader = DataLoader(val_dataset, batch_size=adjusted_batch_size, shuffle=False)
|
| 228 |
+
test_loader = DataLoader(test_dataset, batch_size=adjusted_batch_size, shuffle=False)
|
| 229 |
+
|
| 230 |
+
# Initialize model
|
| 231 |
+
model = WatermelonModelModular(image_backbone, audio_backbone, audio_hidden_size).to(device)
|
| 232 |
+
|
| 233 |
+
# Loss function and optimizer
|
| 234 |
+
criterion = torch.nn.MSELoss()
|
| 235 |
+
mae_criterion = torch.nn.L1Loss() # For MAE evaluation
|
| 236 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 237 |
+
|
| 238 |
+
print(f"\033[92mINFO\033[0m: Evaluating model with {image_backbone} (image) and {audio_backbone} (audio)")
|
| 239 |
+
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
|
| 240 |
+
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
|
| 241 |
+
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
|
| 242 |
+
print(f"\033[92mINFO\033[0m: Batch size: {adjusted_batch_size}")
|
| 243 |
+
|
| 244 |
+
# Training loop
|
| 245 |
+
print(f"\033[92mINFO\033[0m: Training for evaluation...")
|
| 246 |
+
model.train()
|
| 247 |
+
running_loss = 0.0
|
| 248 |
+
|
| 249 |
+
# Wrap with tqdm for progress visualization
|
| 250 |
+
train_iterator = tqdm(train_loader, desc="Training")
|
| 251 |
+
|
| 252 |
+
for i, (mfcc, image, label) in enumerate(train_iterator):
|
| 253 |
+
try:
|
| 254 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
| 255 |
+
|
| 256 |
+
optimizer.zero_grad()
|
| 257 |
+
output = model(mfcc, image)
|
| 258 |
+
label = label.view(-1, 1).float()
|
| 259 |
+
loss = criterion(output, label)
|
| 260 |
+
loss.backward()
|
| 261 |
+
optimizer.step()
|
| 262 |
+
|
| 263 |
+
running_loss += loss.item()
|
| 264 |
+
train_iterator.set_postfix({"Loss": f"{loss.item():.4f}"})
|
| 265 |
+
|
| 266 |
+
# Clear memory after each batch
|
| 267 |
+
if device.type == 'cuda':
|
| 268 |
+
del mfcc, image, label, output, loss
|
| 269 |
+
torch.cuda.empty_cache()
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}")
|
| 273 |
+
# Clear memory in case of error
|
| 274 |
+
if device.type == 'cuda':
|
| 275 |
+
torch.cuda.empty_cache()
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
# Validation phase
|
| 279 |
+
print(f"\033[92mINFO\033[0m: Validating...")
|
| 280 |
+
model.eval()
|
| 281 |
+
val_loss = 0.0
|
| 282 |
+
val_mae = 0.0
|
| 283 |
+
|
| 284 |
+
val_iterator = tqdm(val_loader, desc="Validation")
|
| 285 |
+
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
for i, (mfcc, image, label) in enumerate(val_iterator):
|
| 288 |
+
try:
|
| 289 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
| 290 |
+
output = model(mfcc, image)
|
| 291 |
+
label = label.view(-1, 1).float()
|
| 292 |
+
|
| 293 |
+
# Calculate MSE loss
|
| 294 |
+
loss = criterion(output, label)
|
| 295 |
+
val_loss += loss.item()
|
| 296 |
+
|
| 297 |
+
# Calculate MAE
|
| 298 |
+
mae = mae_criterion(output, label)
|
| 299 |
+
val_mae += mae.item()
|
| 300 |
+
|
| 301 |
+
val_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
|
| 302 |
+
|
| 303 |
+
# Clear memory after each batch
|
| 304 |
+
if device.type == 'cuda':
|
| 305 |
+
del mfcc, image, label, output, loss, mae
|
| 306 |
+
torch.cuda.empty_cache()
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}")
|
| 310 |
+
# Clear memory in case of error
|
| 311 |
+
if device.type == 'cuda':
|
| 312 |
+
torch.cuda.empty_cache()
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
|
| 316 |
+
avg_val_mae = val_mae / len(val_loader) if len(val_loader) > 0 else float('inf')
|
| 317 |
+
|
| 318 |
+
# Test phase
|
| 319 |
+
print(f"\033[92mINFO\033[0m: Testing...")
|
| 320 |
+
model.eval()
|
| 321 |
+
test_loss = 0.0
|
| 322 |
+
test_mae = 0.0
|
| 323 |
+
|
| 324 |
+
test_iterator = tqdm(test_loader, desc="Testing")
|
| 325 |
+
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
for i, (mfcc, image, label) in enumerate(test_iterator):
|
| 328 |
+
try:
|
| 329 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
| 330 |
+
output = model(mfcc, image)
|
| 331 |
+
label = label.view(-1, 1).float()
|
| 332 |
+
|
| 333 |
+
# Calculate MSE loss
|
| 334 |
+
loss = criterion(output, label)
|
| 335 |
+
test_loss += loss.item()
|
| 336 |
+
|
| 337 |
+
# Calculate MAE
|
| 338 |
+
mae = mae_criterion(output, label)
|
| 339 |
+
test_mae += mae.item()
|
| 340 |
+
|
| 341 |
+
test_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
|
| 342 |
+
|
| 343 |
+
# Clear memory after each batch
|
| 344 |
+
if device.type == 'cuda':
|
| 345 |
+
del mfcc, image, label, output, loss, mae
|
| 346 |
+
torch.cuda.empty_cache()
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
|
| 350 |
+
# Clear memory in case of error
|
| 351 |
+
if device.type == 'cuda':
|
| 352 |
+
torch.cuda.empty_cache()
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
avg_test_loss = test_loss / len(test_loader) if len(test_loader) > 0 else float('inf')
|
| 356 |
+
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
|
| 357 |
+
|
| 358 |
+
results = {
|
| 359 |
+
"image_backbone": image_backbone,
|
| 360 |
+
"audio_backbone": audio_backbone,
|
| 361 |
+
"validation_mse": avg_val_loss,
|
| 362 |
+
"validation_mae": avg_val_mae,
|
| 363 |
+
"test_mse": avg_test_loss,
|
| 364 |
+
"test_mae": avg_test_mae
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
print(f"\033[92mINFO\033[0m: Evaluation Results:")
|
| 368 |
+
print(f"Image Backbone: {image_backbone}")
|
| 369 |
+
print(f"Audio Backbone: {audio_backbone}")
|
| 370 |
+
print(f"Validation MSE: {avg_val_loss:.4f}")
|
| 371 |
+
print(f"Validation MAE: {avg_val_mae:.4f}")
|
| 372 |
+
print(f"Test MSE: {avg_test_loss:.4f}")
|
| 373 |
+
print(f"Test MAE: {avg_test_mae:.4f}")
|
| 374 |
+
|
| 375 |
+
# Save model if save_model_dir is provided
|
| 376 |
+
if save_model_dir:
|
| 377 |
+
os.makedirs(save_model_dir, exist_ok=True)
|
| 378 |
+
model_filename = f"{image_backbone}_{audio_backbone}_model.pt"
|
| 379 |
+
model_path = os.path.join(save_model_dir, model_filename)
|
| 380 |
+
torch.save(model.state_dict(), model_path)
|
| 381 |
+
print(f"\033[92mINFO\033[0m: Model saved to {model_path}")
|
| 382 |
+
|
| 383 |
+
# Add model path to results
|
| 384 |
+
results["model_path"] = model_path
|
| 385 |
+
|
| 386 |
+
# Clean up memory before returning
|
| 387 |
+
if device.type == 'cuda':
|
| 388 |
+
del model, optimizer, criterion, mae_criterion
|
| 389 |
+
torch.cuda.empty_cache()
|
| 390 |
+
|
| 391 |
+
return results
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def evaluate_all_combinations(data_dir, image_backbones=None, audio_backbones=None, save_model_dir="test_models", results_file="backbone_evaluation_results.json"):
|
| 395 |
+
if image_backbones is None:
|
| 396 |
+
image_backbones = list(IMAGE_BACKBONES.keys())
|
| 397 |
+
|
| 398 |
+
if audio_backbones is None:
|
| 399 |
+
audio_backbones = list(AUDIO_BACKBONES.keys())
|
| 400 |
+
|
| 401 |
+
# Create directory for saving models
|
| 402 |
+
if save_model_dir:
|
| 403 |
+
os.makedirs(save_model_dir, exist_ok=True)
|
| 404 |
+
|
| 405 |
+
# Load previous results if the file exists
|
| 406 |
+
results = []
|
| 407 |
+
evaluated_combinations = set()
|
| 408 |
+
|
| 409 |
+
if os.path.exists(results_file):
|
| 410 |
+
try:
|
| 411 |
+
with open(results_file, 'r') as f:
|
| 412 |
+
results = json.load(f)
|
| 413 |
+
evaluated_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in results}
|
| 414 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(results)} previous results from {results_file}")
|
| 415 |
+
except Exception as e:
|
| 416 |
+
print(f"\033[91mERR!\033[0m: Error loading previous results from {results_file}: {e}")
|
| 417 |
+
results = []
|
| 418 |
+
evaluated_combinations = set()
|
| 419 |
+
else:
|
| 420 |
+
print(f"\033[93mWARN\033[0m: Results file '{results_file}' does not exist. Starting with empty results.")
|
| 421 |
+
|
| 422 |
+
# Create combinations to evaluate, skipping any that have already been evaluated
|
| 423 |
+
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones
|
| 424 |
+
if (img, aud) not in evaluated_combinations]
|
| 425 |
+
|
| 426 |
+
if len(combinations) < len(image_backbones) * len(audio_backbones):
|
| 427 |
+
print(f"\033[92mINFO\033[0m: Skipping {len(evaluated_combinations)} already evaluated combinations")
|
| 428 |
+
|
| 429 |
+
print(f"\033[92mINFO\033[0m: Will evaluate {len(combinations)} combinations")
|
| 430 |
+
|
| 431 |
+
for image_backbone, audio_backbone in combinations:
|
| 432 |
+
print(f"\033[92mINFO\033[0m: Evaluating {image_backbone} + {audio_backbone}")
|
| 433 |
+
try:
|
| 434 |
+
# Clean GPU memory before each model evaluation
|
| 435 |
+
if torch.cuda.is_available():
|
| 436 |
+
torch.cuda.empty_cache()
|
| 437 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
|
| 438 |
+
# Print memory usage for debugging
|
| 439 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 440 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
| 441 |
+
|
| 442 |
+
result = evaluate_model(data_dir, image_backbone, audio_backbone, save_model_dir=save_model_dir)
|
| 443 |
+
results.append(result)
|
| 444 |
+
|
| 445 |
+
# Save results after each evaluation
|
| 446 |
+
save_results(results, results_file)
|
| 447 |
+
print(f"\033[92mINFO\033[0m: Updated results saved to {results_file}")
|
| 448 |
+
|
| 449 |
+
# Force garbage collection to free memory
|
| 450 |
+
import gc
|
| 451 |
+
gc.collect()
|
| 452 |
+
if torch.cuda.is_available():
|
| 453 |
+
torch.cuda.empty_cache()
|
| 454 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
|
| 455 |
+
# Print memory usage for debugging
|
| 456 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 457 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
| 458 |
+
|
| 459 |
+
except Exception as e:
|
| 460 |
+
print(f"\033[91mERR!\033[0m: Error evaluating {image_backbone} + {audio_backbone}: {e}")
|
| 461 |
+
print(f"\033[91mERR!\033[0m: To continue from this point, use --start_from={image_backbone}:{audio_backbone}")
|
| 462 |
+
|
| 463 |
+
# Force garbage collection to free memory even if there's an error
|
| 464 |
+
import gc
|
| 465 |
+
gc.collect()
|
| 466 |
+
if torch.cuda.is_available():
|
| 467 |
+
torch.cuda.empty_cache()
|
| 468 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
|
| 469 |
+
|
| 470 |
+
continue
|
| 471 |
+
|
| 472 |
+
# Sort results by test MAE (ascending)
|
| 473 |
+
results.sort(key=lambda x: x["test_mae"])
|
| 474 |
+
|
| 475 |
+
# Save final sorted results
|
| 476 |
+
save_results(results, results_file)
|
| 477 |
+
|
| 478 |
+
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
|
| 479 |
+
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
|
| 480 |
+
print("="*60)
|
| 481 |
+
|
| 482 |
+
for result in results:
|
| 483 |
+
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")
|
| 484 |
+
|
| 485 |
+
return results
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def save_results(results, filename="backbone_evaluation_results.json"):
|
| 489 |
+
"""Save evaluation results to a JSON file."""
|
| 490 |
+
with open(filename, 'w') as f:
|
| 491 |
+
json.dump(results, f, indent=4)
|
| 492 |
+
print(f"\033[92mINFO\033[0m: Results saved to {filename}")
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
if __name__ == "__main__":
|
| 496 |
+
import argparse
|
| 497 |
+
|
| 498 |
+
parser = argparse.ArgumentParser(description="Evaluate Different Backbones for Watermelon Sweetness Prediction")
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--data_dir",
|
| 501 |
+
type=str,
|
| 502 |
+
default="../cleaned",
|
| 503 |
+
help="Path to the cleaned dataset directory"
|
| 504 |
+
)
|
| 505 |
+
parser.add_argument(
|
| 506 |
+
"--image_backbone",
|
| 507 |
+
type=str,
|
| 508 |
+
default=None,
|
| 509 |
+
help="Specific image backbone to evaluate (leave empty to evaluate all available)"
|
| 510 |
+
)
|
| 511 |
+
parser.add_argument(
|
| 512 |
+
"--audio_backbone",
|
| 513 |
+
type=str,
|
| 514 |
+
default=None,
|
| 515 |
+
help="Specific audio backbone to evaluate (leave empty to evaluate all available)"
|
| 516 |
+
)
|
| 517 |
+
parser.add_argument(
|
| 518 |
+
"--evaluate_all",
|
| 519 |
+
action="store_true",
|
| 520 |
+
help="Evaluate all combinations of backbones"
|
| 521 |
+
)
|
| 522 |
+
parser.add_argument(
|
| 523 |
+
"--start_from",
|
| 524 |
+
type=str,
|
| 525 |
+
default=None,
|
| 526 |
+
help="Start evaluation from a specific combination, format: 'image_backbone:audio_backbone'"
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument(
|
| 529 |
+
"--prioritize_efficient",
|
| 530 |
+
action="store_true",
|
| 531 |
+
help="Prioritize more efficient models first to avoid memory issues"
|
| 532 |
+
)
|
| 533 |
+
parser.add_argument(
|
| 534 |
+
"--results_file",
|
| 535 |
+
type=str,
|
| 536 |
+
default="backbone_evaluation_results.json",
|
| 537 |
+
help="File to save the evaluation results"
|
| 538 |
+
)
|
| 539 |
+
parser.add_argument(
|
| 540 |
+
"--load_previous_results",
|
| 541 |
+
action="store_true",
|
| 542 |
+
help="Load previous results from results_file if it exists"
|
| 543 |
+
)
|
| 544 |
+
parser.add_argument(
|
| 545 |
+
"--model_dir",
|
| 546 |
+
type=str,
|
| 547 |
+
default="test_models",
|
| 548 |
+
help="Directory to save model checkpoints"
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
args = parser.parse_args()
|
| 552 |
+
|
| 553 |
+
# Create model directory if it doesn't exist
|
| 554 |
+
if args.model_dir:
|
| 555 |
+
os.makedirs(args.model_dir, exist_ok=True)
|
| 556 |
+
|
| 557 |
+
print(f"\033[92mINFO\033[0m: === Available Image Backbones ===")
|
| 558 |
+
for name in IMAGE_BACKBONES.keys():
|
| 559 |
+
print(f"- {name}")
|
| 560 |
+
|
| 561 |
+
print(f"\033[92mINFO\033[0m: === Available Audio Backbones ===")
|
| 562 |
+
for name in AUDIO_BACKBONES.keys():
|
| 563 |
+
print(f"- {name}")
|
| 564 |
+
|
| 565 |
+
if args.evaluate_all:
|
| 566 |
+
evaluate_all_combinations(args.data_dir, results_file=args.results_file, save_model_dir=args.model_dir)
|
| 567 |
+
elif args.image_backbone and args.audio_backbone:
|
| 568 |
+
result = evaluate_model(args.data_dir, args.image_backbone, args.audio_backbone, save_model_dir=args.model_dir)
|
| 569 |
+
save_results([result], args.results_file)
|
| 570 |
+
else:
|
| 571 |
+
# Define a default set of backbones to evaluate if not specified
|
| 572 |
+
if args.prioritize_efficient:
|
| 573 |
+
# Start with less memory-intensive models
|
| 574 |
+
image_backbones = ["resnet50", "efficientnet_b0", "resnet101", "efficientnet_b3", "convnext_base", "swin_b"]
|
| 575 |
+
audio_backbones = ["lstm", "gru", "bidirectional_lstm", "transformer"]
|
| 576 |
+
else:
|
| 577 |
+
# Default selection focusing on better performance models
|
| 578 |
+
image_backbones = ["resnet101", "efficientnet_b3", "swin_b"]
|
| 579 |
+
audio_backbones = ["lstm", "bidirectional_lstm", "transformer"]
|
| 580 |
+
|
| 581 |
+
# Create all combinations
|
| 582 |
+
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones]
|
| 583 |
+
|
| 584 |
+
# Load previous results if requested and file exists
|
| 585 |
+
previous_results = []
|
| 586 |
+
previous_combinations = set()
|
| 587 |
+
if args.load_previous_results:
|
| 588 |
+
try:
|
| 589 |
+
if os.path.exists(args.results_file):
|
| 590 |
+
with open(args.results_file, 'r') as f:
|
| 591 |
+
previous_results = json.load(f)
|
| 592 |
+
previous_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in previous_results}
|
| 593 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(previous_results)} previous results")
|
| 594 |
+
else:
|
| 595 |
+
print(f"\033[93mWARN\033[0m: Results file '{args.results_file}' does not exist. Starting with empty results.")
|
| 596 |
+
except Exception as e:
|
| 597 |
+
print(f"\033[91mERR!\033[0m: Error loading previous results: {e}")
|
| 598 |
+
previous_results = []
|
| 599 |
+
previous_combinations = set()
|
| 600 |
+
|
| 601 |
+
# If starting from a specific point
|
| 602 |
+
if args.start_from:
|
| 603 |
+
try:
|
| 604 |
+
start_img, start_aud = args.start_from.split(':')
|
| 605 |
+
start_idx = combinations.index((start_img, start_aud))
|
| 606 |
+
combinations = combinations[start_idx:]
|
| 607 |
+
print(f"\033[92mINFO\033[0m: Starting from combination: {start_img} (image) + {start_aud} (audio)")
|
| 608 |
+
except (ValueError, IndexError):
|
| 609 |
+
print(f"\033[91mERR!\033[0m: Invalid start_from format or combination not found. Format should be 'image_backbone:audio_backbone'")
|
| 610 |
+
print(f"\033[91mERR!\033[0m: Continuing with all combinations.")
|
| 611 |
+
|
| 612 |
+
# Skip combinations that have already been evaluated
|
| 613 |
+
if previous_combinations:
|
| 614 |
+
original_count = len(combinations)
|
| 615 |
+
combinations = [(img, aud) for img, aud in combinations if (img, aud) not in previous_combinations]
|
| 616 |
+
print(f"\033[92mINFO\033[0m: Skipping {original_count - len(combinations)} already evaluated combinations")
|
| 617 |
+
|
| 618 |
+
# Evaluate each combination
|
| 619 |
+
results = previous_results.copy()
|
| 620 |
+
|
| 621 |
+
for img_backbone, audio_backbone in combinations:
|
| 622 |
+
print(f"\033[92mINFO\033[0m: Evaluating {img_backbone} + {audio_backbone}")
|
| 623 |
+
try:
|
| 624 |
+
# Clean GPU memory before each model evaluation
|
| 625 |
+
if torch.cuda.is_available():
|
| 626 |
+
torch.cuda.empty_cache()
|
| 627 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
|
| 628 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 629 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
| 630 |
+
|
| 631 |
+
result = evaluate_model(args.data_dir, img_backbone, audio_backbone, save_model_dir=args.model_dir)
|
| 632 |
+
results.append(result)
|
| 633 |
+
|
| 634 |
+
# Save results after each evaluation
|
| 635 |
+
save_results(results, args.results_file)
|
| 636 |
+
|
| 637 |
+
# Force garbage collection to free memory
|
| 638 |
+
import gc
|
| 639 |
+
gc.collect()
|
| 640 |
+
if torch.cuda.is_available():
|
| 641 |
+
torch.cuda.empty_cache()
|
| 642 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
|
| 643 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 644 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
| 645 |
+
|
| 646 |
+
except Exception as e:
|
| 647 |
+
print(f"\033[91mERR!\033[0m: Error evaluating {img_backbone} + {audio_backbone}: {e}")
|
| 648 |
+
print(f"\033[91mERR!\033[0m: To continue from this point later, use --start_from={img_backbone}:{audio_backbone}")
|
| 649 |
+
|
| 650 |
+
# Force garbage collection to free memory even if there's an error
|
| 651 |
+
import gc
|
| 652 |
+
gc.collect()
|
| 653 |
+
if torch.cuda.is_available():
|
| 654 |
+
torch.cuda.empty_cache()
|
| 655 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
|
| 656 |
+
|
| 657 |
+
continue
|
| 658 |
+
|
| 659 |
+
# Sort results by test MAE (ascending)
|
| 660 |
+
results.sort(key=lambda x: x["test_mae"])
|
| 661 |
+
|
| 662 |
+
# Save final sorted results
|
| 663 |
+
save_results(results, args.results_file)
|
| 664 |
+
|
| 665 |
+
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
|
| 666 |
+
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
|
| 667 |
+
print("="*60)
|
| 668 |
+
|
| 669 |
+
for result in results:
|
| 670 |
+
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")
|
models/.nfs00000001a1a17512003726ad
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02999bd33592de717dc1ec8054dc570193074c3f25a7283b3daa580b727b7134
|
| 3 |
+
size 96095572
|
models/.nfs00000001a234d9cd003726ac
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5df632222fa87e09e635f90e5cce14bdd9fd34b442bf18daaf13e54dedfed132
|
| 3 |
+
size 96095572
|
models/.nfs00000001a2a11ea9003726ae
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:80f999a1540c42ed74491692aa66c3b5a6171f972bdf47c9d52556fe1673c8dd
|
| 3 |
+
size 96095572
|
models/efficientnet_b0_transformer_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eec8d23f6454198e147db3ff31e497a0fed8cc0fa690f58e2576e9190ca54aa7
|
| 3 |
+
size 22597034
|
models/efficientnet_b3_transformer_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da70bf6bef70cfa3795e566fd58523a9b41b01c151fb37fd3b255262c2b47451
|
| 3 |
+
size 49751930
|
models/resnet50_transformer_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cec4fe964defc58fea1f6c26c714c27680a4aa81b131795e8cbeadb6e7be9bd5
|
| 3 |
+
size 101004668
|
moe_evaluation_results.json
ADDED
|
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"moe_test_mae": 0.19680618420243262,
|
| 3 |
+
"moe_test_mse": 0.05606407420709729,
|
| 4 |
+
"true_labels": [
|
| 5 |
+
10.5,
|
| 6 |
+
9.399999618530273,
|
| 7 |
+
11.600000381469727,
|
| 8 |
+
8.699999809265137,
|
| 9 |
+
10.399999618530273,
|
| 10 |
+
10.800000190734863,
|
| 11 |
+
11.600000381469727,
|
| 12 |
+
10.5,
|
| 13 |
+
11.600000381469727,
|
| 14 |
+
11.100000381469727,
|
| 15 |
+
10.399999618530273,
|
| 16 |
+
10.5,
|
| 17 |
+
11.0,
|
| 18 |
+
10.5,
|
| 19 |
+
10.899999618530273,
|
| 20 |
+
10.5,
|
| 21 |
+
11.100000381469727,
|
| 22 |
+
9.600000381469727,
|
| 23 |
+
12.699999809265137,
|
| 24 |
+
10.0,
|
| 25 |
+
10.300000190734863,
|
| 26 |
+
10.399999618530273,
|
| 27 |
+
9.399999618530273,
|
| 28 |
+
10.800000190734863,
|
| 29 |
+
10.0,
|
| 30 |
+
11.600000381469727,
|
| 31 |
+
10.0,
|
| 32 |
+
10.399999618530273,
|
| 33 |
+
9.399999618530273,
|
| 34 |
+
10.399999618530273,
|
| 35 |
+
10.300000190734863,
|
| 36 |
+
9.399999618530273,
|
| 37 |
+
10.899999618530273,
|
| 38 |
+
9.0,
|
| 39 |
+
10.300000190734863,
|
| 40 |
+
10.899999618530273,
|
| 41 |
+
11.0,
|
| 42 |
+
12.699999809265137,
|
| 43 |
+
10.399999618530273,
|
| 44 |
+
9.600000381469727,
|
| 45 |
+
8.699999809265137,
|
| 46 |
+
10.199999809265137,
|
| 47 |
+
10.300000190734863,
|
| 48 |
+
11.600000381469727,
|
| 49 |
+
9.0,
|
| 50 |
+
9.0,
|
| 51 |
+
11.0,
|
| 52 |
+
8.699999809265137,
|
| 53 |
+
9.699999809265137,
|
| 54 |
+
10.399999618530273,
|
| 55 |
+
10.0,
|
| 56 |
+
11.600000381469727,
|
| 57 |
+
9.399999618530273,
|
| 58 |
+
9.0,
|
| 59 |
+
10.300000190734863,
|
| 60 |
+
10.5,
|
| 61 |
+
10.399999618530273,
|
| 62 |
+
11.0,
|
| 63 |
+
10.899999618530273,
|
| 64 |
+
9.399999618530273,
|
| 65 |
+
8.699999809265137,
|
| 66 |
+
10.300000190734863,
|
| 67 |
+
9.699999809265137,
|
| 68 |
+
10.300000190734863,
|
| 69 |
+
9.399999618530273,
|
| 70 |
+
10.300000190734863,
|
| 71 |
+
9.399999618530273,
|
| 72 |
+
10.0,
|
| 73 |
+
10.399999618530273,
|
| 74 |
+
10.199999809265137,
|
| 75 |
+
11.0,
|
| 76 |
+
12.699999809265137,
|
| 77 |
+
12.699999809265137,
|
| 78 |
+
10.0,
|
| 79 |
+
11.0,
|
| 80 |
+
9.0,
|
| 81 |
+
10.0,
|
| 82 |
+
10.5,
|
| 83 |
+
11.600000381469727,
|
| 84 |
+
9.399999618530273,
|
| 85 |
+
10.0,
|
| 86 |
+
11.0,
|
| 87 |
+
11.100000381469727,
|
| 88 |
+
10.899999618530273,
|
| 89 |
+
9.399999618530273,
|
| 90 |
+
10.300000190734863,
|
| 91 |
+
9.399999618530273,
|
| 92 |
+
8.699999809265137,
|
| 93 |
+
10.0,
|
| 94 |
+
12.699999809265137,
|
| 95 |
+
12.699999809265137,
|
| 96 |
+
9.699999809265137,
|
| 97 |
+
9.399999618530273,
|
| 98 |
+
11.0,
|
| 99 |
+
9.399999618530273,
|
| 100 |
+
9.0,
|
| 101 |
+
11.100000381469727,
|
| 102 |
+
10.300000190734863,
|
| 103 |
+
10.300000190734863,
|
| 104 |
+
10.300000190734863,
|
| 105 |
+
10.0,
|
| 106 |
+
9.399999618530273,
|
| 107 |
+
9.399999618530273,
|
| 108 |
+
10.899999618530273,
|
| 109 |
+
11.0,
|
| 110 |
+
9.699999809265137,
|
| 111 |
+
12.699999809265137,
|
| 112 |
+
10.5,
|
| 113 |
+
11.0,
|
| 114 |
+
10.899999618530273,
|
| 115 |
+
12.699999809265137,
|
| 116 |
+
10.899999618530273,
|
| 117 |
+
11.0,
|
| 118 |
+
10.300000190734863,
|
| 119 |
+
11.0,
|
| 120 |
+
9.699999809265137,
|
| 121 |
+
10.300000190734863,
|
| 122 |
+
10.300000190734863,
|
| 123 |
+
10.199999809265137,
|
| 124 |
+
10.199999809265137,
|
| 125 |
+
10.899999618530273,
|
| 126 |
+
10.5,
|
| 127 |
+
11.0,
|
| 128 |
+
8.699999809265137,
|
| 129 |
+
9.699999809265137,
|
| 130 |
+
12.699999809265137,
|
| 131 |
+
11.600000381469727,
|
| 132 |
+
10.899999618530273,
|
| 133 |
+
11.0,
|
| 134 |
+
9.399999618530273,
|
| 135 |
+
10.300000190734863,
|
| 136 |
+
12.699999809265137,
|
| 137 |
+
10.199999809265137,
|
| 138 |
+
10.199999809265137,
|
| 139 |
+
10.800000190734863,
|
| 140 |
+
8.699999809265137,
|
| 141 |
+
9.0,
|
| 142 |
+
11.0,
|
| 143 |
+
9.399999618530273,
|
| 144 |
+
10.800000190734863,
|
| 145 |
+
11.100000381469727,
|
| 146 |
+
11.100000381469727,
|
| 147 |
+
10.199999809265137,
|
| 148 |
+
9.399999618530273,
|
| 149 |
+
10.199999809265137,
|
| 150 |
+
10.199999809265137,
|
| 151 |
+
9.399999618530273,
|
| 152 |
+
10.899999618530273,
|
| 153 |
+
10.199999809265137,
|
| 154 |
+
11.100000381469727,
|
| 155 |
+
11.600000381469727,
|
| 156 |
+
8.699999809265137,
|
| 157 |
+
11.600000381469727,
|
| 158 |
+
10.199999809265137,
|
| 159 |
+
9.399999618530273,
|
| 160 |
+
9.699999809265137,
|
| 161 |
+
9.399999618530273
|
| 162 |
+
],
|
| 163 |
+
"moe_predictions": [
|
| 164 |
+
10.906482696533203,
|
| 165 |
+
9.413387298583984,
|
| 166 |
+
11.58445930480957,
|
| 167 |
+
8.627098083496094,
|
| 168 |
+
10.55517578125,
|
| 169 |
+
10.969362258911133,
|
| 170 |
+
11.596641540527344,
|
| 171 |
+
10.598587036132812,
|
| 172 |
+
11.712945938110352,
|
| 173 |
+
11.415390968322754,
|
| 174 |
+
10.500967979431152,
|
| 175 |
+
10.939116477966309,
|
| 176 |
+
11.23089599609375,
|
| 177 |
+
10.928877830505371,
|
| 178 |
+
11.180931091308594,
|
| 179 |
+
10.805574417114258,
|
| 180 |
+
11.44560432434082,
|
| 181 |
+
9.797750473022461,
|
| 182 |
+
12.00424575805664,
|
| 183 |
+
9.924805641174316,
|
| 184 |
+
10.419149398803711,
|
| 185 |
+
10.459878921508789,
|
| 186 |
+
9.774242401123047,
|
| 187 |
+
10.985288619995117,
|
| 188 |
+
10.047812461853027,
|
| 189 |
+
11.745304107666016,
|
| 190 |
+
10.191004753112793,
|
| 191 |
+
10.527164459228516,
|
| 192 |
+
9.581968307495117,
|
| 193 |
+
10.483012199401855,
|
| 194 |
+
10.368606567382812,
|
| 195 |
+
9.450727462768555,
|
| 196 |
+
11.197010040283203,
|
| 197 |
+
9.173027038574219,
|
| 198 |
+
10.50676441192627,
|
| 199 |
+
11.195816040039062,
|
| 200 |
+
11.227279663085938,
|
| 201 |
+
13.106525421142578,
|
| 202 |
+
10.4664945602417,
|
| 203 |
+
9.891031265258789,
|
| 204 |
+
8.75540542602539,
|
| 205 |
+
10.572815895080566,
|
| 206 |
+
10.214585304260254,
|
| 207 |
+
12.000329971313477,
|
| 208 |
+
8.887301445007324,
|
| 209 |
+
8.929031372070312,
|
| 210 |
+
11.054266929626465,
|
| 211 |
+
8.85447883605957,
|
| 212 |
+
9.515145301818848,
|
| 213 |
+
10.480228424072266,
|
| 214 |
+
10.193933486938477,
|
| 215 |
+
11.7305908203125,
|
| 216 |
+
9.437666893005371,
|
| 217 |
+
9.13387680053711,
|
| 218 |
+
10.629348754882812,
|
| 219 |
+
10.703892707824707,
|
| 220 |
+
10.539461135864258,
|
| 221 |
+
11.135326385498047,
|
| 222 |
+
11.19705867767334,
|
| 223 |
+
9.558942794799805,
|
| 224 |
+
8.898516654968262,
|
| 225 |
+
10.628425598144531,
|
| 226 |
+
9.657480239868164,
|
| 227 |
+
10.513351440429688,
|
| 228 |
+
9.459192276000977,
|
| 229 |
+
10.358184814453125,
|
| 230 |
+
9.432706832885742,
|
| 231 |
+
10.078161239624023,
|
| 232 |
+
10.572355270385742,
|
| 233 |
+
10.58112907409668,
|
| 234 |
+
10.910698890686035,
|
| 235 |
+
13.053973197937012,
|
| 236 |
+
12.972726821899414,
|
| 237 |
+
10.170805931091309,
|
| 238 |
+
11.225208282470703,
|
| 239 |
+
8.872610092163086,
|
| 240 |
+
10.091118812561035,
|
| 241 |
+
10.724177360534668,
|
| 242 |
+
11.729219436645508,
|
| 243 |
+
9.66834545135498,
|
| 244 |
+
10.027229309082031,
|
| 245 |
+
11.232885360717773,
|
| 246 |
+
11.518696784973145,
|
| 247 |
+
11.261479377746582,
|
| 248 |
+
9.523242950439453,
|
| 249 |
+
10.484042167663574,
|
| 250 |
+
9.522797584533691,
|
| 251 |
+
8.75236988067627,
|
| 252 |
+
10.083819389343262,
|
| 253 |
+
13.073421478271484,
|
| 254 |
+
13.001571655273438,
|
| 255 |
+
9.905550003051758,
|
| 256 |
+
9.318197250366211,
|
| 257 |
+
11.141549110412598,
|
| 258 |
+
9.754105567932129,
|
| 259 |
+
9.013923645019531,
|
| 260 |
+
11.429242134094238,
|
| 261 |
+
10.375783920288086,
|
| 262 |
+
10.526394844055176,
|
| 263 |
+
10.307140350341797,
|
| 264 |
+
10.169934272766113,
|
| 265 |
+
9.429258346557617,
|
| 266 |
+
9.29328441619873,
|
| 267 |
+
11.136444091796875,
|
| 268 |
+
11.040485382080078,
|
| 269 |
+
9.723966598510742,
|
| 270 |
+
12.936074256896973,
|
| 271 |
+
10.913898468017578,
|
| 272 |
+
11.255935668945312,
|
| 273 |
+
11.032815933227539,
|
| 274 |
+
12.95362663269043,
|
| 275 |
+
10.942233085632324,
|
| 276 |
+
11.014484405517578,
|
| 277 |
+
10.47386646270752,
|
| 278 |
+
11.207697868347168,
|
| 279 |
+
9.531013488769531,
|
| 280 |
+
10.512401580810547,
|
| 281 |
+
10.791257858276367,
|
| 282 |
+
10.385677337646484,
|
| 283 |
+
10.393269538879395,
|
| 284 |
+
11.13322639465332,
|
| 285 |
+
10.893503189086914,
|
| 286 |
+
11.24067497253418,
|
| 287 |
+
8.767911911010742,
|
| 288 |
+
9.76015853881836,
|
| 289 |
+
13.095734596252441,
|
| 290 |
+
11.651636123657227,
|
| 291 |
+
11.08572006225586,
|
| 292 |
+
10.958650588989258,
|
| 293 |
+
9.548912048339844,
|
| 294 |
+
10.243309020996094,
|
| 295 |
+
13.102086067199707,
|
| 296 |
+
10.579414367675781,
|
| 297 |
+
10.406577110290527,
|
| 298 |
+
11.255165100097656,
|
| 299 |
+
8.494292259216309,
|
| 300 |
+
8.890151023864746,
|
| 301 |
+
11.146952629089355,
|
| 302 |
+
9.766341209411621,
|
| 303 |
+
11.163339614868164,
|
| 304 |
+
11.502073287963867,
|
| 305 |
+
11.408285140991211,
|
| 306 |
+
10.383015632629395,
|
| 307 |
+
9.54578971862793,
|
| 308 |
+
10.56948184967041,
|
| 309 |
+
10.558614730834961,
|
| 310 |
+
9.794357299804688,
|
| 311 |
+
10.885274887084961,
|
| 312 |
+
10.377969741821289,
|
| 313 |
+
11.410195350646973,
|
| 314 |
+
11.537992477416992,
|
| 315 |
+
8.826037406921387,
|
| 316 |
+
12.070415496826172,
|
| 317 |
+
10.559798240661621,
|
| 318 |
+
9.605077743530273,
|
| 319 |
+
9.737533569335938,
|
| 320 |
+
9.520374298095703
|
| 321 |
+
],
|
| 322 |
+
"individual_predictions": {
|
| 323 |
+
"efficientnet_b3_transformer": [
|
| 324 |
+
10.619565963745117,
|
| 325 |
+
9.285565376281738,
|
| 326 |
+
11.017762184143066,
|
| 327 |
+
8.358080863952637,
|
| 328 |
+
9.92147159576416,
|
| 329 |
+
10.68340015411377,
|
| 330 |
+
11.023524284362793,
|
| 331 |
+
10.292417526245117,
|
| 332 |
+
10.513864517211914,
|
| 333 |
+
10.958821296691895,
|
| 334 |
+
10.322061538696289,
|
| 335 |
+
10.383071899414062,
|
| 336 |
+
10.330121040344238,
|
| 337 |
+
10.344510078430176,
|
| 338 |
+
11.309442520141602,
|
| 339 |
+
10.321882247924805,
|
| 340 |
+
10.974185943603516,
|
| 341 |
+
9.367315292358398,
|
| 342 |
+
11.474529266357422,
|
| 343 |
+
9.296891212463379,
|
| 344 |
+
10.27892780303955,
|
| 345 |
+
10.14356803894043,
|
| 346 |
+
9.155308723449707,
|
| 347 |
+
10.249421119689941,
|
| 348 |
+
9.534292221069336,
|
| 349 |
+
11.197205543518066,
|
| 350 |
+
9.988767623901367,
|
| 351 |
+
10.485107421875,
|
| 352 |
+
9.040623664855957,
|
| 353 |
+
10.171326637268066,
|
| 354 |
+
10.153056144714355,
|
| 355 |
+
9.17545223236084,
|
| 356 |
+
10.604523658752441,
|
| 357 |
+
8.7711763381958,
|
| 358 |
+
10.127464294433594,
|
| 359 |
+
11.29480266571045,
|
| 360 |
+
10.326626777648926,
|
| 361 |
+
13.54947566986084,
|
| 362 |
+
10.142123222351074,
|
| 363 |
+
9.914827346801758,
|
| 364 |
+
7.935253620147705,
|
| 365 |
+
10.513096809387207,
|
| 366 |
+
9.79228687286377,
|
| 367 |
+
11.721403121948242,
|
| 368 |
+
7.996966361999512,
|
| 369 |
+
8.011720657348633,
|
| 370 |
+
10.551737785339355,
|
| 371 |
+
8.663973808288574,
|
| 372 |
+
8.74413776397705,
|
| 373 |
+
10.276195526123047,
|
| 374 |
+
10.136805534362793,
|
| 375 |
+
11.221556663513184,
|
| 376 |
+
8.912840843200684,
|
| 377 |
+
8.619383811950684,
|
| 378 |
+
10.178643226623535,
|
| 379 |
+
10.311914443969727,
|
| 380 |
+
10.487189292907715,
|
| 381 |
+
10.548056602478027,
|
| 382 |
+
11.258485794067383,
|
| 383 |
+
9.288726806640625,
|
| 384 |
+
8.140922546386719,
|
| 385 |
+
10.216073989868164,
|
| 386 |
+
9.068129539489746,
|
| 387 |
+
10.33917236328125,
|
| 388 |
+
9.11395263671875,
|
| 389 |
+
10.140262603759766,
|
| 390 |
+
8.864439010620117,
|
| 391 |
+
9.560175895690918,
|
| 392 |
+
10.1554594039917,
|
| 393 |
+
10.011631965637207,
|
| 394 |
+
10.838635444641113,
|
| 395 |
+
13.890799522399902,
|
| 396 |
+
13.743374824523926,
|
| 397 |
+
10.119439125061035,
|
| 398 |
+
11.073603630065918,
|
| 399 |
+
7.99126672744751,
|
| 400 |
+
10.012906074523926,
|
| 401 |
+
10.309550285339355,
|
| 402 |
+
10.537038803100586,
|
| 403 |
+
9.361739158630371,
|
| 404 |
+
9.594813346862793,
|
| 405 |
+
10.32430362701416,
|
| 406 |
+
11.0283842086792,
|
| 407 |
+
11.271435737609863,
|
| 408 |
+
9.267289161682129,
|
| 409 |
+
10.143651962280273,
|
| 410 |
+
9.201630592346191,
|
| 411 |
+
8.489853858947754,
|
| 412 |
+
9.663308143615723,
|
| 413 |
+
13.539351463317871,
|
| 414 |
+
13.890753746032715,
|
| 415 |
+
9.300865173339844,
|
| 416 |
+
8.978877067565918,
|
| 417 |
+
10.455121994018555,
|
| 418 |
+
9.145268440246582,
|
| 419 |
+
8.390588760375977,
|
| 420 |
+
10.97396183013916,
|
| 421 |
+
10.023279190063477,
|
| 422 |
+
10.194899559020996,
|
| 423 |
+
9.974883079528809,
|
| 424 |
+
10.101761817932129,
|
| 425 |
+
9.511059761047363,
|
| 426 |
+
8.89189624786377,
|
| 427 |
+
10.77907657623291,
|
| 428 |
+
10.7083158493042,
|
| 429 |
+
9.067532539367676,
|
| 430 |
+
13.406800270080566,
|
| 431 |
+
10.60212516784668,
|
| 432 |
+
10.704161643981934,
|
| 433 |
+
11.133363723754883,
|
| 434 |
+
13.293631553649902,
|
| 435 |
+
9.996685981750488,
|
| 436 |
+
10.766114234924316,
|
| 437 |
+
10.15234088897705,
|
| 438 |
+
11.180027961730957,
|
| 439 |
+
8.875227928161621,
|
| 440 |
+
10.376603126525879,
|
| 441 |
+
10.074305534362793,
|
| 442 |
+
10.001667022705078,
|
| 443 |
+
10.027312278747559,
|
| 444 |
+
10.606922149658203,
|
| 445 |
+
10.565585136413574,
|
| 446 |
+
10.699769020080566,
|
| 447 |
+
8.507576942443848,
|
| 448 |
+
9.084380149841309,
|
| 449 |
+
13.500945091247559,
|
| 450 |
+
11.240296363830566,
|
| 451 |
+
10.65023136138916,
|
| 452 |
+
10.248372077941895,
|
| 453 |
+
9.269180297851562,
|
| 454 |
+
9.840892791748047,
|
| 455 |
+
13.547538757324219,
|
| 456 |
+
9.992758750915527,
|
| 457 |
+
10.026358604431152,
|
| 458 |
+
10.71567440032959,
|
| 459 |
+
8.320480346679688,
|
| 460 |
+
8.000975608825684,
|
| 461 |
+
10.548954963684082,
|
| 462 |
+
9.176098823547363,
|
| 463 |
+
11.098072052001953,
|
| 464 |
+
11.02483081817627,
|
| 465 |
+
11.12319278717041,
|
| 466 |
+
9.996392250061035,
|
| 467 |
+
9.263312339782715,
|
| 468 |
+
10.517735481262207,
|
| 469 |
+
9.8799409866333,
|
| 470 |
+
9.319127082824707,
|
| 471 |
+
9.990796089172363,
|
| 472 |
+
9.982155799865723,
|
| 473 |
+
11.105603218078613,
|
| 474 |
+
10.747210502624512,
|
| 475 |
+
8.343344688415527,
|
| 476 |
+
11.73001480102539,
|
| 477 |
+
10.511062622070312,
|
| 478 |
+
9.331645965576172,
|
| 479 |
+
9.131060600280762,
|
| 480 |
+
8.956952095031738
|
| 481 |
+
],
|
| 482 |
+
"efficientnet_b0_transformer": [
|
| 483 |
+
11.040512084960938,
|
| 484 |
+
9.555410385131836,
|
| 485 |
+
11.689399719238281,
|
| 486 |
+
8.434002876281738,
|
| 487 |
+
11.386773109436035,
|
| 488 |
+
10.940624237060547,
|
| 489 |
+
11.708887100219727,
|
| 490 |
+
11.056541442871094,
|
| 491 |
+
12.392988204956055,
|
| 492 |
+
11.619367599487305,
|
| 493 |
+
10.591476440429688,
|
| 494 |
+
11.15828800201416,
|
| 495 |
+
11.810995101928711,
|
| 496 |
+
11.26023006439209,
|
| 497 |
+
11.246732711791992,
|
| 498 |
+
11.448994636535645,
|
| 499 |
+
11.935430526733398,
|
| 500 |
+
10.085470199584961,
|
| 501 |
+
12.768455505371094,
|
| 502 |
+
10.39224910736084,
|
| 503 |
+
10.590924263000488,
|
| 504 |
+
10.642997741699219,
|
| 505 |
+
9.948995590209961,
|
| 506 |
+
11.38804817199707,
|
| 507 |
+
10.38807487487793,
|
| 508 |
+
11.55557632446289,
|
| 509 |
+
10.514514923095703,
|
| 510 |
+
10.37149429321289,
|
| 511 |
+
9.95881462097168,
|
| 512 |
+
10.645825386047363,
|
| 513 |
+
10.480897903442383,
|
| 514 |
+
9.64439868927002,
|
| 515 |
+
11.213277816772461,
|
| 516 |
+
9.551204681396484,
|
| 517 |
+
10.929215431213379,
|
| 518 |
+
11.268585205078125,
|
| 519 |
+
11.799053192138672,
|
| 520 |
+
12.975137710571289,
|
| 521 |
+
10.657550811767578,
|
| 522 |
+
9.907003402709961,
|
| 523 |
+
9.108478546142578,
|
| 524 |
+
10.350242614746094,
|
| 525 |
+
10.475027084350586,
|
| 526 |
+
12.249593734741211,
|
| 527 |
+
9.311214447021484,
|
| 528 |
+
9.402128219604492,
|
| 529 |
+
11.460792541503906,
|
| 530 |
+
8.638538360595703,
|
| 531 |
+
10.098196029663086,
|
| 532 |
+
10.429000854492188,
|
| 533 |
+
10.63322639465332,
|
| 534 |
+
11.521190643310547,
|
| 535 |
+
9.934067726135254,
|
| 536 |
+
9.390719413757324,
|
| 537 |
+
10.85897445678711,
|
| 538 |
+
10.96368408203125,
|
| 539 |
+
10.440620422363281,
|
| 540 |
+
11.39995002746582,
|
| 541 |
+
11.138040542602539,
|
| 542 |
+
9.738420486450195,
|
| 543 |
+
9.13027286529541,
|
| 544 |
+
10.834165573120117,
|
| 545 |
+
9.734615325927734,
|
| 546 |
+
10.535043716430664,
|
| 547 |
+
9.7576904296875,
|
| 548 |
+
10.504064559936523,
|
| 549 |
+
9.726502418518066,
|
| 550 |
+
10.391711235046387,
|
| 551 |
+
10.526286125183105,
|
| 552 |
+
10.450986862182617,
|
| 553 |
+
10.732028007507324,
|
| 554 |
+
13.047806739807129,
|
| 555 |
+
12.901583671569824,
|
| 556 |
+
10.609762191772461,
|
| 557 |
+
11.112765312194824,
|
| 558 |
+
9.227752685546875,
|
| 559 |
+
10.403764724731445,
|
| 560 |
+
10.97991943359375,
|
| 561 |
+
12.400298118591309,
|
| 562 |
+
9.740009307861328,
|
| 563 |
+
10.546162605285645,
|
| 564 |
+
11.811308860778809,
|
| 565 |
+
12.024316787719727,
|
| 566 |
+
11.304412841796875,
|
| 567 |
+
9.642568588256836,
|
| 568 |
+
10.770721435546875,
|
| 569 |
+
9.673535346984863,
|
| 570 |
+
8.692492485046387,
|
| 571 |
+
10.140533447265625,
|
| 572 |
+
13.103691101074219,
|
| 573 |
+
12.987236022949219,
|
| 574 |
+
9.978914260864258,
|
| 575 |
+
9.647960662841797,
|
| 576 |
+
11.465564727783203,
|
| 577 |
+
9.91793155670166,
|
| 578 |
+
8.99271011352539,
|
| 579 |
+
11.874197959899902,
|
| 580 |
+
10.875059127807617,
|
| 581 |
+
10.751541137695312,
|
| 582 |
+
10.586625099182129,
|
| 583 |
+
10.616861343383789,
|
| 584 |
+
9.251531600952148,
|
| 585 |
+
9.575355529785156,
|
| 586 |
+
11.49870777130127,
|
| 587 |
+
11.352771759033203,
|
| 588 |
+
9.970162391662598,
|
| 589 |
+
12.869828224182129,
|
| 590 |
+
11.021011352539062,
|
| 591 |
+
11.830097198486328,
|
| 592 |
+
10.895241737365723,
|
| 593 |
+
13.477546691894531,
|
| 594 |
+
11.435956001281738,
|
| 595 |
+
11.21767807006836,
|
| 596 |
+
10.8616361618042,
|
| 597 |
+
11.25930404663086,
|
| 598 |
+
9.386629104614258,
|
| 599 |
+
10.510151863098145,
|
| 600 |
+
11.104487419128418,
|
| 601 |
+
10.017858505249023,
|
| 602 |
+
10.365488052368164,
|
| 603 |
+
11.206178665161133,
|
| 604 |
+
11.027682304382324,
|
| 605 |
+
11.81328010559082,
|
| 606 |
+
8.614967346191406,
|
| 607 |
+
10.088481903076172,
|
| 608 |
+
12.978555679321289,
|
| 609 |
+
11.964248657226562,
|
| 610 |
+
11.287935256958008,
|
| 611 |
+
11.514422416687012,
|
| 612 |
+
9.758452415466309,
|
| 613 |
+
10.500945091247559,
|
| 614 |
+
12.95924186706543,
|
| 615 |
+
10.438175201416016,
|
| 616 |
+
10.364145278930664,
|
| 617 |
+
11.490489959716797,
|
| 618 |
+
8.45285415649414,
|
| 619 |
+
9.380582809448242,
|
| 620 |
+
11.404769897460938,
|
| 621 |
+
10.42972183227539,
|
| 622 |
+
11.568924903869629,
|
| 623 |
+
11.746879577636719,
|
| 624 |
+
11.68482780456543,
|
| 625 |
+
10.019561767578125,
|
| 626 |
+
9.662923812866211,
|
| 627 |
+
10.360588073730469,
|
| 628 |
+
10.901131629943848,
|
| 629 |
+
10.128849029541016,
|
| 630 |
+
11.287601470947266,
|
| 631 |
+
10.017107009887695,
|
| 632 |
+
11.725995063781738,
|
| 633 |
+
11.726645469665527,
|
| 634 |
+
8.865287780761719,
|
| 635 |
+
12.030455589294434,
|
| 636 |
+
10.348114013671875,
|
| 637 |
+
9.747005462646484,
|
| 638 |
+
9.905638694763184,
|
| 639 |
+
9.855661392211914
|
| 640 |
+
],
|
| 641 |
+
"resnet50_transformer": [
|
| 642 |
+
11.059370040893555,
|
| 643 |
+
9.399184226989746,
|
| 644 |
+
12.046213150024414,
|
| 645 |
+
9.089208602905273,
|
| 646 |
+
10.357281684875488,
|
| 647 |
+
11.284062385559082,
|
| 648 |
+
12.057510375976562,
|
| 649 |
+
10.44680118560791,
|
| 650 |
+
12.231982231140137,
|
| 651 |
+
11.667984008789062,
|
| 652 |
+
10.58936595916748,
|
| 653 |
+
11.275989532470703,
|
| 654 |
+
11.5515718460083,
|
| 655 |
+
11.181893348693848,
|
| 656 |
+
10.986615180969238,
|
| 657 |
+
10.645844459533691,
|
| 658 |
+
11.427197456359863,
|
| 659 |
+
9.94046688079834,
|
| 660 |
+
11.769749641418457,
|
| 661 |
+
10.08527660369873,
|
| 662 |
+
10.387595176696777,
|
| 663 |
+
10.593070030212402,
|
| 664 |
+
10.218421936035156,
|
| 665 |
+
11.31839656829834,
|
| 666 |
+
10.221070289611816,
|
| 667 |
+
12.48313045501709,
|
| 668 |
+
10.069729804992676,
|
| 669 |
+
10.72489070892334,
|
| 670 |
+
9.746464729309082,
|
| 671 |
+
10.631884574890137,
|
| 672 |
+
10.4718656539917,
|
| 673 |
+
9.532330513000488,
|
| 674 |
+
11.773228645324707,
|
| 675 |
+
9.196700096130371,
|
| 676 |
+
10.46361255645752,
|
| 677 |
+
11.024060249328613,
|
| 678 |
+
11.556159019470215,
|
| 679 |
+
12.794964790344238,
|
| 680 |
+
10.599808692932129,
|
| 681 |
+
9.851262092590332,
|
| 682 |
+
9.222484588623047,
|
| 683 |
+
10.855106353759766,
|
| 684 |
+
10.37644100189209,
|
| 685 |
+
12.02999210357666,
|
| 686 |
+
9.35372257232666,
|
| 687 |
+
9.37324333190918,
|
| 688 |
+
11.150269508361816,
|
| 689 |
+
9.2609224319458,
|
| 690 |
+
9.703102111816406,
|
| 691 |
+
10.735487937927246,
|
| 692 |
+
9.811766624450684,
|
| 693 |
+
12.44902515411377,
|
| 694 |
+
9.46609115600586,
|
| 695 |
+
9.391528129577637,
|
| 696 |
+
10.850428581237793,
|
| 697 |
+
10.836078643798828,
|
| 698 |
+
10.690573692321777,
|
| 699 |
+
11.45797348022461,
|
| 700 |
+
11.194649696350098,
|
| 701 |
+
9.649679183959961,
|
| 702 |
+
9.42435359954834,
|
| 703 |
+
10.835038185119629,
|
| 704 |
+
10.169693946838379,
|
| 705 |
+
10.665839195251465,
|
| 706 |
+
9.50593376159668,
|
| 707 |
+
10.43022632598877,
|
| 708 |
+
9.70718002319336,
|
| 709 |
+
10.282594680786133,
|
| 710 |
+
11.035321235656738,
|
| 711 |
+
11.280767440795898,
|
| 712 |
+
11.161433219909668,
|
| 713 |
+
12.223311424255371,
|
| 714 |
+
12.273221015930176,
|
| 715 |
+
9.783215522766113,
|
| 716 |
+
11.48925495147705,
|
| 717 |
+
9.398808479309082,
|
| 718 |
+
9.856684684753418,
|
| 719 |
+
10.883062362670898,
|
| 720 |
+
12.250321388244629,
|
| 721 |
+
9.903286933898926,
|
| 722 |
+
9.940712928771973,
|
| 723 |
+
11.563044548034668,
|
| 724 |
+
11.503388404846191,
|
| 725 |
+
11.208588600158691,
|
| 726 |
+
9.659869194030762,
|
| 727 |
+
10.537753105163574,
|
| 728 |
+
9.693224906921387,
|
| 729 |
+
9.074763298034668,
|
| 730 |
+
10.447615623474121,
|
| 731 |
+
12.577223777770996,
|
| 732 |
+
12.126725196838379,
|
| 733 |
+
10.436871528625488,
|
| 734 |
+
9.327754020690918,
|
| 735 |
+
11.503960609436035,
|
| 736 |
+
10.199116706848145,
|
| 737 |
+
9.658470153808594,
|
| 738 |
+
11.43956470489502,
|
| 739 |
+
10.229013442993164,
|
| 740 |
+
10.632741928100586,
|
| 741 |
+
10.35991096496582,
|
| 742 |
+
9.791178703308105,
|
| 743 |
+
9.52518367767334,
|
| 744 |
+
9.412601470947266,
|
| 745 |
+
11.131546974182129,
|
| 746 |
+
11.0603666305542,
|
| 747 |
+
10.13420295715332,
|
| 748 |
+
12.53159236907959,
|
| 749 |
+
11.118557929992676,
|
| 750 |
+
11.233548164367676,
|
| 751 |
+
11.069842338562012,
|
| 752 |
+
12.089702606201172,
|
| 753 |
+
11.394057273864746,
|
| 754 |
+
11.059659957885742,
|
| 755 |
+
10.407622337341309,
|
| 756 |
+
11.183761596679688,
|
| 757 |
+
10.331181526184082,
|
| 758 |
+
10.6504487991333,
|
| 759 |
+
11.194979667663574,
|
| 760 |
+
11.137504577636719,
|
| 761 |
+
10.787008285522461,
|
| 762 |
+
11.586577415466309,
|
| 763 |
+
11.08724308013916,
|
| 764 |
+
11.208975791931152,
|
| 765 |
+
9.181191444396973,
|
| 766 |
+
10.107614517211914,
|
| 767 |
+
12.807703018188477,
|
| 768 |
+
11.750362396240234,
|
| 769 |
+
11.31899356842041,
|
| 770 |
+
11.11315631866455,
|
| 771 |
+
9.619100570678711,
|
| 772 |
+
10.388087272644043,
|
| 773 |
+
12.79947566986084,
|
| 774 |
+
11.307307243347168,
|
| 775 |
+
10.82922649383545,
|
| 776 |
+
11.55932903289795,
|
| 777 |
+
8.709542274475098,
|
| 778 |
+
9.288893699645996,
|
| 779 |
+
11.48713207244873,
|
| 780 |
+
9.693202018737793,
|
| 781 |
+
10.82302188873291,
|
| 782 |
+
11.73450756072998,
|
| 783 |
+
11.416834831237793,
|
| 784 |
+
11.133091926574707,
|
| 785 |
+
9.71113109588623,
|
| 786 |
+
10.830121040344238,
|
| 787 |
+
10.894770622253418,
|
| 788 |
+
9.935094833374023,
|
| 789 |
+
11.377425193786621,
|
| 790 |
+
11.13464641571045,
|
| 791 |
+
11.39898681640625,
|
| 792 |
+
12.140122413635254,
|
| 793 |
+
9.269479751586914,
|
| 794 |
+
12.450774192810059,
|
| 795 |
+
10.820216178894043,
|
| 796 |
+
9.736580848693848,
|
| 797 |
+
10.17590045928955,
|
| 798 |
+
9.74850845336914
|
| 799 |
+
]
|
| 800 |
+
}
|
| 801 |
+
}
|
templates/.nfs00000001a2893bde003726a5
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
test_moe_model.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import torchvision
|
| 5 |
+
import numpy as np
|
| 6 |
+
import json
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import sys
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
# Add parent directory to path to import the preprocess functions
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
+
from preprocess import process_audio_data, process_image_data
|
| 14 |
+
|
| 15 |
+
# Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file
|
| 16 |
+
from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
|
| 17 |
+
|
| 18 |
+
# Print library versions
|
| 19 |
+
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
|
| 20 |
+
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
|
| 21 |
+
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
|
| 22 |
+
|
| 23 |
+
# Device selection
|
| 24 |
+
device = torch.device(
|
| 25 |
+
"cuda" if torch.cuda.is_available()
|
| 26 |
+
else "mps" if torch.backends.mps.is_available()
|
| 27 |
+
else "cpu"
|
| 28 |
+
)
|
| 29 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
| 30 |
+
|
| 31 |
+
# Define the top-performing models based on the previous evaluation
|
| 32 |
+
TOP_MODELS = [
|
| 33 |
+
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
|
| 34 |
+
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
|
| 35 |
+
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# Define class for the MoE model
|
| 39 |
+
class WatermelonMoEModel(torch.nn.Module):
|
| 40 |
+
def __init__(self, model_configs, model_dir="test_models", weights=None):
|
| 41 |
+
"""
|
| 42 |
+
Mixture of Experts model that combines multiple backbone models.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
|
| 46 |
+
model_dir: Directory where model checkpoints are stored
|
| 47 |
+
weights: Optional list of weights for each model (None for equal weighting)
|
| 48 |
+
"""
|
| 49 |
+
super(WatermelonMoEModel, self).__init__()
|
| 50 |
+
self.models = []
|
| 51 |
+
self.model_configs = model_configs
|
| 52 |
+
|
| 53 |
+
# Load each model
|
| 54 |
+
for config in model_configs:
|
| 55 |
+
img_backbone = config["image_backbone"]
|
| 56 |
+
audio_backbone = config["audio_backbone"]
|
| 57 |
+
|
| 58 |
+
# Initialize model
|
| 59 |
+
model = WatermelonModelModular(img_backbone, audio_backbone)
|
| 60 |
+
|
| 61 |
+
# Load weights
|
| 62 |
+
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
| 63 |
+
if os.path.exists(model_path):
|
| 64 |
+
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
| 65 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 66 |
+
else:
|
| 67 |
+
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
model.to(device)
|
| 71 |
+
model.eval() # Set to evaluation mode
|
| 72 |
+
self.models.append(model)
|
| 73 |
+
|
| 74 |
+
# Set model weights (uniform by default)
|
| 75 |
+
if weights:
|
| 76 |
+
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
| 77 |
+
self.weights = weights
|
| 78 |
+
else:
|
| 79 |
+
self.weights = [1.0 / len(self.models)] * len(self.models)
|
| 80 |
+
|
| 81 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
|
| 82 |
+
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
| 83 |
+
|
| 84 |
+
def forward(self, mfcc, image):
|
| 85 |
+
"""
|
| 86 |
+
Forward pass through the MoE model.
|
| 87 |
+
Returns the weighted average of all model outputs.
|
| 88 |
+
"""
|
| 89 |
+
outputs = []
|
| 90 |
+
|
| 91 |
+
# Get outputs from each model
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
for i, model in enumerate(self.models):
|
| 94 |
+
output = model(mfcc, image)
|
| 95 |
+
outputs.append(output * self.weights[i])
|
| 96 |
+
|
| 97 |
+
# Return weighted average
|
| 98 |
+
return torch.sum(torch.stack(outputs), dim=0)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def evaluate_moe_model(data_dir, model_dir="test_models", weights=None):
|
| 102 |
+
"""
|
| 103 |
+
Evaluate the MoE model on the test set.
|
| 104 |
+
"""
|
| 105 |
+
# Load dataset
|
| 106 |
+
print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}")
|
| 107 |
+
dataset = WatermelonDataset(data_dir)
|
| 108 |
+
n_samples = len(dataset)
|
| 109 |
+
|
| 110 |
+
# Split dataset
|
| 111 |
+
train_size = int(0.7 * n_samples)
|
| 112 |
+
val_size = int(0.2 * n_samples)
|
| 113 |
+
test_size = n_samples - train_size - val_size
|
| 114 |
+
|
| 115 |
+
_, _, test_dataset = torch.utils.data.random_split(
|
| 116 |
+
dataset, [train_size, val_size, test_size]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Use a reasonable batch size
|
| 120 |
+
batch_size = 8
|
| 121 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 122 |
+
|
| 123 |
+
# Initialize MoE model
|
| 124 |
+
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
|
| 125 |
+
moe_model.eval()
|
| 126 |
+
|
| 127 |
+
# Evaluation metrics
|
| 128 |
+
mae_criterion = torch.nn.L1Loss()
|
| 129 |
+
mse_criterion = torch.nn.MSELoss()
|
| 130 |
+
|
| 131 |
+
test_mae = 0.0
|
| 132 |
+
test_mse = 0.0
|
| 133 |
+
|
| 134 |
+
print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples")
|
| 135 |
+
|
| 136 |
+
# Individual model predictions for analysis
|
| 137 |
+
individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": []
|
| 138 |
+
for config in TOP_MODELS}
|
| 139 |
+
true_labels = []
|
| 140 |
+
moe_predictions = []
|
| 141 |
+
|
| 142 |
+
# Evaluation loop
|
| 143 |
+
test_iterator = tqdm(test_loader, desc="Testing MoE")
|
| 144 |
+
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
for i, (mfcc, image, label) in enumerate(test_iterator):
|
| 147 |
+
try:
|
| 148 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
| 149 |
+
|
| 150 |
+
# Store individual model outputs for analysis
|
| 151 |
+
for j, model in enumerate(moe_model.models):
|
| 152 |
+
config = TOP_MODELS[j]
|
| 153 |
+
model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
|
| 154 |
+
output = model(mfcc, image)
|
| 155 |
+
individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
|
| 156 |
+
|
| 157 |
+
# Get MoE prediction
|
| 158 |
+
output = moe_model(mfcc, image)
|
| 159 |
+
moe_predictions.extend(output.view(-1).cpu().numpy())
|
| 160 |
+
|
| 161 |
+
# Store true labels
|
| 162 |
+
label = label.view(-1, 1).float()
|
| 163 |
+
true_labels.extend(label.view(-1).cpu().numpy())
|
| 164 |
+
|
| 165 |
+
# Calculate metrics
|
| 166 |
+
mae = mae_criterion(output, label)
|
| 167 |
+
mse = mse_criterion(output, label)
|
| 168 |
+
|
| 169 |
+
test_mae += mae.item()
|
| 170 |
+
test_mse += mse.item()
|
| 171 |
+
|
| 172 |
+
test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"})
|
| 173 |
+
|
| 174 |
+
# Clean up memory
|
| 175 |
+
if device.type == 'cuda':
|
| 176 |
+
del mfcc, image, label, output, mae, mse
|
| 177 |
+
torch.cuda.empty_cache()
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
|
| 181 |
+
if device.type == 'cuda':
|
| 182 |
+
torch.cuda.empty_cache()
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
# Calculate average metrics
|
| 186 |
+
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
|
| 187 |
+
avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf')
|
| 188 |
+
|
| 189 |
+
print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===")
|
| 190 |
+
print(f"Test MAE: {avg_test_mae:.4f}")
|
| 191 |
+
print(f"Test MSE: {avg_test_mse:.4f}")
|
| 192 |
+
|
| 193 |
+
# Compare with individual models
|
| 194 |
+
print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===")
|
| 195 |
+
print(f"{'Model':<30} {'Test MAE':<15}")
|
| 196 |
+
print("="*45)
|
| 197 |
+
|
| 198 |
+
# Load previous results
|
| 199 |
+
results_file = "backbone_evaluation_results.json"
|
| 200 |
+
if os.path.exists(results_file):
|
| 201 |
+
with open(results_file, 'r') as f:
|
| 202 |
+
previous_results = json.load(f)
|
| 203 |
+
|
| 204 |
+
# Filter results for our top models
|
| 205 |
+
for config in TOP_MODELS:
|
| 206 |
+
img_backbone = config["image_backbone"]
|
| 207 |
+
audio_backbone = config["audio_backbone"]
|
| 208 |
+
|
| 209 |
+
for result in previous_results:
|
| 210 |
+
if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone:
|
| 211 |
+
print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}")
|
| 212 |
+
|
| 213 |
+
print(f"MoE (Ensemble) {avg_test_mae:<15.4f}")
|
| 214 |
+
|
| 215 |
+
# Save results and predictions
|
| 216 |
+
results = {
|
| 217 |
+
"moe_test_mae": float(avg_test_mae),
|
| 218 |
+
"moe_test_mse": float(avg_test_mse),
|
| 219 |
+
"true_labels": [float(x) for x in true_labels],
|
| 220 |
+
"moe_predictions": [float(x) for x in moe_predictions],
|
| 221 |
+
"individual_predictions": {key: [float(x) for x in values]
|
| 222 |
+
for key, values in individual_predictions.items()}
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
with open("moe_evaluation_results.json", 'w') as f:
|
| 226 |
+
json.dump(results, f, indent=4)
|
| 227 |
+
|
| 228 |
+
print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json")
|
| 229 |
+
|
| 230 |
+
return avg_test_mae, avg_test_mse
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
import argparse
|
| 235 |
+
|
| 236 |
+
parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction")
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--data_dir",
|
| 239 |
+
type=str,
|
| 240 |
+
default="../cleaned",
|
| 241 |
+
help="Path to the cleaned dataset directory"
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--model_dir",
|
| 245 |
+
type=str,
|
| 246 |
+
default="test_models",
|
| 247 |
+
help="Directory containing model checkpoints"
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--weighting",
|
| 251 |
+
type=str,
|
| 252 |
+
choices=["uniform", "performance"],
|
| 253 |
+
default="uniform",
|
| 254 |
+
help="How to weight the models (uniform or based on performance)"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
# Determine weights based on argument
|
| 260 |
+
weights = None
|
| 261 |
+
if args.weighting == "performance":
|
| 262 |
+
# Weights inversely proportional to the MAE (better models get higher weights)
|
| 263 |
+
# These are the MAE values from the provided results
|
| 264 |
+
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
|
| 265 |
+
|
| 266 |
+
# Convert to weights (inverse of MAE, normalized)
|
| 267 |
+
inverse_mae = [1/mae for mae in mae_values]
|
| 268 |
+
total = sum(inverse_mae)
|
| 269 |
+
weights = [val/total for val in inverse_mae]
|
| 270 |
+
|
| 271 |
+
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
|
| 272 |
+
else:
|
| 273 |
+
print(f"\033[92mINFO\033[0m: Using uniform weights")
|
| 274 |
+
|
| 275 |
+
# Evaluate the MoE model
|
| 276 |
+
evaluate_moe_model(args.data_dir, args.model_dir, weights)
|