File size: 15,024 Bytes
fdc673b
 
 
 
34c4a97
8d95e52
fdc673b
 
 
 
 
34c4a97
 
 
 
 
 
 
 
 
 
 
 
fdc673b
 
 
34c4a97
fdc673b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34c4a97
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
fdc673b
34c4a97
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
fdc673b
34c4a97
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
 
 
 
 
 
 
 
fdc673b
34c4a97
fdc673b
34c4a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc673b
34c4a97
fdc673b
 
 
 
 
 
 
 
 
 
 
 
 
34c4a97
fdc673b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import torch, torchaudio, torchvision
import os
import gradio as gr
import numpy as np
import traceback
import spaces

from preprocess import process_audio_data, process_image_data
from train import WatermelonModel
from infer import infer

# Add HuggingFace Spaces GPU decorator
try:
    use_gpu_decorator = True
    print("\033[92mINFO\033[0m: HuggingFace Spaces GPU support detected")
except ImportError:
    use_gpu_decorator = False
    print("\033[93mWARNING\033[0m: HuggingFace Spaces GPU support not detected, running in standard mode")

# Global device variable
device = None

@spaces.GPU
def load_model(model_path):
    global device
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 
    )
    print(f"\033[92mINFO\033[0m: Using device: {device}")

    # Check if the file exists
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")
    
    # Check if the file is empty or very small
    file_size = os.path.getsize(model_path)
    if file_size < 1000:  # Less than 1KB is suspiciously small for a model
        print(f"\033[93mWARNING\033[0m: Model file size is only {file_size} bytes, which is suspiciously small")
    
    try:
        model = WatermelonModel().to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
        return model
    except RuntimeError as e:
        if "failed finding central directory" in str(e):
            print(f"\033[91mERROR\033[0m: The model file at {model_path} appears to be corrupted.")
            print("This can happen if:")
            print("  1. The model saving process was interrupted")
            print("  2. The file was not properly downloaded")
            print("  3. The path points to a file that is not a valid PyTorch model")
            print(f"File size: {file_size} bytes")
        raise

