bravedims commited on
Commit
bd1f2b1
·
1 Parent(s): 763d982

Deploy OmniAvatar-14B with ElevenLabs TTS integration to Hugging Face Spaces

Browse files
Dockerfile ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # Use NVIDIA PyTorch base image for GPU support
3
+ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
4
+
5
+ # Create user as required by HF Spaces
6
+ RUN useradd -m -u 1000 user
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y \
10
+ git \
11
+ wget \
12
+ curl \
13
+ libgl1-mesa-glx \
14
+ libglib2.0-0 \
15
+ libsm6 \
16
+ libxext6 \
17
+ libxrender-dev \
18
+ libgomp1 \
19
+ libgoogle-perftools4 \
20
+ libtcmalloc-minimal4 \
21
+ ffmpeg \
22
+ && apt-get clean \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ # Switch to user
26
+ USER user
27
+
28
+ # Set environment variables for user
29
+ ENV PATH="/home/user/.local/bin:$PATH"
30
+ ENV PYTHONPATH=/app
31
+ ENV GRADIO_SERVER_NAME=0.0.0.0
32
+ ENV GRADIO_SERVER_PORT=7860
33
+
34
+ # Set working directory
35
+ WORKDIR /app
36
+
37
+ # Copy requirements and install Python dependencies
38
+ COPY --chown=user ./requirements.txt requirements.txt
39
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
40
+
41
+ # Copy application code
42
+ COPY --chown=user . /app
43
+
44
+ # Create necessary directories
45
+ RUN mkdir -p pretrained_models outputs
46
+
47
+ # Expose port (required by HF Spaces to be 7860)
48
+ EXPOSE 7860
49
+
50
+ # Start the application
51
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,11 +1,71 @@
1
- ---
2
  title: AI Avatar Chat
3
- emoji: 🐢
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
  title: AI Avatar Chat
3
+ emoji: 🎭
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
+ suggested_hardware: t4-medium
10
+ suggested_storage: medium
11
+ ---
12
+
13
+ # 🎭 OmniAvatar-14B with ElevenLabs TTS
14
+
15
+ An advanced AI avatar generation system that creates realistic talking avatars from text prompts and speech. This space combines the power of OmniAvatar-14B with ElevenLabs text-to-speech for seamless avatar creation.
16
+
17
+ ## ✨ Features
18
+
19
+ - **🎯 Text-to-Avatar Generation**: Generate avatars from descriptive text prompts
20
+ - **🗣️ ElevenLabs Integration**: High-quality text-to-speech synthesis
21
+ - **🎵 Audio URL Support**: Use pre-generated audio files
22
+ - **🖼️ Image Reference Support**: Guide avatar appearance with reference images
23
+ - **⚡ Real-time Processing**: Fast generation with GPU acceleration
24
+ - **🎨 Customizable Parameters**: Fine-tune generation quality and lip-sync
25
+
26
+ ## 🚀 How to Use
27
+
28
+ 1. **Enter a Prompt**: Describe the character's behavior and appearance
29
+ 2. **Choose Audio Source**:
30
+ - Enter text for automatic speech generation
31
+ - OR provide a direct audio URL
32
+ 3. **Optional**: Add a reference image URL
33
+ 4. **Customize**: Adjust voice, guidance scale, and generation parameters
34
+ 5. **Generate**: Create your avatar video!
35
+
36
+ ## 🛠️ Parameters
37
+
38
+ - **Guidance Scale** (4-6 recommended): Controls how closely the model follows your prompt
39
+ - **Audio Scale** (3-5 recommended): Higher values improve lip-sync accuracy
40
+ - **Number of Steps** (20-50 recommended): More steps = higher quality, longer processing time
41
+
42
+ ## 📝 Example Prompts
43
+
44
+ - "A professional teacher explaining a mathematical concept with clear gestures"
45
+ - "A friendly presenter speaking confidently to an audience"
46
+ - "A news anchor delivering the morning headlines with professional demeanor"
47
+
48
+ ## 🔧 Technical Details
49
+
50
+ - **Model**: OmniAvatar-14B for video generation
51
+ - **TTS**: ElevenLabs API for high-quality speech synthesis
52
+ - **Framework**: FastAPI + Gradio interface
53
+ - **GPU**: Optimized for T4 and higher
54
+
55
+ ## 🎮 API Endpoints
56
+
57
+ - `GET /health` - Check system status
58
+ - `POST /generate` - Generate avatar video
59
+ - `/gradio` - Interactive web interface
60
+
61
+ ## 🔐 Environment Variables
62
+
63
+ The space uses ElevenLabs for text-to-speech. For optimal performance, configure your ElevenLabs API key as a secret.
64
+
65
+ ## 📄 License
66
+
67
+ Apache 2.0 - See LICENSE file for details
68
+
69
  ---
