Xalphinions commited on
Commit
fb889d2
·
verified ·
1 Parent(s): 945fdb4

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +148 -288
app.py CHANGED
@@ -6,14 +6,14 @@ import gradio as gr
6
  import torchaudio
7
  import torchvision
8
 
9
- # Import Gradio Spaces GPU decorator
10
- try:
11
- from gradio import spaces
12
- HAS_SPACES = True
13
- print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
14
- except ImportError:
15
- HAS_SPACES = False
16
- print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
17
 
18
  # Add parent directory to path to import preprocess functions
19
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -78,291 +78,151 @@ def app_process_audio_data(waveform, sample_rate):
78
  # Similarly for images, but let's import the original one
79
  from preprocess import process_image_data
80
 
81
- # Apply GPU decorator directly to the function if available
82
- if HAS_SPACES:
83
  # Using the decorator directly on the function definition
84
- @spaces.GPU
85
- def predict_sweetness(audio, image, model_path):
86
- """Function with GPU acceleration"""
87
- try:
88
- # Now check CUDA availability inside the GPU-decorated function
89
- if torch.cuda.is_available():
90
- device = torch.device("cuda")
91
- print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
92
- else:
93
- device = torch.device("cpu")
94
- print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
95
-
96
- # Load model inside the function to ensure it's on the correct device
97
- model = WatermelonModel().to(device)
98
- model.load_state_dict(torch.load(model_path, map_location=device))
99
- model.eval()
100
- print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
101
-
102
- # Debug information about input types
103
- print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
104
- print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
105
- print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
106
- if isinstance(image, np.ndarray):
107
- print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
108
-
109
- # Handle different audio input formats
110
- if isinstance(audio, tuple) and len(audio) == 2:
111
- # Standard Gradio format: (sample_rate, audio_data)
112
- sample_rate, audio_data = audio
113
- print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
114
- print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
115
- elif isinstance(audio, tuple) and len(audio) > 2:
116
- # Sometimes Gradio returns (sample_rate, audio_data, other_info...)
117
- sample_rate, audio_data = audio[0], audio[-1]
118
- print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
119
- print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
120
- elif isinstance(audio, str):
121
- # Direct path to audio file
122
- audio_data, sample_rate = torchaudio.load(audio)
123
- print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
124
- else:
125
- return f"Error: Unsupported audio format. Got {type(audio)}"
126
-
127
- # Create a temporary file path for the audio and image
128
- temp_dir = "temp"
129
- os.makedirs(temp_dir, exist_ok=True)
130
-
131
- temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
132
- temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
133
-
134
- # Import necessary libraries
135
- from PIL import Image
136
-
137
- # Audio handling - direct processing from the data in memory
138
- if isinstance(audio_data, np.ndarray):
139
- # Convert numpy array to tensor
140
- print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
141
- audio_tensor = torch.tensor(audio_data).float()
142
-
143
- # Handle different audio dimensions
144
- if audio_data.ndim == 1:
145
- # Single channel audio
146
- audio_tensor = audio_tensor.unsqueeze(0)
147
- elif audio_data.ndim == 2:
148
- # Ensure channels are first dimension
149
- if audio_data.shape[0] > audio_data.shape[1]:
150
- # More rows than columns, probably (samples, channels)
151
- audio_tensor = torch.tensor(audio_data.T).float()
152
- else:
153
- # Already a tensor
154
- audio_tensor = audio_data.float()
155
-
156
- print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
157
-
158
- # Skip saving/loading and process directly
159
- mfcc = app_process_audio_data(audio_tensor, sample_rate)
160
- print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
161
-
162
- # Image handling
163
- if isinstance(image, np.ndarray):
164
- print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
165
- pil_image = Image.fromarray(image)
166
- pil_image.save(temp_image_path)
167
- print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
168
- elif isinstance(image, str):
169
- # If image is already a path
170
- temp_image_path = image
171
- print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
172
- else:
173
- return f"Error: Unsupported image format. Got {type(image)}"
174
-
175
- # Process image
176
- print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
177
- image_tensor = torchvision.io.read_image(temp_image_path)
178
- print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
179
- image_tensor = image_tensor.float()
180
- processed_image = process_image_data(image_tensor)
181
- print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
182
-
183
- # Add batch dimension for inference and move to device
184
- if mfcc is not None:
185
- mfcc = mfcc.unsqueeze(0).to(device)
186
- print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
187
-
188
- if processed_image is not None:
189
- processed_image = processed_image.unsqueeze(0).to(device)
190
- print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
191
-
192
- # Run inference
193
- print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
194
- if mfcc is not None and processed_image is not None:
195
- with torch.no_grad():
196
- sweetness = model(mfcc, processed_image)
197
- print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  else:
199
- return "Error: Failed to process inputs. Please check the debug logs."
200
-
201
- # Format the result
202
- if sweetness is not None:
203
- result = f"Predicted Sweetness: {sweetness.item():.2f}/13"
204
 