# Define the main prediction function
def predict_impl(audio, image, model):
    try:
        # Debug audio input
        print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
        print(f"\033[92mDEBUG\033[0m: Audio input value: {audio}")
        
        # Handle different formats of audio input from Gradio
        if audio is None:
            return "Error: No audio provided. Please upload or record audio."
            
        if isinstance(audio, tuple) and len(audio) >= 2:
            sr, audio_data = audio[0], audio[-1]
            print(f"\033[92mDEBUG\033[0m: Audio format: sr={sr}, audio_data shape={audio_data.shape if hasattr(audio_data, 'shape') else 'no shape'}")
        elif isinstance(audio, tuple) and len(audio) == 1:
            # Handle single element tuple
            audio_data = audio[0]
            sr = 44100  # Assume default sample rate
            print(f"\033[92mDEBUG\033[0m: Single element audio tuple, using default sr={sr}")
        elif isinstance(audio, np.ndarray):
            # Handle direct numpy array
            audio_data = audio
            sr = 44100  # Assume default sample rate
            print(f"\033[92mDEBUG\033[0m: Audio is numpy array, using default sr={sr}")
        else:
            return f"Error: Unexpected audio format: {type(audio)}"
        
        # Ensure audio_data is correctly shaped
        if isinstance(audio_data, np.ndarray):
            # Make sure we have a 2D array
            if len(audio_data.shape) == 1:
                audio_data = np.expand_dims(audio_data, axis=0)
                print(f"\033[92mDEBUG\033[0m: Reshaped 1D audio to 2D: {audio_data.shape}")
            
            # If channels are the second dimension, transpose
            if len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]:
                audio_data = np.transpose(audio_data)
                print(f"\033[92mDEBUG\033[0m: Transposed audio shape to: {audio_data.shape}")
        
        # Convert to tensor
        audio_tensor = torch.tensor(audio_data).float()
        print(f"\033[92mDEBUG\033[0m: Audio tensor shape: {audio_tensor.shape}")
        
        # Process audio data and handle None case
        mfcc = process_audio_data(audio_tensor, sr)
        if mfcc is None:
            return "Error: Failed to process audio data. Make sure your audio contains a clear tapping sound."
            
        mfcc = mfcc.to(device)
        print(f"\033[92mDEBUG\033[0m: MFCC shape: {mfcc.shape}")
        
        # Debug image input
        print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
        print(f"\033[92mDEBUG\033[0m: Image shape: {image.shape if hasattr(image, 'shape') else 'No shape'}")
        
        # Process image data and handle None case
        if image is None:
            return "Error: No image provided. Please upload an image."
            
        # Handle different image formats
        if isinstance(image, np.ndarray):
            # Check if image is properly formatted (H, W, C) with 3 channels
            if len(image.shape) == 3 and image.shape[2] == 3:
                # Convert to tensor with shape (C, H, W) as expected by PyTorch
                img = torch.tensor(image).float().permute(2, 0, 1)
                print(f"\033[92mDEBUG\033[0m: Converted image to tensor with shape: {img.shape}")
            elif len(image.shape) == 2:
                # Grayscale image, expand to 3 channels
                img = torch.tensor(image).float().unsqueeze(0).repeat(3, 1, 1)
                print(f"\033[92mDEBUG\033[0m: Converted grayscale image to RGB tensor with shape: {img.shape}")
            else:
                return f"Error: Unexpected image shape: {image.shape}. Expected RGB or grayscale image."
        else:
            return f"Error: Unexpected image format: {type(image)}. Expected numpy array."
        
        # Scale pixel values to [0, 1] if needed
        if img.max() > 1.0:
            img = img / 255.0
            print(f"\033[92mDEBUG\033[0m: Scaled image pixel values to range [0, 1]")
        
        # Get image dimensions and check if they're reasonable
        print(f"\033[92mDEBUG\033[0m: Final image tensor shape before processing: {img.shape}")
        
        # Process image
        try:
            img_processed = process_image_data(img)
            if img_processed is None:
                return "Error: Failed to process image data. Make sure your image clearly shows a watermelon."
                
            img_processed = img_processed.to(device)
            print(f"\033[92mDEBUG\033[0m: Processed image shape: {img_processed.shape}")
        except Exception as e:
            print(f"\033[91mERROR\033[0m: Image processing error: {str(e)}")
            return f"Error in image processing: {str(e)}"
        
        # Run inference
        try:
            # Based on the error, it seems infer() expects file paths, not tensors
            # Let's create temporary files for the processed data
            temp_dir = os.path.join(os.getcwd(), "temp")
            os.makedirs(temp_dir, exist_ok=True)
            
            # Save the audio to a temporary file if infer expects a file path
            temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
            if not isinstance(audio, str) and isinstance(audio, tuple) and len(audio) >= 2:
                # If we have the original audio data and sample rate
                audio_array = audio[-1]
                sr = audio[0]
                
                # Check if the audio array is valid
                if audio_array.size == 0:
                    return "Error: Audio data is empty. Please record a longer audio clip."
                
                # Get the duration of the audio
                duration = audio_array.shape[-1] / sr
                print(f"\033[92mDEBUG\033[0m: Audio duration: {duration:.2f} seconds")
                
                # Check if we have at least 1 second of audio - but don't reject, just pad if needed
                min_duration = 1.0  # minimum 1 second of audio
                if duration < min_duration:
                    print(f"\033[93mWARNING\033[0m: Audio is shorter than {min_duration} seconds. Padding will be applied.")
                    # Calculate samples needed to reach minimum duration
                    samples_needed = int(min_duration * sr) - audio_array.shape[-1]
                    # Pad with zeros
                    padding = np.zeros((audio_array.shape[0], samples_needed), dtype=audio_array.dtype)
                    audio_array = np.concatenate([audio_array, padding], axis=1)
                    print(f"\033[92mDEBUG\033[0m: Padded audio to shape: {audio_array.shape}")
                
                # Make sure audio has 2 dimensions
                if len(audio_array.shape) == 1:
                    audio_array = np.expand_dims(audio_array, axis=0)
                
                print(f"\033[92mDEBUG\033[0m: Audio array shape before saving: {audio_array.shape}, sr: {sr}")
                
                # Make sure it's in the right format for torchaudio.save
                audio_tensor = torch.tensor(audio_array).float()
                if audio_tensor.dim() == 1:
                    audio_tensor = audio_tensor.unsqueeze(0)
                
                torchaudio.save(temp_audio_path, audio_tensor, sr)
                print(f"\033[92mDEBUG\033[0m: Saved temporary audio file to {temp_audio_path}")
                
                # Let's also process the audio here to verify it works
                test_mfcc = process_audio_data(audio_tensor, sr)
                if test_mfcc is None:
                    return "Error: Unable to process the audio. Please try recording a different audio sample."
                else:
                    print(f"\033[92mDEBUG\033[0m: Audio pre-check passed. MFCC shape: {test_mfcc.shape}")
                
                audio_path = temp_audio_path
            else:
                # If we don't have a valid path, return an error
                return "Error: Cannot process audio for inference. Invalid audio format."
            
            # Save the image to a temporary file if infer expects a file path
            temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
            if isinstance(image, np.ndarray):
                import cv2
                cv2.imwrite(temp_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
                print(f"\033[92mDEBUG\033[0m: Saved temporary image file to {temp_image_path}")
                image_path = temp_image_path
            else:
                # If we don't have a valid image, return an error
                return "Error: Cannot process image for inference. Invalid image format."
            
            # Create a modified version of infer that handles None returns
            def safe_infer(audio_path, image_path, model, device):
                try:
                    return infer(audio_path, image_path, model, device)
                except Exception as e:
                    print(f"\033[91mERROR\033[0m: Error in infer function: {str(e)}")
                    # Try a more direct approach
                    try:
                        # Load audio and process
                        audio, sr = torchaudio.load(audio_path)
                        mfcc = process_audio_data(audio, sr)
                        if mfcc is None:
                            raise ValueError("Audio processing failed - MFCC is None")
                        mfcc = mfcc.to(device)
                        
                        # Load image and process
                        image = cv2.imread(image_path)
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                        image_tensor = torch.tensor(image).float().permute(2, 0, 1) / 255.0
                        img_processed = process_image_data(image_tensor)
                        if img_processed is None:
                            raise ValueError("Image processing failed - processed image is None")
                        img_processed = img_processed.to(device)
                        
                        # Run model inference
                        with torch.no_grad():
                            prediction = model(mfcc, img_processed)
                        return prediction
                    except Exception as e2:
                        print(f"\033[91mERROR\033[0m: Fallback inference also failed: {str(e2)}")
                        raise
            
            # Call our safer version
            print(f"\033[92mDEBUG\033[0m: Calling safe_infer with audio_path={audio_path}, image_path={image_path}")
            sweetness = safe_infer(audio_path, image_path, model, device)
            if sweetness is None:
                return "Error: The model was unable to make a prediction. Please try with different inputs."
                
            print(f"\033[92mDEBUG\033[0m: Inference result: {sweetness.item()}")
            return f"Predicted Sweetness: {sweetness.item():.2f}/10"
        except Exception as e:
            print(f"\033[91mERROR\033[0m: Inference failed: {str(e)}")
            print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
            return f"Error during inference: {str(e)}"
        
    except Exception as e:
        print(f"\033[91mERROR\033[0m: Prediction failed: {str(e)}")
        print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
        return f"Error processing input: {str(e)}"

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Watermelon sweetness predictor")
    parser.add_argument("--model_path", type=str, default="./models/model_15_20250405-033557.pt", help="Path to the trained model")
    args = parser.parse_args()

    
    # Create wrapper function for Gradio that passes the model
    @spaces.GPU
    def predict(audio, image):
        model = load_model(args.model_path)
        return predict_impl(audio, image, model)
    print("\033[92mINFO\033[0m: GPU acceleration enabled via @spaces.GPU decorator")

    # Set up Gradio interface
    audio_input = gr.Audio(label="Upload or Record Audio")
    image_input = gr.Image(label="Upload or Capture Image")
    output = gr.Textbox(label="Predicted Sweetness")

    interface = gr.Interface(
        fn=predict,
        inputs=[audio_input, image_input],
        outputs=output,
        title="Watermelon Sweetness Predictor",
        description="Upload an audio file and an image to predict the sweetness of a watermelon."
    )

    try:
        interface.launch()  # Launch the interface
    except Exception as e:
        print(f"\033[91mERROR\033[0m: Failed to launch interface: {e}")
        print("\033[93mTIP\033[0m: If you're running in a remote environment or container, try setting additional parameters:")
        print("    interface.launch(server_name='0.0.0.0', share=True)")