Developer commited on
Commit
335b2c9
·
1 Parent(s): 0ee3594

Clean recursive model detection with proper file counting

Browse files

- Rewrote load_model function with clean recursive file detection
- Uses os.walk to count files in all subdirectories recursively
- Logs file counts for debugging (Found X video files, Y audio files)
- Requires >10 files in each directory (reasonable threshold)
- Checks both downloaded_models and traditional paths
- Should properly detect the 82 video files and 41 audio files reported by model-status

Files changed (1) hide show
  1. app_main.py +37 -63
app_main.py CHANGED
@@ -295,77 +295,50 @@ class OmniAvatarAPI:
295
  logger.info("Initialized with robust TTS system")
296
 
297
  def load_model(self):
298
- """Load the OmniAvatar model - now more flexible"""
299
  try:
300
- # Check if models are downloaded (but don't require them)
301
- # Check both traditional and downloaded model paths
302
  downloaded_video = "./downloaded_models/video"
303
  downloaded_audio = "./downloaded_models/audio"
304
-
305
- # Check downloaded models first
306
  if os.path.exists(downloaded_video) and os.path.exists(downloaded_audio):
307
- video_files = len([f for f in os.listdir(downloaded_video) if os.path.isfile(os.path.join(downloaded_video, f))]) if os.path.isdir(downloaded_video) else 0
308
- audio_files = len([f for f in os.listdir(downloaded_audio) if os.path.isfile(os.path.join(downloaded_audio, f))]) if os.path.isdir(downloaded_audio) else 0
309
- if video_files > 5 and audio_files > 5:
310
- missing_models.append(path)
311
-
312
- # Check downloaded models first
313
- if os.path.exists("./downloaded_models/video") and os.path.exists("./downloaded_models/audio"):
314
- try:
315
- video_files = len(os.listdir("./downloaded_models/video"))
316
- audio_files = len(os.listdir("./downloaded_models/audio"))
317
- if video_files > 5 and audio_files > 5:
318
- self.model_loaded = True
319
- logger.info(f"SUCCESS: Downloaded models loaded - Video: {video_files}, Audio: {audio_files}")
320
- return True
321
- except:
322
- pass
323
-
324
- if missing_models:
325
- logger.warning("WARNING: Some OmniAvatar models not found:")
326
- for model in missing_models:
327
- logger.warning(f" - {model}")
328
- logger.info("TIP: App will run in TTS-only mode (no video generation)")
329
- logger.info("TIP: To enable full avatar generation, download the required models")
330
-
331
- # Set as loaded but in limited mode
332
- self.model_loaded = False # Video generation disabled
333
- return True # But app can still run
334
- else:
335
  self.model_loaded = True
336
- logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
337
  return True
338
-
339
- except Exception as e:
340
- logger.error(f"Error checking models: {str(e)}")
341
- logger.info("TIP: Continuing in TTS-only mode")
342
  self.model_loaded = False
343
- return True # Continue running
344
-
345
- async def download_file(self, url: str, suffix: str = "") -> str:
346
- """Download file from URL and save to temporary location"""
347
- try:
348
- async with aiohttp.ClientSession() as session:
349
- async with session.get(str(url)) as response:
350
- if response.status != 200:
351
- raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
352
-
353
- content = await response.read()
354
-
355
- # Create temporary file
356
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
357
- temp_file.write(content)
358
- temp_file.close()
359
-
360
- return temp_file.name
361
-
362
- except aiohttp.ClientError as e:
363
- logger.error(f"Network error downloading {url}: {e}")
364
- raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
365
  except Exception as e:
366
- logger.error(f"Error downloading file from {url}: {e}")
367
- raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
368
-
369
  def validate_audio_url(self, url: str) -> bool:
370
  """Validate if URL is likely an audio file"""
371
  try:
@@ -1085,5 +1058,6 @@ if __name__ == "__main__":
1085
 
1086
 
1087
 
 
1088
 
1089
 
 
295
  logger.info("Initialized with robust TTS system")
296
 
297
  def load_model(self):
298
+ """Load and detect OmniAvatar models"""
299
  try:
300
+ # Check if models were downloaded via /download-models
 
301
  downloaded_video = "./downloaded_models/video"
302
  downloaded_audio = "./downloaded_models/audio"
303
+
 
304
  if os.path.exists(downloaded_video) and os.path.exists(downloaded_audio):
305
+ # Count files recursively in all subdirectories
306
+ video_files = 0
307
+ for root, dirs, files in os.walk(downloaded_video):
308
+ video_files += len(files)
309
+
310
+ audio_files = 0
311
+ for root, dirs, files in os.walk(downloaded_audio):
312
+ audio_files += len(files)
313
+
314
+ logger.info(f"Found {video_files} video files, {audio_files} audio files")
315
+
316
+ if video_files > 10 and audio_files > 10:
317
+ self.model_loaded = True
318
+ logger.info(f"? SUCCESS: Downloaded models loaded - Video: {video_files}, Audio: {audio_files}")
319
+ return True
320
+
321
+ # Fallback: Check traditional OmniAvatar paths
322
+ traditional_paths = [
323
+ "./pretrained_models/Wan2.1-T2V-14B",
324
+ "./pretrained_models/OmniAvatar-14B",
325
+ "./pretrained_models/wav2vec2-base-960h"
326
+ ]
327
+
328
+ if all(os.path.exists(path) for path in traditional_paths):
 
 
 
 
329
  self.model_loaded = True
330
+ logger.info("? SUCCESS: Traditional OmniAvatar models found")
331
  return True
332
+
333
+ # No models found
334
+ logger.warning("?? WARNING: No models found (neither downloaded nor traditional)")
 
335
  self.model_loaded = False
336
+ return True # App can still run in TTS-only mode
337
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  except Exception as e:
339
+ logger.error(f"? Error in model detection: {e}")
340
+ self.model_loaded = False
341
+ return True
342
  def validate_audio_url(self, url: str) -> bool:
343
  """Validate if URL is likely an audio file"""
344
  try:
 
1058
 
1059
 
1060
 
1061
+
1062
 
1063