205
- # Add a qualitative description
206
- if sweetness.item() < 9:
207
- result += "\n\nThis watermelon is not very sweet. You might want to choose another one."
208
- elif sweetness.item() < 10:
209
- result += "\n\nThis watermelon has moderate sweetness."
210
- elif sweetness.item() < 11:
211
- result += "\n\nThis watermelon is sweet! A good choice."
212
- else:
213
- result += "\n\nThis watermelon is very sweet! Excellent choice!"
214
-
215
- return result
216
- else:
217
- return "Error: Could not predict sweetness. Please try again with different inputs."
218
-
219
- except Exception as e:
220
- import traceback
221
- error_msg = f"Error: {str(e)}\n\n"
222
- error_msg += traceback.format_exc()
223
- print(f"\033[91mERR!\033[0m: {error_msg}")
224
- return error_msg
225
 
226
  print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
227
- else:
228
- # Regular version without GPU decorator for non-Spaces environments
229
- def predict_sweetness(audio, image, model_path):
230
- """Predict sweetness of a watermelon from audio and image input"""
231
- try:
232
- # Check for device - will be CPU in this case
233
- device = torch.device("cpu")
234
- print(f"\033[92mINFO\033[0m: Using device: {device}")
235
-
236
- # Load model inside the function
237
- model = WatermelonModel().to(device)
238
- model.load_state_dict(torch.load(model_path, map_location=device))
239
- model.eval()
240
- print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
241
-
242
- # Rest of function identical - processing code
243
- # Debug information about input types
244
- print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
245
- print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
246
- print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
247
- if isinstance(image, np.ndarray):
248
- print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
249
-
250
- # Handle different audio input formats
251
- if isinstance(audio, tuple) and len(audio) == 2:
252
- # Standard Gradio format: (sample_rate, audio_data)
253
- sample_rate, audio_data = audio
254
- print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
255
- print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
256
- elif isinstance(audio, tuple) and len(audio) > 2:
257
- # Sometimes Gradio returns (sample_rate, audio_data, other_info...)
258
- sample_rate, audio_data = audio[0], audio[-1]
259
- print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
260
- print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
261
- elif isinstance(audio, str):
262
- # Direct path to audio file
263
- audio_data, sample_rate = torchaudio.load(audio)
264
- print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
265
- else:
266
- return f"Error: Unsupported audio format. Got {type(audio)}"
267
-
268
- # Create a temporary file path for the audio and image
269
- temp_dir = "temp"
270
- os.makedirs(temp_dir, exist_ok=True)
271
-
272
- temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
273
- temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
274
-
275
- # Import necessary libraries
276
- from PIL import Image
277
-
278
- # Audio handling - direct processing from the data in memory
279
- if isinstance(audio_data, np.ndarray):
280
- # Convert numpy array to tensor
281
- print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
282
- audio_tensor = torch.tensor(audio_data).float()
283
-
284
- # Handle different audio dimensions
285
- if audio_data.ndim == 1:
286
- # Single channel audio
287
- audio_tensor = audio_tensor.unsqueeze(0)
288
- elif audio_data.ndim == 2:
289
- # Ensure channels are first dimension
290
- if audio_data.shape[0] > audio_data.shape[1]:
291
- # More rows than columns, probably (samples, channels)
292
- audio_tensor = torch.tensor(audio_data.T).float()
293
- else:
294
- # Already a tensor
295
- audio_tensor = audio_data.float()
296
-
297
- print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
298
-
299
- # Skip saving/loading and process directly
300
- mfcc = app_process_audio_data(audio_tensor, sample_rate)
301
- print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
302
-
303
- # Image handling
304
- if isinstance(image, np.ndarray):
305
- print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
306
- pil_image = Image.fromarray(image)
307
- pil_image.save(temp_image_path)
308
- print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
309
- elif isinstance(image, str):
310
- # If image is already a path
311
- temp_image_path = image
312
- print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
313
- else:
314
- return f"Error: Unsupported image format. Got {type(image)}"
315
-
316
- # Process image
317
- print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
318
- image_tensor = torchvision.io.read_image(temp_image_path)
319
- print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
320
- image_tensor = image_tensor.float()
321
- processed_image = process_image_data(image_tensor)
322
- print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
323
-
324
- # Add batch dimension for inference and move to device
325
- if mfcc is not None:
326
- mfcc = mfcc.unsqueeze(0).to(device)
327
- print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
328
-
329
- if processed_image is not None:
330
- processed_image = processed_image.unsqueeze(0).to(device)
331
- print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
332
-
333
- # Run inference
334
- print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
335
- if mfcc is not None and processed_image is not None:
336
- with torch.no_grad():
337
- sweetness = model(mfcc, processed_image)
338
- print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}")
339
- else:
340
- return "Error: Failed to process inputs. Please check the debug logs."
341
-
342
- # Format the result
343
- if sweetness is not None:
344
- result = f"Predicted Sweetness: {sweetness.item():.2f}/13"
345
-
346
- # Add a qualitative description
347
- if sweetness.item() < 9:
348
- result += "\n\nThis watermelon is not very sweet. You might want to choose another one."
349
- elif sweetness.item() < 10:
350
- result += "\n\nThis watermelon has moderate sweetness."
351
- elif sweetness.item() < 11:
352
- result += "\n\nThis watermelon is sweet! A good choice."
353
- else:
354
- result += "\n\nThis watermelon is very sweet! Excellent choice!"
355
-
356
- return result
357
- else:
358
- return "Error: Could not predict sweetness. Please try again with different inputs."
359
-
360
- except Exception as e:
361
- import traceback
362
- error_msg = f"Error: {str(e)}\n\n"
363
- error_msg += traceback.format_exc()
364
- print(f"\033[91mERR!\033[0m: {error_msg}")
365
- return error_msg
366
 
