File size: 15,253 Bytes
fdc673b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86ab2cb
fdc673b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import torch, torchaudio, torchvision
import os
import gradio as gr
import numpy as np

from preprocess import process_audio_data, process_image_data
from train import WatermelonModel
from infer import infer

def load_model(model_path):
    global device
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"\033[92mINFO\033[0m: Using device: {device}")

    # Check if the file exists
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")
    
    # Check if the file is empty or very small
    file_size = os.path.getsize(model_path)
    if file_size < 1000:  # Less than 1KB is suspiciously small for a model
        print(f"\033[93mWARNING\033[0m: Model file size is only {file_size} bytes, which is suspiciously small")
    
    try:
        model = WatermelonModel().to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
        return model
    except RuntimeError as e:
        if "failed finding central directory" in str(e):
            print(f"\033[91mERROR\033[0m: The model file at {model_path} appears to be corrupted.")
            print("This can happen if:")
            print("  1. The model saving process was interrupted")
            print("  2. The file was not properly downloaded")
            print("  3. The path points to a file that is not a valid PyTorch model")
            print(f"File size: {file_size} bytes")
        raise

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Watermelon sweetness predictor")
    parser.add_argument("--model_path", type=str, default="./models/model_15_20250405-033557.pt", help="Path to the trained model")
    args = parser.parse_args()

    model = load_model(args.model_path)

    def predict(audio, image):
        try:
            # Debug audio input
            print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
            print(f"\033[92mDEBUG\033[0m: Audio input value: {audio}")
            
            # Handle different formats of audio input from Gradio
            if audio is None:
                return "Error: No audio provided. Please upload or record audio."
                
            if isinstance(audio, tuple) and len(audio) >= 2:
                sr, audio_data = audio[0], audio[-1]
                print(f"\033[92mDEBUG\033[0m: Audio format: sr={sr}, audio_data shape={audio_data.shape if hasattr(audio_data, 'shape') else 'no shape'}")
            elif isinstance(audio, tuple) and len(audio) == 1:
                # Handle single element tuple
                audio_data = audio[0]
                sr = 44100  # Assume default sample rate
                print(f"\033[92mDEBUG\033[0m: Single element audio tuple, using default sr={sr}")
            elif isinstance(audio, np.ndarray):
                # Handle direct numpy array
                audio_data = audio
                sr = 44100  # Assume default sample rate
                print(f"\033[92mDEBUG\033[0m: Audio is numpy array, using default sr={sr}")
            else:
                return f"Error: Unexpected audio format: {type(audio)}"
            
            # Ensure audio_data is correctly shaped
            if isinstance(audio_data, np.ndarray):
                # Make sure we have a 2D array
                if len(audio_data.shape) == 1:
                    audio_data = np.expand_dims(audio_data, axis=0)
                    print(f"\033[92mDEBUG\033[0m: Reshaped 1D audio to 2D: {audio_data.shape}")
                
                # If channels are the second dimension, transpose
                if len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]:
                    audio_data = np.transpose(audio_data)
                    print(f"\033[92mDEBUG\033[0m: Transposed audio shape to: {audio_data.shape}")
            
            # Convert to tensor
            audio_tensor = torch.tensor(audio_data).float()
            print(f"\033[92mDEBUG\033[0m: Audio tensor shape: {audio_tensor.shape}")
            
            # Process audio data and handle None case
            mfcc = process_audio_data(audio_tensor, sr)
            if mfcc is None:
                return "Error: Failed to process audio data. Make sure your audio contains a clear tapping sound."
                
            mfcc = mfcc.to(device)
            print(f"\033[92mDEBUG\033[0m: MFCC shape: {mfcc.shape}")
            
            # Debug image input
            print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
            print(f"\033[92mDEBUG\033[0m: Image shape: {image.shape if hasattr(image, 'shape') else 'No shape'}")
            
            # Process image data and handle None case
            if image is None:
                return "Error: No image provided. Please upload an image."
                
            # Handle different image formats
            if isinstance(image, np.ndarray):
                # Check if image is properly formatted (H, W, C) with 3 channels
                if len(image.shape) == 3 and image.shape[2] == 3:
                    # Convert to tensor with shape (C, H, W) as expected by PyTorch
                    img = torch.tensor(image).float().permute(2, 0, 1)
                    print(f"\033[92mDEBUG\033[0m: Converted image to tensor with shape: {img.shape}")
                elif len(image.shape) == 2:
                    # Grayscale image, expand to 3 channels
                    img = torch.tensor(image).float().unsqueeze(0).repeat(3, 1, 1)
                    print(f"\033[92mDEBUG\033[0m: Converted grayscale image to RGB tensor with shape: {img.shape}")
                else:
                    return f"Error: Unexpected image shape: {image.shape}. Expected RGB or grayscale image."
            else:
                return f"Error: Unexpected image format: {type(image)}. Expected numpy array."
            
            # Scale pixel values to [0, 1] if needed
            if img.max() > 1.0:
                img = img / 255.0
                print(f"\033[92mDEBUG\033[0m: Scaled image pixel values to range [0, 1]")
            
            # Get image dimensions and check if they're reasonable
            print(f"\033[92mDEBUG\033[0m: Final image tensor shape before processing: {img.shape}")
            
            # Process image
            try:
                img_processed = process_image_data(img)
                if img_processed is None:
                    return "Error: Failed to process image data. Make sure your image clearly shows a watermelon."
                    
                img_processed = img_processed.to(device)
                print(f"\033[92mDEBUG\033[0m: Processed image shape: {img_processed.shape}")
            except Exception as e:
                print(f"\033[91mERROR\033[0m: Image processing error: {str(e)}")
                return f"Error in image processing: {str(e)}"
            
            # Run inference
            try:
                # Based on the error, it seems infer() expects file paths, not tensors
                # Let's create temporary files for the processed data
                temp_dir = os.path.join(os.getcwd(), "temp")
                os.makedirs(temp_dir, exist_ok=True)
                
                # Save the audio to a temporary file if infer expects a file path
                temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
                if not isinstance(audio, str) and isinstance(audio, tuple) and len(audio) >= 2:
                    # If we have the original audio data and sample rate
                    audio_array = audio[-1]
                    sr = audio[0]
                    
                    # Check if the audio array is valid
                    if audio_array.size == 0:
                        return "Error: Audio data is empty. Please record a longer audio clip."
                    
                    # Get the duration of the audio
                    duration = audio_array.shape[-1] / sr
                    print(f"\033[92mDEBUG\033[0m: Audio duration: {duration:.2f} seconds")
                    
                    # Check if we have at least 1 second of audio - but don't reject, just pad if needed
                    min_duration = 1.0  # minimum 1 second of audio
                    if duration < min_duration:
                        print(f"\033[93mWARNING\033[0m: Audio is shorter than {min_duration} seconds. Padding will be applied.")
                        # Calculate samples needed to reach minimum duration
                        samples_needed = int(min_duration * sr) - audio_array.shape[-1]
                        # Pad with zeros
                        padding = np.zeros((audio_array.shape[0], samples_needed), dtype=audio_array.dtype)
                        audio_array = np.concatenate([audio_array, padding], axis=1)
                        print(f"\033[92mDEBUG\033[0m: Padded audio to shape: {audio_array.shape}")
                    
                    # Make sure audio has 2 dimensions
                    if len(audio_array.shape) == 1:
                        audio_array = np.expand_dims(audio_array, axis=0)
                    
                    print(f"\033[92mDEBUG\033[0m: Audio array shape before saving: {audio_array.shape}, sr: {sr}")
                    
                    # Make sure it's in the right format for torchaudio.save
                    audio_tensor = torch.tensor(audio_array).float()
                    if audio_tensor.dim() == 1:
                        audio_tensor = audio_tensor.unsqueeze(0)
                    
                    torchaudio.save(temp_audio_path, audio_tensor, sr)
                    print(f"\033[92mDEBUG\033[0m: Saved temporary audio file to {temp_audio_path}")
                    
                    # Let's also process the audio here to verify it works
                    test_mfcc = process_audio_data(audio_tensor, sr)
                    if test_mfcc is None:
                        return "Error: Unable to process the audio. Please try recording a different audio sample."
                    else:
                        print(f"\033[92mDEBUG\033[0m: Audio pre-check passed. MFCC shape: {test_mfcc.shape}")
                    
                    audio_path = temp_audio_path
                else:
                    # If we don't have a valid path, return an error
                    return "Error: Cannot process audio for inference. Invalid audio format."
                
                # Save the image to a temporary file if infer expects a file path
                temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
                if isinstance(image, np.ndarray):
                    import cv2
                    cv2.imwrite(temp_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
                    print(f"\033[92mDEBUG\033[0m: Saved temporary image file to {temp_image_path}")
                    image_path = temp_image_path
                else:
                    # If we don't have a valid image, return an error
                    return "Error: Cannot process image for inference. Invalid image format."
                
                # Create a modified version of infer that handles None returns
                def safe_infer(audio_path, image_path, model, device):
                    try:
                        return infer(audio_path, image_path, model, device)
                    except Exception as e:
                        print(f"\033[91mERROR\033[0m: Error in infer function: {str(e)}")
                        # Try a more direct approach
                        try:
                            # Load audio and process
                            audio, sr = torchaudio.load(audio_path)
                            mfcc = process_audio_data(audio, sr)
                            if mfcc is None:
                                raise ValueError("Audio processing failed - MFCC is None")
                            mfcc = mfcc.to(device)
                            
                            # Load image and process
                            image = cv2.imread(image_path)
                            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                            image_tensor = torch.tensor(image).float().permute(2, 0, 1) / 255.0
                            img_processed = process_image_data(image_tensor)
                            if img_processed is None:
                                raise ValueError("Image processing failed - processed image is None")
                            img_processed = img_processed.to(device)
                            
                            # Run model inference
                            with torch.no_grad():
                                prediction = model(mfcc, img_processed)
                            return prediction
                        except Exception as e2:
                            print(f"\033[91mERROR\033[0m: Fallback inference also failed: {str(e2)}")
                            raise
                
                # Call our safer version
                print(f"\033[92mDEBUG\033[0m: Calling safe_infer with audio_path={audio_path}, image_path={image_path}")
                sweetness = safe_infer(audio_path, image_path, model, device)
                if sweetness is None:
                    return "Error: The model was unable to make a prediction. Please try with different inputs."
                    
                print(f"\033[92mDEBUG\033[0m: Inference result: {sweetness.item()}")
                return f"Predicted Sweetness: {sweetness.item():.2f}/10"
            except Exception as e:
                import traceback
                print(f"\033[91mERROR\033[0m: Inference failed: {str(e)}")
                print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
                return f"Error during inference: {str(e)}"
            
        except Exception as e:
            import traceback
            print(f"\033[91mERROR\033[0m: Prediction failed: {str(e)}")
            print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
            return f"Error processing input: {str(e)}"

    audio_input = gr.Audio(label="Upload or Record Audio")
    image_input = gr.Image(label="Upload or Capture Image")
    output = gr.Textbox(label="Predicted Sweetness")

    interface = gr.Interface(
        fn=predict,
        inputs=[audio_input, image_input],
        outputs=output,
        title="Watermelon Sweetness Predictor",
        description="Upload an audio file and an image to predict the sweetness of a watermelon."
    )

    try:
        interface.launch(share=True)  # Enable sharing to avoid localhost access issues
    except Exception as e:
        print(f"\033[91mERROR\033[0m: Failed to launch interface: {e}")
        print("\033[93mTIP\033[0m: If you're running in a remote environment or container, try setting additional parameters:")
        print("    interface.launch(server_name='0.0.0.0', share=True)")