import os import torch import tempfile import gradio as gr from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, HttpUrl import subprocess import json from pathlib import Path import logging import requests from urllib.parse import urlparse from PIL import Image import io from typing import Optional import aiohttp import asyncio from dotenv import load_dotenv # Load environment variables load_dotenv() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="OmniAvatar-14B API with ElevenLabs", version="1.0.0") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files for serving generated videos app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") def get_video_url(output_path: str) -> str: """Convert local file path to accessible URL""" try: from pathlib import Path filename = Path(output_path).name # For HuggingFace Spaces, construct the URL base_url = "https://bravedims-ai-avatar-chat.hf.space" video_url = f"{base_url}/outputs/{filename}" logger.info(f"Generated video URL: {video_url}") return video_url except Exception as e: logger.error(f"Error creating video URL: {e}") return output_path # Fallback to original path # Pydantic models for request/response class GenerateRequest(BaseModel): prompt: str text_to_speech: Optional[str] = None # Text to convert to speech elevenlabs_audio_url: Optional[HttpUrl] = None # Direct audio URL voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Default ElevenLabs voice image_url: Optional[HttpUrl] = None guidance_scale: float = 5.0 audio_scale: float = 3.0 num_steps: int = 30 sp_size: int = 1 tea_cache_l1_thresh: Optional[float] = None class GenerateResponse(BaseModel): message: str output_path: str processing_time: float audio_generated: bool = False # Import the robust TTS client as fallback from robust_tts_client import RobustTTSClient class ElevenLabsClient: def __init__(self, api_key: str = None): self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY", "sk_c7a0b115cd48fc026226158c5ac87755b063c802ad892de6") self.base_url = "https://api.elevenlabs.io/v1" # Initialize fallback TTS client self.fallback_tts = RobustTTSClient() async def text_to_speech(self, text: str, voice_id: str = "21m00Tcm4TlvDq8ikWAM") -> str: """Convert text to speech using ElevenLabs with fallback to robust TTS""" logger.info(f"Generating speech from text: {text[:50]}...") logger.info(f"Voice ID: {voice_id}") # Try ElevenLabs first try: return await self._elevenlabs_tts(text, voice_id) except Exception as e: logger.warning(f"ElevenLabs TTS failed: {e}") logger.info("Falling back to robust TTS client...") try: return await self.fallback_tts.text_to_speech(text, voice_id) except Exception as fallback_error: logger.error(f"Fallback TTS also failed: {fallback_error}") raise HTTPException(status_code=500, detail=f"All TTS methods failed. ElevenLabs: {e}, Fallback: {fallback_error}") async def _elevenlabs_tts(self, text: str, voice_id: str) -> str: """Internal method for ElevenLabs API call""" url = f"{self.base_url}/text-to-speech/{voice_id}" headers = { "Accept": "audio/mpeg", "Content-Type": "application/json", "xi-api-key": self.api_key } data = { "text": text, "model_id": "eleven_monolingual_v1", "voice_settings": { "stability": 0.5, "similarity_boost": 0.5 } } logger.info(f"Calling ElevenLabs API: {url}") logger.info(f"API Key configured: {'Yes' if self.api_key else 'No'}") timeout = aiohttp.ClientTimeout(total=30) # 30 second timeout async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, headers=headers, json=data) as response: logger.info(f"ElevenLabs response status: {response.status}") if response.status != 200: error_text = await response.text() logger.error(f"ElevenLabs API error: {response.status} - {error_text}") if response.status == 401: raise Exception(f"ElevenLabs authentication failed. Please check API key.") elif response.status == 429: raise Exception(f"ElevenLabs rate limit exceeded. Please try again later.") elif response.status == 422: raise Exception(f"ElevenLabs request validation failed: {error_text}") else: raise Exception(f"ElevenLabs API error: {response.status} - {error_text}") audio_content = await response.read() if not audio_content: raise Exception("ElevenLabs returned empty audio content") logger.info(f"Received {len(audio_content)} bytes of audio from ElevenLabs") # Save to temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') temp_file.write(audio_content) temp_file.close() logger.info(f"Generated speech audio: {temp_file.name}") return temp_file.name class OmniAvatarAPI: def __init__(self): self.model_loaded = False self.device = "cuda" if torch.cuda.is_available() else "cpu" self.elevenlabs_client = ElevenLabsClient() logger.info(f"Using device: {self.device}") logger.info(f"ElevenLabs API Key configured: {'Yes' if self.elevenlabs_client.api_key else 'No'}") def load_model(self): """Load the OmniAvatar model""" try: # Check if models are downloaded model_paths = [ "./pretrained_models/Wan2.1-T2V-14B", "./pretrained_models/OmniAvatar-14B", "./pretrained_models/wav2vec2-base-960h" ] for path in model_paths: if not os.path.exists(path): logger.error(f"Model path not found: {path}") return False self.model_loaded = True logger.info("Models loaded successfully") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") return False async def download_file(self, url: str, suffix: str = "") -> str: """Download file from URL and save to temporary location""" try: async with aiohttp.ClientSession() as session: async with session.get(str(url)) as response: if response.status != 200: raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}") content = await response.read() # Create temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) temp_file.write(content) temp_file.close() return temp_file.name except aiohttp.ClientError as e: logger.error(f"Network error downloading {url}: {e}") raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}") except Exception as e: logger.error(f"Error downloading file from {url}: {e}") raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") def validate_audio_url(self, url: str) -> bool: """Validate if URL is likely an audio file""" try: parsed = urlparse(url) # Check for common audio file extensions or ElevenLabs patterns audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac'] is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions) is_elevenlabs = 'elevenlabs' in parsed.netloc.lower() return is_audio_ext or is_elevenlabs or 'audio' in url.lower() except: return False def validate_image_url(self, url: str) -> bool: """Validate if URL is likely an image file""" try: parsed = urlparse(url) image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'] return any(parsed.path.lower().endswith(ext) for ext in image_extensions) except: return False async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool]: """Generate avatar video from prompt and audio/text""" import time start_time = time.time() audio_generated = False try: # Determine audio source audio_path = None if request.text_to_speech: # Generate speech from text using ElevenLabs logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") audio_path = await self.elevenlabs_client.text_to_speech( request.text_to_speech, request.voice_id or "21m00Tcm4TlvDq8ikWAM" ) audio_generated = True elif request.elevenlabs_audio_url: # Download audio from provided URL logger.info(f"Downloading audio from URL: {request.elevenlabs_audio_url}") if not self.validate_audio_url(str(request.elevenlabs_audio_url)): logger.warning(f"Audio URL may not be valid: {request.elevenlabs_audio_url}") audio_path = await self.download_file(str(request.elevenlabs_audio_url), ".mp3") else: raise HTTPException( status_code=400, detail="Either text_to_speech or elevenlabs_audio_url must be provided" ) # Download image if provided image_path = None if request.image_url: logger.info(f"Downloading image from URL: {request.image_url}") if not self.validate_image_url(str(request.image_url)): logger.warning(f"Image URL may not be valid: {request.image_url}") # Determine image extension from URL or default to .jpg parsed = urlparse(str(request.image_url)) ext = os.path.splitext(parsed.path)[1] or ".jpg" image_path = await self.download_file(str(request.image_url), ext) # Create temporary input file for inference with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: if image_path: input_line = f"{request.prompt}@@{image_path}@@{audio_path}" else: input_line = f"{request.prompt}@@@@{audio_path}" f.write(input_line) temp_input_file = f.name # Prepare inference command cmd = [ "python", "-m", "torch.distributed.run", "--standalone", f"--nproc_per_node={request.sp_size}", "scripts/inference.py", "--config", "configs/inference.yaml", "--input_file", temp_input_file, "--guidance_scale", str(request.guidance_scale), "--audio_scale", str(request.audio_scale), "--num_steps", str(request.num_steps) ] if request.tea_cache_l1_thresh: cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)]) logger.info(f"Running inference with command: {' '.join(cmd)}") # Run inference result = subprocess.run(cmd, capture_output=True, text=True) # Clean up temporary files os.unlink(temp_input_file) os.unlink(audio_path) if image_path: os.unlink(image_path) if result.returncode != 0: logger.error(f"Inference failed: {result.stderr}") raise Exception(f"Inference failed: {result.stderr}") # Find output video file output_dir = "./outputs" if os.path.exists(output_dir): video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))] if video_files: # Return the most recent video file video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True) output_path = os.path.join(output_dir, video_files[0]) processing_time = time.time() - start_time return output_path, processing_time, audio_generated raise Exception("No output video generated") except Exception as e: # Clean up any temporary files in case of error try: if 'audio_path' in locals() and audio_path and os.path.exists(audio_path): os.unlink(audio_path) if 'image_path' in locals() and image_path and os.path.exists(image_path): os.unlink(image_path) if 'temp_input_file' in locals() and os.path.exists(temp_input_file): os.unlink(temp_input_file) except: pass logger.error(f"Generation error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Initialize API omni_api = OmniAvatarAPI() @app.on_event("startup") async def startup_event(): """Load model on startup""" success = omni_api.load_model() if not success: logger.warning("Model loading failed on startup") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "model_loaded": omni_api.model_loaded, "device": omni_api.device, "supports_elevenlabs": True, "supports_image_urls": True, "supports_text_to_speech": True, "elevenlabs_api_configured": bool(omni_api.elevenlabs_client.api_key), "fallback_tts_available": True } @app.post("/generate", response_model=GenerateResponse) async def generate_avatar(request: GenerateRequest): """Generate avatar video from prompt, text/audio, and optional image URL""" if not omni_api.model_loaded: raise HTTPException(status_code=503, detail="Model not loaded") logger.info(f"Generating avatar with prompt: {request.prompt}") if request.text_to_speech: logger.info(f"Text to speech: {request.text_to_speech[:100]}...") logger.info(f"Voice ID: {request.voice_id}") if request.elevenlabs_audio_url: logger.info(f"Audio URL: {request.elevenlabs_audio_url}") if request.image_url: logger.info(f"Image URL: {request.image_url}") try: output_path, processing_time, audio_generated = await omni_api.generate_avatar(request) return GenerateResponse( message="Avatar generation completed successfully", output_path=get_video_url(output_path), processing_time=processing_time, audio_generated=audio_generated ) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") # Enhanced Gradio interface with text-to-speech option def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps): """Gradio interface wrapper with text-to-speech support""" if not omni_api.model_loaded: return "Error: Model not loaded" try: # Create request object request_data = { "prompt": prompt, "guidance_scale": guidance_scale, "audio_scale": audio_scale, "num_steps": int(num_steps) } # Add audio source if text_to_speech and text_to_speech.strip(): request_data["text_to_speech"] = text_to_speech request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM" elif audio_url and audio_url.strip(): request_data["elevenlabs_audio_url"] = audio_url else: return "Error: Please provide either text to speech or audio URL" if image_url and image_url.strip(): request_data["image_url"] = image_url request = GenerateRequest(**request_data) # Run async function in sync context loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) output_path, processing_time, audio_generated = loop.run_until_complete(omni_api.generate_avatar(request)) loop.close() return output_path except Exception as e: logger.error(f"Gradio generation error: {e}") return f"Error: {str(e)}" # Updated Gradio interface with text-to-speech support iface = gr.Interface( fn=gradio_generate, inputs=[ gr.Textbox( label="Prompt", placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')", lines=2 ), gr.Textbox( label="Text to Speech", placeholder="Enter text to convert to speech using ElevenLabs", lines=3, info="This will be converted to speech automatically" ), gr.Textbox( label="OR Audio URL", placeholder="https://api.elevenlabs.io/v1/text-to-speech/...", info="Direct URL to audio file (alternative to text-to-speech)" ), gr.Textbox( label="Image URL (Optional)", placeholder="https://example.com/image.jpg", info="Direct URL to reference image (JPG, PNG, etc.)" ), gr.Dropdown( choices=["21m00Tcm4TlvDq8ikWAM", "pNInz6obpgDQGcFmaJgB", "EXAVITQu4vr4xnSDxMaL"], value="21m00Tcm4TlvDq8ikWAM", label="ElevenLabs Voice ID", info="Choose voice for text-to-speech" ), gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"), gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"), gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended") ], outputs=gr.Video(label="Generated Avatar Video"), title="🎭 OmniAvatar-14B with ElevenLabs TTS (+ Fallback)", description=""" Generate avatar videos with lip-sync from text prompts and speech. **Features:** - ✅ **Text-to-Speech**: Enter text to generate speech automatically - ✅ **ElevenLabs Integration**: High-quality voice synthesis - ✅ **Fallback TTS**: Robust backup system if ElevenLabs fails - ✅ **Audio URL Support**: Use pre-generated audio files - ✅ **Image URL Support**: Reference images for character appearance - ✅ **Customizable Parameters**: Fine-tune generation quality **Usage:** 1. Enter a character description in the prompt 2. **Either** enter text for speech generation **OR** provide an audio URL 3. Optionally add a reference image URL 4. Choose voice and adjust parameters 5. Generate your avatar video! **Tips:** - Use guidance scale 4-6 for best prompt following - Increase audio scale for better lip-sync - Clear, descriptive prompts work best - If ElevenLabs fails, fallback TTS will be used automatically """, examples=[ [ "A professional teacher explaining a mathematical concept with clear gestures", "Hello students! Today we're going to learn about calculus and how derivatives work in real life.", "", "", "21m00Tcm4TlvDq8ikWAM", 5.0, 3.5, 30 ], [ "A friendly presenter speaking confidently to an audience", "Welcome everyone to our presentation on artificial intelligence and its applications!", "", "", "pNInz6obpgDQGcFmaJgB", 5.5, 4.0, 35 ] ] ) # Mount Gradio app app = gr.mount_gradio_app(app, iface, path="/gradio") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)