367
  def create_app(model_path):
368
  """Create and launch the Gradio interface"""
 
6
  import torchaudio
7
  import torchvision
8
 
9
+ # # Import Gradio Spaces GPU decorator
10
+ # try:
11
+ # from gradio import spaces
12
+ # HAS_SPACES = True
13
+ # print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
14
+ # except ImportError:
15
+ # HAS_SPACES = False
16
+ # print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
17
 
18
  # Add parent directory to path to import preprocess functions
19
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
78
  # Similarly for images, but let's import the original one
79
  from preprocess import process_image_data
80
 
 
 
81
  # Using the decorator directly on the function definition
82
+ @spaces.GPU
83
+ def predict_sweetness(audio, image, model_path):
84
+ """Function with GPU acceleration"""
85
+ try:
86
+ # Now check CUDA availability inside the GPU-decorated function
87
+ if torch.cuda.is_available():
88
+ device = torch.device("cuda")
89
+ print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
90
+ else:
91
+ device = torch.device("cpu")
92
+ print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
93
+
94
+ # Load model inside the function to ensure it's on the correct device
95
+ model = WatermelonModel().to(device)
96
+ model.load_state_dict(torch.load(model_path, map_location=device))
97
+ model.eval()
98
+ print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
99
+
100
+ # Debug information about input types
101
+ print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
102
+ print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
103
+ print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
104
+ if isinstance(image, np.ndarray):
105
+ print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
106
+
107
+ # Handle different audio input formats
108
+ if isinstance(audio, tuple) and len(audio) == 2:
109
+ # Standard Gradio format: (sample_rate, audio_data)
110
+ sample_rate, audio_data = audio
111
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
112
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
113
+ elif isinstance(audio, tuple) and len(audio) > 2:
114
+ # Sometimes Gradio returns (sample_rate, audio_data, other_info...)
115
+ sample_rate, audio_data = audio[0], audio[-1]
116
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
117
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
118
+ elif isinstance(audio, str):
119
+ # Direct path to audio file
120
+ audio_data, sample_rate = torchaudio.load(audio)
121
+ print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
122
+ else:
123
+ return f"Error: Unsupported audio format. Got {type(audio)}"
124
+
125
+ # Create a temporary file path for the audio and image
126
+ temp_dir = "temp"
127
+ os.makedirs(temp_dir, exist_ok=True)
128
+
129
+ temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
130
+ temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
131
+
132
+ # Import necessary libraries
133
+ from PIL import Image
134
+
135
+ # Audio handling - direct processing from the data in memory
136
+ if isinstance(audio_data, np.ndarray):
137
+ # Convert numpy array to tensor
138
+ print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
139
+ audio_tensor = torch.tensor(audio_data).float()
140
+
141
+ # Handle different audio dimensions
142
+ if audio_data.ndim == 1:
143
+ # Single channel audio
144
+ audio_tensor = audio_tensor.unsqueeze(0)
145
+ elif audio_data.ndim == 2:
146
+ # Ensure channels are first dimension
147
+ if audio_data.shape[0] > audio_data.shape[1]:
148
+ # More rows than columns, probably (samples, channels)
149
+ audio_tensor = torch.tensor(audio_data.T).float()
150
+ else:
151
+ # Already a tensor
152
+ audio_tensor = audio_data.float()
153
+
154
+ print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
155
+
156
+ # Skip saving/loading and process directly
157
+ mfcc = app_process_audio_data(audio_tensor, sample_rate)
158
+ print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
159
+
160
+ # Image handling
161
+ if isinstance(image, np.ndarray):
162
+ print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
163
+ pil_image = Image.fromarray(image)
164
+ pil_image.save(temp_image_path)
165
+ print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
166
+ elif isinstance(image, str):
167
+ # If image is already a path
168
+ temp_image_path = image
169
+ print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
170
+ else:
171
+ return f"Error: Unsupported image format. Got {type(image)}"
172
+
173
+ # Process image
174
+ print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
175
+ image_tensor = torchvision.io.read_image(temp_image_path)
176
+ print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
177
+ image_tensor = image_tensor.float()
178
+ processed_image = process_image_data(image_tensor)
179
+ print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
180
+
181
+ # Add batch dimension for inference and move to device
182
+ if mfcc is not None:
183
+ mfcc = mfcc.unsqueeze(0).to(device)
184
+ print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
185
+
186
+ if processed_image is not None:
187
+ processed_image = processed_image.unsqueeze(0).to(device)
188
+ print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
189
+
190
+ # Run inference
191
+ print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
192
+ if mfcc is not None and processed_image is not None:
193
+ with torch.no_grad():
194
+ sweetness = model(mfcc, processed_image)
195
+ print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}")
196
+ else:
197
+ return "Error: Failed to process inputs. Please check the debug logs."
198
+
199
+ # Format the result
200
+ if sweetness is not None:
201
+ result = f"Predicted Sweetness: {sweetness.item():.2f}/13"
202
+
203
+ # Add a qualitative description
204
+ if sweetness.item() < 9:
205
+ result += "\n\nThis watermelon is not very sweet. You might want to choose another one."
206
+ elif sweetness.item() < 10:
207
+ result += "\n\nThis watermelon has moderate sweetness."
208
+ elif sweetness.item() < 11:
209
+ result += "\n\nThis watermelon is sweet! A good choice."
210
  else:
211
+ result += "\n\nThis watermelon is very sweet! Excellent choice!"
 
 
 
 
212
 
213
+ return result
214
+ else:
215
+ return "Error: Could not predict sweetness. Please try again with different inputs."
216
+
217
+ except Exception as e:
218
+ import traceback
219
+ error_msg = f"Error: {str(e)}\n\n"
220
+ error_msg += traceback.format_exc()
221
+ print(f"\033[91mERR!\033[0m: {error_msg}")
222
+ return error_msg
 
 
 
 
 
 
 
 
 
 
223
 
224
  print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
225
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  def create_app(model_path):
228
  """Create and launch the Gradio interface"""