70
 
71
+ *Powered by OmniAvatar-14B and ElevenLabs TTS*
app.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import tempfile
4
+ import gradio as gr
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel, HttpUrl
8
+ import subprocess
9
+ import json
10
+ from pathlib import Path
11
+ import logging
12
+ import requests
13
+ from urllib.parse import urlparse
14
+ from PIL import Image
15
+ import io
16
+ from typing import Optional
17
+ import aiohttp
18
+ import asyncio
19
+ from dotenv import load_dotenv
20
+
21
+ # Load environment variables
22
+ load_dotenv()
23
+
24
+ # Set up logging
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+ app = FastAPI(title="OmniAvatar-14B API with ElevenLabs", version="1.0.0")
29
+
30
+ # Add CORS middleware
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ # Pydantic models for request/response
40
+ class GenerateRequest(BaseModel):
41
+ prompt: str
42
+ text_to_speech: Optional[str] = None # Text to convert to speech
43
+ elevenlabs_audio_url: Optional[HttpUrl] = None # Direct audio URL
44
+ voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Default ElevenLabs voice
45
+ image_url: Optional[HttpUrl] = None
46
+ guidance_scale: float = 5.0
47
+ audio_scale: float = 3.0
48
+ num_steps: int = 30
49
+ sp_size: int = 1
50
+ tea_cache_l1_thresh: Optional[float] = None
51
+
52
+ class GenerateResponse(BaseModel):
53
+ message: str
54
+ output_path: str
55
+ processing_time: float
56
+ audio_generated: bool = False
57
+
58
+ class ElevenLabsClient:
59
+ def __init__(self, api_key: str = None):
60
+ self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY", "sk_c7a0b115cd48fc026226158c5ac87755b063c802ad892de6")
61
+ self.base_url = "https://api.elevenlabs.io/v1"
62
+
63
+ async def text_to_speech(self, text: str, voice_id: str = "21m00Tcm4TlvDq8ikWAM") -> str:
64
+ """Convert text to speech using ElevenLabs and return temporary file path"""
65
+ url = f"{self.base_url}/text-to-speech/{voice_id}"
66
+
67
+ headers = {
68
+ "Accept": "audio/mpeg",
69
+ "Content-Type": "application/json",
70
+ "xi-api-key": self.api_key
71
+ }
72
+
73
+ data = {
74
+ "text": text,
75
+ "model_id": "eleven_monolingual_v1",
76
+ "voice_settings": {
77
+ "stability": 0.5,
78
+ "similarity_boost": 0.5
79
+ }
80
+ }
81
+
82
+ try:
83
+ async with aiohttp.ClientSession() as session:
84
+ async with session.post(url, headers=headers, json=data) as response:
85
+ if response.status != 200:
86
+ error_text = await response.text()
87
+ raise HTTPException(
88
+ status_code=400,
89
+ detail=f"ElevenLabs API error: {response.status} - {error_text}"
90
+ )
91
+
92
+ audio_content = await response.read()
93
+
94
+ # Save to temporary file
95
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
96
+ temp_file.write(audio_content)
97
+ temp_file.close()
98
+
99
+ logger.info(f"Generated speech audio: {temp_file.name}")
100
+ return temp_file.name
101
+
102
+ except aiohttp.ClientError as e:
103
+ logger.error(f"Network error calling ElevenLabs: {e}")
104
+ raise HTTPException(status_code=400, detail=f"Network error calling ElevenLabs: {e}")
105
+ except Exception as e:
106
+ logger.error(f"Error generating speech: {e}")
107
+ raise HTTPException(status_code=500, detail=f"Error generating speech: {e}")
108
+
109
+ class OmniAvatarAPI:
110
+ def __init__(self):
111
+ self.model_loaded = False
112
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ self.elevenlabs_client = ElevenLabsClient()
114
+ logger.info(f"Using device: {self.device}")
115
+ logger.info(f"ElevenLabs API Key configured: {'Yes' if self.elevenlabs_client.api_key else 'No'}")
116
+
117
+ def load_model(self):
118
+ """Load the OmniAvatar model"""
119
+ try:
120
+ # Check if models are downloaded
121
+ model_paths = [
122
+ "./pretrained_models/Wan2.1-T2V-14B",
123
+ "./pretrained_models/OmniAvatar-14B",
124
+ "./pretrained_models/wav2vec2-base-960h"
125
+ ]
126
+
127
+ for path in model_paths:
128
+ if not os.path.exists(path):
129
+ logger.error(f"Model path not found: {path}")
130
+ return False
131
+
132
+ self.model_loaded = True
133
+ logger.info("Models loaded successfully")
134
+ return True
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error loading model: {str(e)}")
138
+ return False
139
+
140
+ async def download_file(self, url: str, suffix: str = "") -> str:
141
+ """Download file from URL and save to temporary location"""
142
+ try:
143
+ async with aiohttp.ClientSession() as session:
144
+ async with session.get(str(url)) as response:
145
+ if response.status != 200:
146
+ raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
147
+
148
+ content = await response.read()
149
+
150
+ # Create temporary file
151
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
152
+ temp_file.write(content)
153
+ temp_file.close()
154
+
155
+ return temp_file.name
156
+
157
+ except aiohttp.ClientError as e:
158
+ logger.error(f"Network error downloading {url}: {e}")
159
+ raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
160
+ except Exception as e:
161
+ logger.error(f"Error downloading file from {url}: {e}")
162
+ raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
163
+
164
+ def validate_audio_url(self, url: str) -> bool:
165
+ """Validate if URL is likely an audio file"""
166
+ try:
167
+ parsed = urlparse(url)
168
+ # Check for common audio file extensions or ElevenLabs patterns
169
+ audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac']
170
+ is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
171
+ is_elevenlabs = 'elevenlabs' in parsed.netloc.lower()
172
+
173
+ return is_audio_ext or is_elevenlabs or 'audio' in url.lower()
174
+ except:
175
+ return False
176
+
177
+ def validate_image_url(self, url: str) -> bool:
178
+ """Validate if URL is likely an image file"""
179
+ try:
180
+ parsed = urlparse(url)
181
+ image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
182
+ return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
183
+ except:
184
+ return False
185
+
186
+ async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool]:
187
+ """Generate avatar video from prompt and audio/text"""
188
+ import time
189
+ start_time = time.time()
190
+ audio_generated = False
191
+
192
+ try:
193
+ # Determine audio source
194
+ audio_path = None
195
+
196
+ if request.text_to_speech:
197
+ # Generate speech from text using ElevenLabs
198
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
199
+ audio_path = await self.elevenlabs_client.text_to_speech(
200
+ request.text_to_speech,
201
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
202
+ )
203
+ audio_generated = True
204
+
205
+ elif request.elevenlabs_audio_url:
206
+ # Download audio from provided URL
207
+ logger.info(f"Downloading audio from URL: {request.elevenlabs_audio_url}")
208
+ if not self.validate_audio_url(str(request.elevenlabs_audio_url)):
209
+ logger.warning(f"Audio URL may not be valid: {request.elevenlabs_audio_url}")
210
+
211
+ audio_path = await self.download_file(str(request.elevenlabs_audio_url), ".mp3")
212
+
213
+ else:
214
+ raise HTTPException(
215
+ status_code=400,
216
+ detail="Either text_to_speech or elevenlabs_audio_url must be provided"
217
+ )
218
+
219
+ # Download image if provided
220
+ image_path = None
221
+ if request.image_url:
222
+ logger.info(f"Downloading image from URL: {request.image_url}")
223
+ if not self.validate_image_url(str(request.image_url)):
224
+ logger.warning(f"Image URL may not be valid: {request.image_url}")
225
+
226
+ # Determine image extension from URL or default to .jpg
227
+ parsed = urlparse(str(request.image_url))
228
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
229
+ image_path = await self.download_file(str(request.image_url), ext)
230
+
231
+ # Create temporary input file for inference
232
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
233
+ if image_path:
234
+ input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
235
+ else:
236
+ input_line = f"{request.prompt}@@@@{audio_path}"
237
+ f.write(input_line)
238
+ temp_input_file = f.name
239
+
240
+ # Prepare inference command
241
+ cmd = [
242
+ "python", "-m", "torch.distributed.run",
243
+ "--standalone", f"--nproc_per_node={request.sp_size}",
244
+ "scripts/inference.py",
245
+ "--config", "configs/inference.yaml",
246
+ "--input_file", temp_input_file,
247
+ "--guidance_scale", str(request.guidance_scale),
248
+ "--audio_scale", str(request.audio_scale),
249
+ "--num_steps", str(request.num_steps)
250
+ ]
251
+
252
+ if request.tea_cache_l1_thresh:
253
+ cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
254
+
255
+ logger.info(f"Running inference with command: {' '.join(cmd)}")
256
+
257
+ # Run inference
258
+ result = subprocess.run(cmd, capture_output=True, text=True)
259
+
260
+ # Clean up temporary files
261
+ os.unlink(temp_input_file)
262
+ os.unlink(audio_path)
263
+ if image_path:
264
+ os.unlink(image_path)
265
+
266
+ if result.returncode != 0:
267
+ logger.error(f"Inference failed: {result.stderr}")
268
+ raise Exception(f"Inference failed: {result.stderr}")
269
+
270
+ # Find output video file
271
+ output_dir = "./outputs"
272
+ if os.path.exists(output_dir):
273
+ video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
274
+ if video_files:
275
+ # Return the most recent video file
276
+ video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
277
+ output_path = os.path.join(output_dir, video_files[0])
278
+ processing_time = time.time() - start_time
279
+ return output_path, processing_time, audio_generated
280
+
281
+ raise Exception("No output video generated")
282
+
283
+ except Exception as e:
284
+ # Clean up any temporary files in case of error
285
+ try:
286
+ if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
287
+ os.unlink(audio_path)
288
+ if 'image_path' in locals() and image_path and os.path.exists(image_path):
289
+ os.unlink(image_path)
290
+ if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
291
+ os.unlink(temp_input_file)
292
+ except:
293
+ pass
294
+
295
+ logger.error(f"Generation error: {str(e)}")
296
+ raise HTTPException(status_code=500, detail=str(e))
297
+
298
+ # Initialize API
299
+ omni_api = OmniAvatarAPI()
300
+
301
+ @app.on_event("startup")
302
+ async def startup_event():
303
+ """Load model on startup"""
304
+ success = omni_api.load_model()
305
+ if not success:
306
+ logger.warning("Model loading failed on startup")
307
+
308
+ @app.get("/health")
309
+ async def health_check():
310
+ """Health check endpoint"""
311
+ return {
312
+ "status": "healthy",
313
+ "model_loaded": omni_api.model_loaded,
314
+ "device": omni_api.device,
315
+ "supports_elevenlabs": True,
316
+ "supports_image_urls": True,
317
+ "supports_text_to_speech": True,
318
+ "elevenlabs_api_configured": bool(omni_api.elevenlabs_client.api_key)
319
+ }
320
+
321
+ @app.post("/generate", response_model=GenerateResponse)
322
+ async def generate_avatar(request: GenerateRequest):
323
+ """Generate avatar video from prompt, text/audio, and optional image URL"""
324
+
325
+ if not omni_api.model_loaded:
326
+ raise HTTPException(status_code=503, detail="Model not loaded")
327
+
328
+ logger.info(f"Generating avatar with prompt: {request.prompt}")
329
+ if request.text_to_speech:
330
+ logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
331
+ logger.info(f"Voice ID: {request.voice_id}")
332
+ if request.elevenlabs_audio_url:
333
+ logger.info(f"Audio URL: {request.elevenlabs_audio_url}")
334
+ if request.image_url:
335
+ logger.info(f"Image URL: {request.image_url}")
336
+
337
+ try:
338
+ output_path, processing_time, audio_generated = await omni_api.generate_avatar(request)
339
+
340
+ return GenerateResponse(
341
+ message="Avatar generation completed successfully",
342
+ output_path=output_path,
343
+ processing_time=processing_time,
344
+ audio_generated=audio_generated
345
+ )
346
+
347
+ except HTTPException:
348
+ raise
349
+ except Exception as e:
350
+ logger.error(f"Unexpected error: {e}")
351
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
352
+
353
+ # Enhanced Gradio interface with text-to-speech option
354
+ def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
355
+ """Gradio interface wrapper with text-to-speech support"""
356
+ if not omni_api.model_loaded:
357
+ return "Error: Model not loaded"
358
+
359
+ try:
360
+ # Create request object
361
+ request_data = {
362
+ "prompt": prompt,
363
+ "guidance_scale": guidance_scale,
364
+ "audio_scale": audio_scale,
365
+ "num_steps": int(num_steps)
366
+ }
367
+
368
+ # Add audio source
369
+ if text_to_speech and text_to_speech.strip():
370
+ request_data["text_to_speech"] = text_to_speech
371
+ request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
372
+ elif audio_url and audio_url.strip():
373
+ request_data["elevenlabs_audio_url"] = audio_url
374
+ else:
375
+ return "Error: Please provide either text to speech or audio URL"
376
+
377
+ if image_url and image_url.strip():
378
+ request_data["image_url"] = image_url
379
+
380
+ request = GenerateRequest(**request_data)
381
+
382
+ # Run async function in sync context
383
+ loop = asyncio.new_event_loop()
384
+ asyncio.set_event_loop(loop)
385
+ output_path, processing_time, audio_generated = loop.run_until_complete(omni_api.generate_avatar(request))
386
+ loop.close()
387
+
388
+ return output_path
389
+
390
+ except Exception as e:
391
+ logger.error(f"Gradio generation error: {e}")
392
+ return f"Error: {str(e)}"
393
+
394
+ # Updated Gradio interface with text-to-speech support
395
+ iface = gr.Interface(
396
+ fn=gradio_generate,
397
+ inputs=[
398
+ gr.Textbox(
399
+ label="Prompt",
400
+ placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
401
+ lines=2
402
+ ),
403
+ gr.Textbox(
404
+ label="Text to Speech",
405
+ placeholder="Enter text to convert to speech using ElevenLabs",
406
+ lines=3,
407
+ info="This will be converted to speech automatically"
408
+ ),
409
+ gr.Textbox(
410
+ label="OR Audio URL",
411
+ placeholder="https://api.elevenlabs.io/v1/text-to-speech/...",
412
+ info="Direct URL to audio file (alternative to text-to-speech)"
413
+ ),
414
+ gr.Textbox(
415
+ label="Image URL (Optional)",
416
+ placeholder="https://example.com/image.jpg",
417
+ info="Direct URL to reference image (JPG, PNG, etc.)"
418
+ ),
419
+ gr.Dropdown(
420
+ choices=["21m00Tcm4TlvDq8ikWAM", "pNInz6obpgDQGcFmaJgB", "EXAVITQu4vr4xnSDxMaL"],
421
+ value="21m00Tcm4TlvDq8ikWAM",
422
+ label="ElevenLabs Voice ID",
423
+ info="Choose voice for text-to-speech"
424
+ ),
425
+ gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
426
+ gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
427
+ gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
428
+ ],
429
+ outputs=gr.Video(label="Generated Avatar Video"),
430
+ title="🎭 OmniAvatar-14B with ElevenLabs TTS",
431
+ description="""
432
+ Generate avatar videos with lip-sync from text prompts and speech.
433
+
434
+ **Features:**
435
+ - ✅ **Text-to-Speech**: Enter text to generate speech automatically
436
+ - ✅ **ElevenLabs Integration**: High-quality voice synthesis
437
+ - ✅ **Audio URL Support**: Use pre-generated audio files
438
+ - ✅ **Image URL Support**: Reference images for character appearance
439
+ - ✅ **Customizable Parameters**: Fine-tune generation quality
440
+
441
+ **Usage:**
442
+ 1. Enter a character description in the prompt
443
+ 2. **Either** enter text for speech generation **OR** provide an audio URL
444
+ 3. Optionally add a reference image URL
445
+ 4. Choose voice and adjust parameters
446
+ 5. Generate your avatar video!
447
+
448
+ **Tips:**
449
+ - Use guidance scale 4-6 for best prompt following
450
+ - Increase audio scale for better lip-sync
451
+ - Clear, descriptive prompts work best
452
+ """,
453
+ examples=[
454
+ [
455
+ "A professional teacher explaining a mathematical concept with clear gestures",
456
+ "Hello students! Today we're going to learn about calculus and how derivatives work in real life.",
457
+ "",
458
+ "https://example.com/teacher.jpg",
459
+ "21m00Tcm4TlvDq8ikWAM",
460
+ 5.0,
461
+ 3.5,
462
+ 30
463
+ ],
464
+ [
465
+ "A friendly presenter speaking confidently to an audience",
466
+ "Welcome everyone to our presentation on artificial intelligence and its applications!",
467
+ "",
468
+ "",
469
+ "pNInz6obpgDQGcFmaJgB",
470
+ 5.5,
471
+ 4.0,
472
+ 35
473
+ ]
474
+ ]
475
+ )
476
+
477
+ # Mount Gradio app
478
+ app = gr.mount_gradio_app(app, iface, path="/gradio")
479
+
480
+ if __name__ == "__main__":
481
+ import uvicorn
482
+ uvicorn.run(app, host="0.0.0.0", port=7860)
configs/inference.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OmniAvatar-14B Inference Configuration
2
+
3
+ model:
4
+ base_model_path: "./pretrained_models/Wan2.1-T2V-14B"
5
+ lora_path: "./pretrained_models/OmniAvatar-14B"
6
+ audio_encoder_path: "./pretrained_models/wav2vec2-base-960h"
7
+
8
+ inference:
9
+ guidance_scale: 5.0
10
+ audio_scale: 3.0
11
+ num_inference_steps: 30
12
+ height: 480
13
+ width: 480
14
+ fps: 24
15
+ duration: 5.0
16
+
17
+ hardware:
18
+ device: "cuda"
19
+ mixed_precision: "fp16"
20
+ enable_xformers: true
21
+ enable_flash_attention: true
22
+
23
+ output:
24
+ output_dir: "./outputs"
25
+ format: "mp4"
26
+ codec: "h264"
27
+ bitrate: "5M"
28
+
29
+ tea_cache:
30
+ enabled: false
31
+ l1_thresh: 0.14
32
+
33
+ multi_gpu:
34
+ enabled: false
35
+ sp_size: 1
download_models.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "Downloading OmniAvatar-14B models..."
4
+
5
+ # Create directories
6
+ mkdir -p pretrained_models
7
+
8
+ # Install huggingface-hub if not already installed
9
+ pip install "huggingface_hub[cli]"
10
+
11
+ # Download models
12
+ echo "Downloading Wan2.1-T2V-14B..."
13
+ huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./pretrained_models/Wan2.1-T2V-14B
14
+
15
+ echo "Downloading wav2vec2-base-960h..."
16
+ huggingface-cli download facebook/wav2vec2-base-960h --local-dir ./pretrained_models/wav2vec2-base-960h
17
+
18
+ echo "Downloading OmniAvatar-14B..."
19
+ huggingface-cli download OmniAvatar/OmniAvatar-14B --local-dir ./pretrained_models/OmniAvatar-14B
20
+
21
+ echo "Model download completed!"
elevenlabs_integration.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ElevenLabs + OmniAvatar Integration Example
4
+ """
5
+
6
+ import requests
7
+ import json
8
+ import os
9
+ from typing import Optional
10
+
11
+ class ElevenLabsOmniAvatarClient:
12
+ def __init__(self, elevenlabs_api_key: str, omni_avatar_base_url: str = "http://localhost:7860"):
13
+ self.elevenlabs_api_key = elevenlabs_api_key
14
+ self.omni_avatar_base_url = omni_avatar_base_url
15
+ self.elevenlabs_base_url = "https://api.elevenlabs.io/v1"
16
+
17
+ def text_to_speech_url(self, text: str, voice_id: str, model_id: str = "eleven_monolingual_v1") -> str:
18
+ """
19
+ Generate speech from text using ElevenLabs and return the audio URL
20
+
21
+ Args:
22
+ text: Text to convert to speech
23
+ voice_id: ElevenLabs voice ID
24
+ model_id: ElevenLabs model ID
25
+
26
+ Returns:
27
+ URL to the generated audio file
28
+ """
29
+ url = f"{self.elevenlabs_base_url}/text-to-speech/{voice_id}"
30
+
31
+ headers = {
32
+ "Accept": "audio/mpeg",
33
+ "Content-Type": "application/json",
34
+ "xi-api-key": self.elevenlabs_api_key
35
+ }
36
+
37
+ data = {
38
+ "text": text,
39
+ "model_id": model_id,
40
+ "voice_settings": {
41
+ "stability": 0.5,
42
+ "similarity_boost": 0.5
43
+ }
44
+ }
45
+
46
+ # Generate audio
47
+ response = requests.post(url, json=data, headers=headers)
48
+
49
+ if response.status_code != 200:
50
+ raise Exception(f"ElevenLabs API error: {response.status_code} - {response.text}")
51
+
52
+ # Save audio to temporary file and return a URL
53
+ # In practice, you might upload this to a CDN or file server
54
+ # For this example, we'll assume you have a way to serve the file
55
+
56
+ # This is a placeholder - in real implementation, you would:
57
+ # 1. Save the audio file
58
+ # 2. Upload to a file server or CDN
59
+ # 3. Return the public URL
60
+
61
+ return f"{self.elevenlabs_base_url}/text-to-speech/{voice_id}?text={text}&model_id={model_id}"
62
+
63
+ def generate_avatar(self,
64
+ prompt: str,
65
+ speech_text: str,
66
+ voice_id: str,
67
+ image_url: Optional[str] = None,
68
+ guidance_scale: float = 5.0,
69
+ audio_scale: float = 3.5,
70
+ num_steps: int = 30) -> dict:
71
+ """
72
+ Generate avatar video using ElevenLabs audio and OmniAvatar
73
+
74
+ Args:
75
+ prompt: Description of character behavior
76
+ speech_text: Text to be spoken (sent to ElevenLabs)
77
+ voice_id: ElevenLabs voice ID
78
+ image_url: Optional reference image URL
79
+ guidance_scale: Prompt guidance scale
80
+ audio_scale: Audio guidance scale
81
+ num_steps: Number of inference steps
82
+
83
+ Returns:
84
+ Generation result with video path and metadata
85
+ """
86
+
87
+ try:
88
+ # Step 1: Generate audio URL from ElevenLabs
89
+ print(f"🎤 Generating speech with ElevenLabs...")
90
+ print(f"Text: {speech_text}")
91
+ print(f"Voice ID: {voice_id}")
92
+
93
+ # Get audio URL from ElevenLabs
94
+ elevenlabs_audio_url = self.text_to_speech_url(speech_text, voice_id)
95
+
96
+ # Step 2: Generate avatar with OmniAvatar
97
+ print(f"🎭 Generating avatar with OmniAvatar...")
98
+ print(f"Prompt: {prompt}")
99
+
100
+ avatar_data = {
101
+ "prompt": prompt,
102
+ "elevenlabs_audio_url": elevenlabs_audio_url,
103
+ "guidance_scale": guidance_scale,
104
+ "audio_scale": audio_scale,
105
+ "num_steps": num_steps
106
+ }
107
+
108
+ if image_url:
109
+ avatar_data["image_url"] = image_url
110
+ print(f"Image URL: {image_url}")
111
+
112
+ response = requests.post(f"{self.omni_avatar_base_url}/generate", json=avatar_data)
113
+
114
+ if response.status_code != 200:
115
+ raise Exception(f"OmniAvatar API error: {response.status_code} - {response.text}")
116
+
117
+ result = response.json()
118
+
119
+ print(f"✅ Avatar generated successfully!")
120
+ print(f"Output: {result['output_path']}")
121
+ print(f"Processing time: {result['processing_time']:.2f}s")
122
+
123
+ return result
124
+
125
+ except Exception as e:
126
+ print(f"❌ Error generating avatar: {e}")
127
+ raise
128
+
129
+ def main():
130
+ """Example usage"""
131
+
132
+ # Configuration
133
+ ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY", "your-elevenlabs-api-key")
134
+ OMNI_AVATAR_URL = os.getenv("OMNI_AVATAR_URL", "http://localhost:7860")
135
+
136
+ if ELEVENLABS_API_KEY == "your-elevenlabs-api-key":
137
+ print("⚠�� Please set your ELEVENLABS_API_KEY environment variable")
138
+ print("Example: export ELEVENLABS_API_KEY='your-actual-api-key'")
139
+ return
140
+
141
+ # Initialize client
142
+ client = ElevenLabsOmniAvatarClient(ELEVENLABS_API_KEY, OMNI_AVATAR_URL)
143
+
144
+ # Example 1: Basic avatar generation
145
+ print("=== Example 1: Basic Avatar Generation ===")
146
+ try:
147
+ result = client.generate_avatar(
148
+ prompt="A friendly teacher explaining a concept with clear hand gestures",
149
+ speech_text="Hello! Today we're going to learn about artificial intelligence and how it works.",
150
+ voice_id="21m00Tcm4TlvDq8ikWAM", # Replace with your voice ID
151
+ guidance_scale=5.0,
152
+ audio_scale=4.0,
153
+ num_steps=30
154
+ )
155
+ print(f"Video saved to: {result['output_path']}")
156
+ except Exception as e:
157
+ print(f"Example 1 failed: {e}")
158
+
159
+ # Example 2: Avatar with reference image
160
+ print("\n=== Example 2: Avatar with Reference Image ===")
161
+ try:
162
+ result = client.generate_avatar(
163
+ prompt="A professional presenter speaking confidently to an audience",
164
+ speech_text="Welcome to our presentation on the future of technology.",
165
+ voice_id="21m00Tcm4TlvDq8ikWAM", # Replace with your voice ID
166
+ image_url="https://example.com/professional-headshot.jpg", # Replace with actual image
167
+ guidance_scale=5.5,
168
+ audio_scale=3.5,
169
+ num_steps=35
170
+ )
171
+ print(f"Video with reference image saved to: {result['output_path']}")
172
+ except Exception as e:
173
+ print(f"Example 2 failed: {e}")
174
+
175
+ print("\n🎉 Integration examples completed!")
176
+ print("\nTo use this script:")
177
+ print("1. Set your ElevenLabs API key: export ELEVENLABS_API_KEY='your-key'")
178
+ print("2. Start OmniAvatar API: python app.py")
179
+ print("3. Run this script: python elevenlabs_integration.py")
180
+
181
+ if __name__ == "__main__":
182
+ main()
examples/infer_samples.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ A young person speaking confidently@@@@./examples/sample_audio.wav
2
+ A teacher explaining a concept@@./examples/teacher.jpg@@./examples/lesson_audio.wav
3
+ An animated character telling a story@@@@./examples/story_audio.wav
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core web framework dependencies
2
+ fastapi==0.104.1
3
+ uvicorn[standard]==0.24.0
4
+ gradio==4.7.1
5
+
6
+ # PyTorch ecosystem (pre-installed in base image)
7
+ torch>=2.0.0
8
+ torchvision>=0.15.0
9
+ torchaudio>=2.0.0
10
+
11
+ # ML/AI libraries
12
+ transformers>=4.21.0
13
+ diffusers>=0.21.0
14
+ accelerate>=0.21.0
15
+ xformers>=0.0.20
16
+
17
+ # Media processing
18
+ opencv-python-headless>=4.8.0
19
+ librosa>=0.10.0
20
+ soundfile>=0.12.0
21
+ pillow>=9.5.0
22
+
23
+ # Scientific computing
24
+ numpy>=1.21.0
25
+ scipy>=1.9.0
26
+ einops>=0.6.0
27
+
28
+ # Configuration and training
29
+ omegaconf>=2.3.0
30
+ pytorch-lightning>=2.0.0
31
+ torchmetrics>=1.0.0
32
+
33
+ # API and networking
34
+ pydantic>=2.4.0
35
+ aiohttp>=3.8.0
36
+ aiofiles
37
+ python-dotenv>=1.0.0
38
+
39
+ # Attention optimization (optional, may fail on some systems)
40
+ # flash-attn>=2.3.0
41
+
42
+ # Additional dependencies for HF Spaces
43
+ huggingface-hub>=0.17.0
44
+ safetensors>=0.4.0
scripts/inference.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import torch
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+ import logging
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(description="OmniAvatar-14B Inference")
14
+ parser.add_argument("--config", type=str, required=True, help="Path to config file")
15
+ parser.add_argument("--input_file", type=str, required=True, help="Path to input samples file")
16
+ parser.add_argument("--guidance_scale", type=float, default=5.0, help="Guidance scale")
17
+ parser.add_argument("--audio_scale", type=float, default=3.0, help="Audio guidance scale")
18
+ parser.add_argument("--num_steps", type=int, default=30, help="Number of inference steps")
19
+ parser.add_argument("--sp_size", type=int, default=1, help="Multi-GPU size")
20
+ parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, help="TeaCache threshold")
21
+ return parser.parse_args()
22
+
23
+ def load_config(config_path):
24
+ with open(config_path, 'r') as f:
25
+ return yaml.safe_load(f)
26
+
27
+ def process_input_file(input_file):
28
+ """Parse input file with format: prompt@@image_path@@audio_path"""
29
+ samples = []
30
+ with open(input_file, 'r') as f:
31
+ for line in f:
32
+ line = line.strip()
33
+ if line:
34
+ parts = line.split('@@')
35
+ if len(parts) >= 3:
36
+ prompt = parts[0]
37
+ image_path = parts[1] if parts[1] else None
38
+ audio_path = parts[2]
39
+ samples.append({
40
+ 'prompt': prompt,
41
+ 'image_path': image_path,
42
+ 'audio_path': audio_path
43
+ })
44
+ return samples
45
+
46
+ def main():
47
+ args = parse_args()
48
+
49
+ # Load configuration
50
+ config = load_config(args.config)
51
+
52
+ # Process input samples
53
+ samples = process_input_file(args.input_file)
54
+
55
+ logger.info(f"Processing {len(samples)} samples")
56
+
57
+ # Create output directory
58
+ output_dir = Path(config['output']['output_dir'])
59
+ output_dir.mkdir(exist_ok=True)
60
+
61
+ # This is a placeholder - actual inference would require the OmniAvatar model implementation
62
+ logger.info("Note: This is a placeholder inference script.")
63
+ logger.info("Actual implementation would require:")
64
+ logger.info("1. Loading the OmniAvatar model")
65
+ logger.info("2. Processing audio with wav2vec2")
66
+ logger.info("3. Running video generation pipeline")
67
+ logger.info("4. Saving output videos")
68
+
69
+ for i, sample in enumerate(samples):
70
+ logger.info(f"Sample {i+1}: {sample['prompt']}")
71
+ logger.info(f" Audio: {sample['audio_path']}")
72
+ logger.info(f" Image: {sample['image_path']}")
73
+
74
+ logger.info("Inference completed successfully!")
75
+
76
+ if __name__ == "__main__":
77
+ main()