Spaces:
Running
Running
Developer
commited on
Commit
·
335b2c9
1
Parent(s):
0ee3594
Clean recursive model detection with proper file counting
Browse files- Rewrote load_model function with clean recursive file detection
- Uses os.walk to count files in all subdirectories recursively
- Logs file counts for debugging (Found X video files, Y audio files)
- Requires >10 files in each directory (reasonable threshold)
- Checks both downloaded_models and traditional paths
- Should properly detect the 82 video files and 41 audio files reported by model-status
- app_main.py +37 -63
app_main.py
CHANGED
@@ -295,77 +295,50 @@ class OmniAvatarAPI:
|
|
295 |
logger.info("Initialized with robust TTS system")
|
296 |
|
297 |
def load_model(self):
|
298 |
-
"""Load
|
299 |
try:
|
300 |
-
# Check if models
|
301 |
-
# Check both traditional and downloaded model paths
|
302 |
downloaded_video = "./downloaded_models/video"
|
303 |
downloaded_audio = "./downloaded_models/audio"
|
304 |
-
|
305 |
-
# Check downloaded models first
|
306 |
if os.path.exists(downloaded_video) and os.path.exists(downloaded_audio):
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
# Set as loaded but in limited mode
|
332 |
-
self.model_loaded = False # Video generation disabled
|
333 |
-
return True # But app can still run
|
334 |
-
else:
|
335 |
self.model_loaded = True
|
336 |
-
logger.info("SUCCESS:
|
337 |
return True
|
338 |
-
|
339 |
-
|
340 |
-
logger.
|
341 |
-
logger.info("TIP: Continuing in TTS-only mode")
|
342 |
self.model_loaded = False
|
343 |
-
return True #
|
344 |
-
|
345 |
-
async def download_file(self, url: str, suffix: str = "") -> str:
|
346 |
-
"""Download file from URL and save to temporary location"""
|
347 |
-
try:
|
348 |
-
async with aiohttp.ClientSession() as session:
|
349 |
-
async with session.get(str(url)) as response:
|
350 |
-
if response.status != 200:
|
351 |
-
raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
|
352 |
-
|
353 |
-
content = await response.read()
|
354 |
-
|
355 |
-
# Create temporary file
|
356 |
-
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
357 |
-
temp_file.write(content)
|
358 |
-
temp_file.close()
|
359 |
-
|
360 |
-
return temp_file.name
|
361 |
-
|
362 |
-
except aiohttp.ClientError as e:
|
363 |
-
logger.error(f"Network error downloading {url}: {e}")
|
364 |
-
raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
|
365 |
except Exception as e:
|
366 |
-
logger.error(f"Error
|
367 |
-
|
368 |
-
|
369 |
def validate_audio_url(self, url: str) -> bool:
|
370 |
"""Validate if URL is likely an audio file"""
|
371 |
try:
|
@@ -1085,5 +1058,6 @@ if __name__ == "__main__":
|
|
1085 |
|
1086 |
|
1087 |
|
|
|
1088 |
|
1089 |
|
|
|
295 |
logger.info("Initialized with robust TTS system")
|
296 |
|
297 |
def load_model(self):
|
298 |
+
"""Load and detect OmniAvatar models"""
|
299 |
try:
|
300 |
+
# Check if models were downloaded via /download-models
|
|
|
301 |
downloaded_video = "./downloaded_models/video"
|
302 |
downloaded_audio = "./downloaded_models/audio"
|
303 |
+
|
|
|
304 |
if os.path.exists(downloaded_video) and os.path.exists(downloaded_audio):
|
305 |
+
# Count files recursively in all subdirectories
|
306 |
+
video_files = 0
|
307 |
+
for root, dirs, files in os.walk(downloaded_video):
|
308 |
+
video_files += len(files)
|
309 |
+
|
310 |
+
audio_files = 0
|
311 |
+
for root, dirs, files in os.walk(downloaded_audio):
|
312 |
+
audio_files += len(files)
|
313 |
+
|
314 |
+
logger.info(f"Found {video_files} video files, {audio_files} audio files")
|
315 |
+
|
316 |
+
if video_files > 10 and audio_files > 10:
|
317 |
+
self.model_loaded = True
|
318 |
+
logger.info(f"? SUCCESS: Downloaded models loaded - Video: {video_files}, Audio: {audio_files}")
|
319 |
+
return True
|
320 |
+
|
321 |
+
# Fallback: Check traditional OmniAvatar paths
|
322 |
+
traditional_paths = [
|
323 |
+
"./pretrained_models/Wan2.1-T2V-14B",
|
324 |
+
"./pretrained_models/OmniAvatar-14B",
|
325 |
+
"./pretrained_models/wav2vec2-base-960h"
|
326 |
+
]
|
327 |
+
|
328 |
+
if all(os.path.exists(path) for path in traditional_paths):
|
|
|
|
|
|
|
|
|
329 |
self.model_loaded = True
|
330 |
+
logger.info("? SUCCESS: Traditional OmniAvatar models found")
|
331 |
return True
|
332 |
+
|
333 |
+
# No models found
|
334 |
+
logger.warning("?? WARNING: No models found (neither downloaded nor traditional)")
|
|
|
335 |
self.model_loaded = False
|
336 |
+
return True # App can still run in TTS-only mode
|
337 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
except Exception as e:
|
339 |
+
logger.error(f"? Error in model detection: {e}")
|
340 |
+
self.model_loaded = False
|
341 |
+
return True
|
342 |
def validate_audio_url(self, url: str) -> bool:
|
343 |
"""Validate if URL is likely an audio file"""
|
344 |
try:
|
|
|
1058 |
|
1059 |
|
1060 |
|
1061 |
+
|
1062 |
|
1063 |
|