Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import tempfile
|
@@ -10,7 +14,7 @@ import librosa
|
|
10 |
import soundfile as sf
|
11 |
import torch
|
12 |
from pathlib import Path
|
13 |
-
from typing import Iterator, Dict, Any
|
14 |
|
15 |
# Clone and setup VibeVoice if not already present
|
16 |
vibevoice_dir = Path('./VibeVoice')
|
@@ -87,20 +91,6 @@ from transformers import set_seed
|
|
87 |
logging.set_verbosity_info()
|
88 |
logger = logging.get_logger(__name__)
|
89 |
|
90 |
-
# --- Helper function for audio conversion ---
|
91 |
-
def convert_to_16_bit_wav(data: np.ndarray | torch.Tensor) -> np.ndarray:
|
92 |
-
"""Convert audio data to 16-bit WAV format (numpy int16)."""
|
93 |
-
if torch.is_tensor(data):
|
94 |
-
data = data.detach().cpu().numpy()
|
95 |
-
|
96 |
-
data = np.array(data, dtype=np.float32) # Ensure float32 before scaling
|
97 |
-
|
98 |
-
# Normalize to -1 to 1 if necessary
|
99 |
-
if np.max(np.abs(data)) > 1.0:
|
100 |
-
data = data / np.max(np.abs(data))
|
101 |
-
|
102 |
-
data = (data * 32767).astype(np.int16)
|
103 |
-
return data
|
104 |
|
105 |
class VibeVoiceChat:
|
106 |
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
|
@@ -110,8 +100,7 @@ class VibeVoiceChat:
|
|
110 |
self.inference_steps = inference_steps
|
111 |
self.is_generating = False
|
112 |
self.stop_generation = False
|
113 |
-
self.current_streamer
|
114 |
-
self.complete_audio_buffer: List[np.ndarray] = [] # To store all generated audio for final download
|
115 |
|
116 |
# Check GPU availability and CUDA version
|
117 |
if torch.cuda.is_available():
|
@@ -119,8 +108,10 @@ class VibeVoiceChat:
|
|
119 |
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
120 |
print(f" CUDA Version: {torch.version.cuda}")
|
121 |
print(f" PyTorch CUDA: {torch.cuda.is_available()}")
|
122 |
-
|
123 |
-
torch.
|
|
|
|
|
124 |
torch.backends.cudnn.allow_tf32 = True
|
125 |
else:
|
126 |
print("β No GPU detected, using CPU (generation will be VERY slow)")
|
@@ -178,13 +169,15 @@ class VibeVoiceChat:
|
|
178 |
load_time = time.time() - start_time
|
179 |
print(f"β Model loaded in {load_time:.2f} seconds")
|
180 |
|
|
|
181 |
if hasattr(self.model, 'device'):
|
182 |
print(f"Model device: {self.model.device}")
|
183 |
|
184 |
def setup_voice_presets(self):
|
185 |
"""Setup voice presets from the voices directory."""
|
186 |
-
voices_dir = os.path.join(os.path.dirname(
|
187 |
|
|
|
188 |
if not os.path.exists(voices_dir):
|
189 |
os.makedirs(voices_dir)
|
190 |
print(f"Created voices directory at {voices_dir}")
|
@@ -193,16 +186,19 @@ class VibeVoiceChat:
|
|
193 |
self.available_voices = {}
|
194 |
audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
|
195 |
|
|
|
196 |
for file in os.listdir(voices_dir):
|
197 |
if file.lower().endswith(audio_extensions):
|
198 |
name = os.path.splitext(file)[0]
|
199 |
self.available_voices[name] = os.path.join(voices_dir, file)
|
200 |
|
|
|
201 |
self.available_voices = dict(sorted(self.available_voices.items()))
|
202 |
|
203 |
if not self.available_voices:
|
204 |
print(f"Warning: No voice files found in {voices_dir}")
|
205 |
print("Using default (zero) voice samples. Add audio files to the voices directory for better results.")
|
|
|
206 |
self.available_voices = {"Default": None}
|
207 |
else:
|
208 |
print(f"Found {len(self.available_voices)} voice presets: {', '.join(self.available_voices.keys())}")
|
@@ -210,13 +206,15 @@ class VibeVoiceChat:
|
|
210 |
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
|
211 |
"""Read and preprocess audio file."""
|
212 |
try:
|
213 |
-
wav, sr =
|
|
|
|
|
214 |
if sr != target_sr:
|
215 |
-
wav = librosa.resample(
|
216 |
return wav
|
217 |
except Exception as e:
|
218 |
print(f"Error reading audio {audio_path}: {e}")
|
219 |
-
return np.zeros(
|
220 |
|
221 |
def format_script(self, message: str, num_speakers: int = 2) -> str:
|
222 |
"""Format input message into a script with speaker assignments."""
|
@@ -228,9 +226,11 @@ class VibeVoiceChat:
|
|
228 |
if not line:
|
229 |
continue
|
230 |
|
|
|
231 |
if line.startswith('Speaker ') and ':' in line:
|
232 |
formatted_lines.append(line)
|
233 |
else:
|
|
|
234 |
speaker_id = i % num_speakers
|
235 |
formatted_lines.append(f"Speaker {speaker_id}: {line}")
|
236 |
|
@@ -239,50 +239,59 @@ class VibeVoiceChat:
|
|
239 |
def generate_audio_stream(
|
240 |
self,
|
241 |
message: str,
|
|
|
242 |
voice_1: str,
|
243 |
voice_2: str,
|
244 |
num_speakers: int,
|
245 |
cfg_scale: float
|
246 |
-
) -> Iterator[tuple]:
|
247 |
-
"""
|
248 |
-
Generate audio stream from text input, implementing buffering for smoother streaming.
|
249 |
-
Yields (sample_rate, audio_chunk_numpy_int16) tuples as audio becomes available.
|
250 |
-
"""
|
251 |
try:
|
252 |
self.stop_generation = False
|
253 |
self.is_generating = True
|
254 |
-
self.complete_audio_buffer = [] # Reset buffer for new generation
|
255 |
|
|
|
256 |
if not message.strip():
|
257 |
-
self.is_generating = False
|
258 |
yield None
|
259 |
return
|
260 |
|
|
|
261 |
formatted_script = self.format_script(message, num_speakers)
|
|
|
|
|
262 |
|
|
|
|
|
|
|
|
|
263 |
selected_voices = []
|
264 |
if voice_1 and voice_1 != "Default":
|
265 |
selected_voices.append(voice_1)
|
266 |
if num_speakers > 1 and voice_2 and voice_2 != "Default":
|
267 |
selected_voices.append(voice_2)
|
268 |
|
|
|
269 |
voice_samples = []
|
270 |
-
target_sr = 24000
|
271 |
for i in range(num_speakers):
|
|
|
272 |
if i < len(selected_voices):
|
273 |
voice_name = selected_voices[i]
|
274 |
if voice_name in self.available_voices and self.available_voices[voice_name]:
|
275 |
-
audio_data = self.read_audio(self.available_voices[voice_name]
|
276 |
else:
|
277 |
-
audio_data = np.zeros(
|
278 |
else:
|
|
|
279 |
if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
|
280 |
-
audio_data = self.read_audio(self.available_voices[selected_voices[0]]
|
281 |
else:
|
282 |
-
audio_data = np.zeros(
|
283 |
|
284 |
voice_samples.append(audio_data)
|
285 |
|
|
|
|
|
|
|
286 |
inputs = self.processor(
|
287 |
text=[formatted_script],
|
288 |
voice_samples=[voice_samples],
|
@@ -291,9 +300,15 @@ class VibeVoiceChat:
|
|
291 |
return_attention_mask=True,
|
292 |
)
|
293 |
|
|
|
294 |
if self.device == "cuda":
|
295 |
inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
|
|
|
|
|
|
|
|
|
296 |
|
|
|
297 |
audio_streamer = AudioStreamer(
|
298 |
batch_size=1,
|
299 |
stop_signal=None,
|
@@ -302,89 +317,86 @@ class VibeVoiceChat:
|
|
302 |
|
303 |
self.current_streamer = audio_streamer
|
304 |
|
|
|
305 |
generation_thread = threading.Thread(
|
306 |
target=self._generate_with_streamer,
|
307 |
args=(inputs, cfg_scale, audio_streamer)
|
308 |
)
|
309 |
generation_thread.start()
|
310 |
|
311 |
-
#
|
312 |
-
time.sleep(1
|
313 |
|
314 |
-
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
min_yield_interval_seconds = 1.0 # Yield at least every 1 second
|
319 |
-
min_chunk_size_samples = target_sr * 0.5 # At least 0.5 seconds of audio per chunk yielded to Gradio
|
320 |
-
last_yield_time = time.time()
|
321 |
|
322 |
-
for
|
323 |
if self.stop_generation:
|
324 |
audio_streamer.end()
|
325 |
break
|
326 |
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
|
|
|
|
332 |
else:
|
333 |
-
audio_np = np.array(
|
334 |
|
|
|
335 |
if len(audio_np.shape) > 1:
|
336 |
audio_np = audio_np.squeeze()
|
337 |
|
338 |
-
#
|
339 |
-
self.
|
340 |
-
|
341 |
-
# Append to pending chunks for streaming to Gradio
|
342 |
-
pending_chunks.append(audio_np)
|
343 |
-
current_pending_size = sum(len(c) for c in pending_chunks)
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
should_yield = True
|
350 |
-
elif (current_time - last_yield_time) >= min_yield_interval_seconds and pending_chunks:
|
351 |
-
should_yield = True
|
352 |
-
|
353 |
-
if should_yield:
|
354 |
-
combined_chunk = np.concatenate(pending_chunks)
|
355 |
-
yield (target_sr, convert_to_16_bit_wav(combined_chunk)) # Convert to int16 before yielding
|
356 |
-
pending_chunks = []
|
357 |
-
last_yield_time = current_time
|
358 |
|
359 |
-
#
|
360 |
-
|
361 |
-
combined_chunk = np.concatenate(pending_chunks)
|
362 |
-
yield (target_sr, convert_to_16_bit_wav(combined_chunk))
|
363 |
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
-
# Clean up
|
367 |
self.current_streamer = None
|
368 |
self.is_generating = False
|
369 |
|
370 |
except Exception as e:
|
371 |
-
print(f"Error in
|
372 |
import traceback
|
373 |
traceback.print_exc()
|
374 |
self.is_generating = False
|
375 |
self.current_streamer = None
|
376 |
-
self.complete_audio_buffer = [] # Clear buffer on error
|
377 |
yield None
|
378 |
|
379 |
-
def _generate_with_streamer(self, inputs
|
380 |
"""Helper method to run generation with streamer."""
|
381 |
try:
|
382 |
def check_stop():
|
383 |
return self.stop_generation
|
384 |
|
|
|
385 |
if self.device == "cuda" and torch.cuda.is_available():
|
386 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
387 |
-
self.model.generate(
|
388 |
**inputs,
|
389 |
max_new_tokens=None,
|
390 |
cfg_scale=cfg_scale,
|
@@ -396,7 +408,7 @@ class VibeVoiceChat:
|
|
396 |
refresh_negative=True,
|
397 |
)
|
398 |
else:
|
399 |
-
self.model.generate(
|
400 |
**inputs,
|
401 |
max_new_tokens=None,
|
402 |
cfg_scale=cfg_scale,
|
@@ -411,232 +423,91 @@ class VibeVoiceChat:
|
|
411 |
print(f"Error in generation thread: {e}")
|
412 |
import traceback
|
413 |
traceback.print_exc()
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
def stop_audio_generation(self):
|
418 |
-
"""
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
except Exception as e:
|
426 |
-
print(f"Error ending streamer: {e}")
|
427 |
-
self.is_generating = False
|
428 |
-
self.complete_audio_buffer = []
|
429 |
-
else:
|
430 |
-
print("No active generation to stop.")
|
431 |
|
432 |
|
433 |
def create_chat_interface(chat_instance: VibeVoiceChat):
|
434 |
-
"""Create a simplified Gradio ChatInterface for VibeVoice
|
435 |
|
|
|
436 |
voice_options = list(chat_instance.available_voices.keys())
|
437 |
if not voice_options:
|
438 |
voice_options = ["Default"]
|
439 |
|
440 |
default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
|
441 |
default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
|
442 |
-
|
443 |
-
# Custom CSS for modern aesthetics
|
444 |
-
custom_css = """
|
445 |
-
.gradio-container {
|
446 |
-
font-family: 'Inter', sans-serif;
|
447 |
-
background: linear-gradient(135deg, #f0f2f5 0%, #e0e6ed 100%);
|
448 |
-
color: #333;
|
449 |
-
}
|
450 |
-
.main-header {
|
451 |
-
background: linear-gradient(45deg, #4A00E0 0%, #8E2DE2 100%);
|
452 |
-
padding: 20px 30px;
|
453 |
-
border-radius: 15px;
|
454 |
-
margin-bottom: 25px;
|
455 |
-
text-align: center;
|
456 |
-
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
|
457 |
-
}
|
458 |
-
.main-header h1 {
|
459 |
-
color: white;
|
460 |
-
font-size: 2.8em;
|
461 |
-
font-weight: 800;
|
462 |
-
margin: 0;
|
463 |
-
letter-spacing: -1px;
|
464 |
-
text-shadow: 0 3px 5px rgba(0,0,0,0.2);
|
465 |
-
}
|
466 |
-
.main-header p {
|
467 |
-
color: rgba(255,255,255,0.85);
|
468 |
-
font-size: 1.1em;
|
469 |
-
margin-top: 10px;
|
470 |
-
}
|
471 |
-
.settings-card, .generation-card {
|
472 |
-
background: rgba(255, 255, 255, 0.9);
|
473 |
-
border: 1px solid #dcdfe6;
|
474 |
-
border-radius: 12px;
|
475 |
-
padding: 20px;
|
476 |
-
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.08);
|
477 |
-
transition: all 0.3s ease;
|
478 |
-
}
|
479 |
-
.settings-card:hover, .generation-card:hover {
|
480 |
-
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.12);
|
481 |
-
transform: translateY(-2px);
|
482 |
-
}
|
483 |
-
.gradio-output {
|
484 |
-
border-radius: 10px;
|
485 |
-
background-color: #fcfcfc;
|
486 |
-
}
|
487 |
-
.gradio-button {
|
488 |
-
border-radius: 8px !important;
|
489 |
-
font-weight: 600;
|
490 |
-
padding: 10px 20px;
|
491 |
-
transition: all 0.2s ease-in-out;
|
492 |
-
}
|
493 |
-
.gradio-button.primary {
|
494 |
-
background: linear-gradient(45deg, #4CAF50 0%, #8BC34A 100%) !important;
|
495 |
-
color: white !important;
|
496 |
-
border: none !important;
|
497 |
-
}
|
498 |
-
.gradio-button.primary:hover {
|
499 |
-
opacity: 0.9;
|
500 |
-
transform: translateY(-1px);
|
501 |
-
}
|
502 |
-
.gradio-button.secondary {
|
503 |
-
background: linear-gradient(45deg, #FF5722 0%, #FFC107 100%) !important;
|
504 |
-
color: white !important;
|
505 |
-
border: none !important;
|
506 |
-
}
|
507 |
-
.gradio-button.secondary:hover {
|
508 |
-
opacity: 0.9;
|
509 |
-
transform: translateY(-1px);
|
510 |
-
}
|
511 |
-
.gradio-button.clear {
|
512 |
-
background: #90A4AE !important;
|
513 |
-
color: white !important;
|
514 |
-
border: none !important;
|
515 |
-
}
|
516 |
-
.gradio-button.clear:hover {
|
517 |
-
opacity: 0.9;
|
518 |
-
transform: translateY(-1px);
|
519 |
-
}
|
520 |
-
.gradio-input {
|
521 |
-
border-radius: 8px !important;
|
522 |
-
border: 1px solid #ced4da !important;
|
523 |
-
}
|
524 |
-
.gradio-label {
|
525 |
-
font-weight: 700;
|
526 |
-
color: #495057;
|
527 |
-
margin-bottom: 5px;
|
528 |
-
}
|
529 |
-
.chatbot {
|
530 |
-
border: 1px solid #e0e0e0 !important;
|
531 |
-
border-radius: 10px !important;
|
532 |
-
box-shadow: 0 2px 10px rgba(0,0,0,0.05);
|
533 |
-
}
|
534 |
-
.log-output {
|
535 |
-
font-family: 'JetBrains Mono', monospace;
|
536 |
-
background-color: #f8f9fa !important;
|
537 |
-
border-radius: 8px !important;
|
538 |
-
border: 1px solid #e9ecef !important;
|
539 |
-
color: #495057 !important;
|
540 |
-
min-height: 80px;
|
541 |
-
}
|
542 |
-
.audio-output {
|
543 |
-
border-radius: 10px !important;
|
544 |
-
border: 1px solid #e0e0e0 !important;
|
545 |
-
background-color: #f8f9fa !important;
|
546 |
-
}
|
547 |
-
"""
|
548 |
|
549 |
-
#
|
550 |
-
def
|
551 |
-
|
552 |
-
user_message_entry = {"role": "user", "content": message_text}
|
553 |
-
|
554 |
-
# Initial state: user message added, text input disabled, buttons updated, audio cleared/hidden
|
555 |
-
# This yield ensures immediate UI feedback
|
556 |
-
yield (
|
557 |
-
history + [user_message_entry], # Add user message to chatbot immediately
|
558 |
-
gr.update(value="", interactive=False), # Clear text input and disable
|
559 |
-
gr.update(value=None, visible=True), # Clear streaming audio, make visible
|
560 |
-
gr.update(value=None, visible=False), # Clear complete audio, hide
|
561 |
-
"ποΈ Starting audio generation...", # Initial log message
|
562 |
-
gr.update(interactive=False, value="Generating..."), # Disable submit button
|
563 |
-
gr.update(visible=True) # Show stop button
|
564 |
-
)
|
565 |
-
|
566 |
-
log_message = ""
|
567 |
-
generated_any_audio = False
|
568 |
|
569 |
-
#
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
573 |
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
else:
|
595 |
-
|
596 |
-
log_message = "β Error during audio generation."
|
597 |
-
break
|
598 |
-
|
599 |
-
# After generation (or stop/error), prepare final updates
|
600 |
-
final_chatbot_history = history + [user_message_entry]
|
601 |
-
final_streaming_audio_update = gr.update(value=None, visible=False) # Hide streaming audio
|
602 |
-
final_complete_audio_update = gr.update(value=None, visible=False) # Default to hidden
|
603 |
-
|
604 |
-
if chat_instance.stop_generation:
|
605 |
-
final_chatbot_history.append({"role": "assistant", "content": "π« Audio generation stopped."})
|
606 |
-
log_message = "π Generation stopped by user."
|
607 |
-
chat_instance.stop_generation = False # Reset flag for next run
|
608 |
-
elif generated_any_audio and chat_instance.complete_audio_buffer:
|
609 |
-
# Concatenate all collected audio chunks for the final downloadable audio
|
610 |
-
complete_audio_data_np = np.concatenate(chat_instance.complete_audio_buffer)
|
611 |
-
final_complete_audio_update = gr.update(value=(24000, convert_to_16_bit_wav(complete_audio_data_np)), visible=True)
|
612 |
-
final_chatbot_history.append({"role": "assistant", "content": "β
Audio generated successfully! Listen below and download."})
|
613 |
-
log_message = "β¨ Generation complete! See 'Complete Audio' below."
|
614 |
-
else:
|
615 |
-
final_chatbot_history.append({"role": "assistant", "content": "β Failed to generate audio."})
|
616 |
-
log_message = "β Generation failed or no audio produced."
|
617 |
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
)
|
628 |
-
|
629 |
-
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True, css=custom_css) as interface:
|
630 |
-
gr.HTML("""
|
631 |
-
<div class="main-header">
|
632 |
-
<h1>ποΈ VibeVoice Chat - Streamed Audio</h1>
|
633 |
-
<p>Generate natural dialogue audio with AI voices</p>
|
634 |
-
</div>
|
635 |
-
""")
|
636 |
|
637 |
with gr.Row():
|
638 |
-
with gr.Column(scale=1
|
639 |
-
gr.Markdown("###
|
640 |
|
641 |
voice_1 = gr.Dropdown(
|
642 |
choices=voice_options,
|
@@ -670,37 +541,140 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
|
|
670 |
info="Guidance strength (higher = more adherence to text)"
|
671 |
)
|
672 |
|
673 |
-
with gr.Column(scale=2
|
674 |
chatbot = gr.Chatbot(
|
675 |
label="Conversation",
|
676 |
-
height=
|
677 |
type="messages",
|
678 |
-
elem_id="chatbot"
|
679 |
-
elem_classes="chatbot"
|
680 |
)
|
681 |
|
682 |
msg = gr.Textbox(
|
683 |
label="Message",
|
684 |
placeholder="Type your message or paste a script...",
|
685 |
-
lines=3
|
686 |
-
elem_classes="gradio-input"
|
687 |
)
|
688 |
|
689 |
-
# Log output for generation status
|
690 |
-
log_output = gr.Textbox(
|
691 |
-
label="Generation Log",
|
692 |
-
lines=2,
|
693 |
-
max_lines=5,
|
694 |
-
interactive=False,
|
695 |
-
value="Ready to generate audio.",
|
696 |
-
elem_classes="log-output"
|
697 |
-
)
|
698 |
-
|
699 |
-
# Streaming audio component
|
700 |
audio_output = gr.Audio(
|
701 |
-
label="
|
702 |
-
type="
|
703 |
-
streaming=True,
|
704 |
autoplay=True,
|
705 |
-
visible=
|
706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
VibeVoice Simple Chat Interface - Streamlined Audio Generation Demo
|
3 |
+
"""
|
4 |
+
|
5 |
import argparse
|
6 |
import os
|
7 |
import tempfile
|
|
|
14 |
import soundfile as sf
|
15 |
import torch
|
16 |
from pathlib import Path
|
17 |
+
from typing import Iterator, Dict, Any
|
18 |
|
19 |
# Clone and setup VibeVoice if not already present
|
20 |
vibevoice_dir = Path('./VibeVoice')
|
|
|
91 |
logging.set_verbosity_info()
|
92 |
logger = logging.get_logger(__name__)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
class VibeVoiceChat:
|
96 |
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
|
|
|
100 |
self.inference_steps = inference_steps
|
101 |
self.is_generating = False
|
102 |
self.stop_generation = False
|
103 |
+
self.current_streamer = None
|
|
|
104 |
|
105 |
# Check GPU availability and CUDA version
|
106 |
if torch.cuda.is_available():
|
|
|
108 |
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
109 |
print(f" CUDA Version: {torch.version.cuda}")
|
110 |
print(f" PyTorch CUDA: {torch.cuda.is_available()}")
|
111 |
+
# Set memory fraction to avoid OOM
|
112 |
+
torch.cuda.set_per_process_memory_fraction(0.95)
|
113 |
+
# Enable TF32 for faster computation on Ampere GPUs
|
114 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
115 |
torch.backends.cudnn.allow_tf32 = True
|
116 |
else:
|
117 |
print("β No GPU detected, using CPU (generation will be VERY slow)")
|
|
|
169 |
load_time = time.time() - start_time
|
170 |
print(f"β Model loaded in {load_time:.2f} seconds")
|
171 |
|
172 |
+
# Print model device
|
173 |
if hasattr(self.model, 'device'):
|
174 |
print(f"Model device: {self.model.device}")
|
175 |
|
176 |
def setup_voice_presets(self):
|
177 |
"""Setup voice presets from the voices directory."""
|
178 |
+
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
|
179 |
|
180 |
+
# Create voices directory if it doesn't exist
|
181 |
if not os.path.exists(voices_dir):
|
182 |
os.makedirs(voices_dir)
|
183 |
print(f"Created voices directory at {voices_dir}")
|
|
|
186 |
self.available_voices = {}
|
187 |
audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
|
188 |
|
189 |
+
# Scan for audio files
|
190 |
for file in os.listdir(voices_dir):
|
191 |
if file.lower().endswith(audio_extensions):
|
192 |
name = os.path.splitext(file)[0]
|
193 |
self.available_voices[name] = os.path.join(voices_dir, file)
|
194 |
|
195 |
+
# Sort voices alphabetically
|
196 |
self.available_voices = dict(sorted(self.available_voices.items()))
|
197 |
|
198 |
if not self.available_voices:
|
199 |
print(f"Warning: No voice files found in {voices_dir}")
|
200 |
print("Using default (zero) voice samples. Add audio files to the voices directory for better results.")
|
201 |
+
# Add a default "None" option
|
202 |
self.available_voices = {"Default": None}
|
203 |
else:
|
204 |
print(f"Found {len(self.available_voices)} voice presets: {', '.join(self.available_voices.keys())}")
|
|
|
206 |
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
|
207 |
"""Read and preprocess audio file."""
|
208 |
try:
|
209 |
+
wav, sr = sf.read(audio_path)
|
210 |
+
if len(wav.shape) > 1:
|
211 |
+
wav = np.mean(wav, axis=1)
|
212 |
if sr != target_sr:
|
213 |
+
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
|
214 |
return wav
|
215 |
except Exception as e:
|
216 |
print(f"Error reading audio {audio_path}: {e}")
|
217 |
+
return np.zeros(24000) # Return 1 second of silence as fallback
|
218 |
|
219 |
def format_script(self, message: str, num_speakers: int = 2) -> str:
|
220 |
"""Format input message into a script with speaker assignments."""
|
|
|
226 |
if not line:
|
227 |
continue
|
228 |
|
229 |
+
# Check if already formatted
|
230 |
if line.startswith('Speaker ') and ':' in line:
|
231 |
formatted_lines.append(line)
|
232 |
else:
|
233 |
+
# Auto-assign speakers in rotation
|
234 |
speaker_id = i % num_speakers
|
235 |
formatted_lines.append(f"Speaker {speaker_id}: {line}")
|
236 |
|
|
|
239 |
def generate_audio_stream(
|
240 |
self,
|
241 |
message: str,
|
242 |
+
history: list,
|
243 |
voice_1: str,
|
244 |
voice_2: str,
|
245 |
num_speakers: int,
|
246 |
cfg_scale: float
|
247 |
+
) -> Iterator[tuple]:
|
248 |
+
"""Generate audio stream from text input."""
|
|
|
|
|
|
|
249 |
try:
|
250 |
self.stop_generation = False
|
251 |
self.is_generating = True
|
|
|
252 |
|
253 |
+
# Validate inputs
|
254 |
if not message.strip():
|
|
|
255 |
yield None
|
256 |
return
|
257 |
|
258 |
+
# Format the script
|
259 |
formatted_script = self.format_script(message, num_speakers)
|
260 |
+
print(f"Formatted script:\n{formatted_script}")
|
261 |
+
print(f"Using device: {self.device}")
|
262 |
|
263 |
+
# Start timing
|
264 |
+
start_time = time.time()
|
265 |
+
|
266 |
+
# Select voices based on number of speakers
|
267 |
selected_voices = []
|
268 |
if voice_1 and voice_1 != "Default":
|
269 |
selected_voices.append(voice_1)
|
270 |
if num_speakers > 1 and voice_2 and voice_2 != "Default":
|
271 |
selected_voices.append(voice_2)
|
272 |
|
273 |
+
# Load voice samples
|
274 |
voice_samples = []
|
|
|
275 |
for i in range(num_speakers):
|
276 |
+
# Use the appropriate voice for each speaker
|
277 |
if i < len(selected_voices):
|
278 |
voice_name = selected_voices[i]
|
279 |
if voice_name in self.available_voices and self.available_voices[voice_name]:
|
280 |
+
audio_data = self.read_audio(self.available_voices[voice_name])
|
281 |
else:
|
282 |
+
audio_data = np.zeros(24000) # Default silence
|
283 |
else:
|
284 |
+
# Use first voice or default if not enough voices selected
|
285 |
if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
|
286 |
+
audio_data = self.read_audio(self.available_voices[selected_voices[0]])
|
287 |
else:
|
288 |
+
audio_data = np.zeros(24000) # Default silence
|
289 |
|
290 |
voice_samples.append(audio_data)
|
291 |
|
292 |
+
print(f"Loaded {len(voice_samples)} voice samples")
|
293 |
+
|
294 |
+
# Process inputs
|
295 |
inputs = self.processor(
|
296 |
text=[formatted_script],
|
297 |
voice_samples=[voice_samples],
|
|
|
300 |
return_attention_mask=True,
|
301 |
)
|
302 |
|
303 |
+
# Move to device and ensure correct dtype
|
304 |
if self.device == "cuda":
|
305 |
inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
|
306 |
+
print(f"β Inputs moved to GPU")
|
307 |
+
# Check GPU memory
|
308 |
+
if torch.cuda.is_available():
|
309 |
+
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
310 |
|
311 |
+
# Create audio streamer
|
312 |
audio_streamer = AudioStreamer(
|
313 |
batch_size=1,
|
314 |
stop_signal=None,
|
|
|
317 |
|
318 |
self.current_streamer = audio_streamer
|
319 |
|
320 |
+
# Start generation in separate thread
|
321 |
generation_thread = threading.Thread(
|
322 |
target=self._generate_with_streamer,
|
323 |
args=(inputs, cfg_scale, audio_streamer)
|
324 |
)
|
325 |
generation_thread.start()
|
326 |
|
327 |
+
# Wait briefly for generation to start
|
328 |
+
time.sleep(1)
|
329 |
|
330 |
+
# Stream audio chunks
|
331 |
+
sample_rate = 24000
|
332 |
+
audio_stream = audio_streamer.get_stream(0)
|
333 |
|
334 |
+
all_audio_chunks = []
|
335 |
+
chunk_count = 0
|
|
|
|
|
|
|
336 |
|
337 |
+
for audio_chunk in audio_stream:
|
338 |
if self.stop_generation:
|
339 |
audio_streamer.end()
|
340 |
break
|
341 |
|
342 |
+
chunk_count += 1
|
343 |
+
|
344 |
+
# Convert to numpy
|
345 |
+
if torch.is_tensor(audio_chunk):
|
346 |
+
if audio_chunk.dtype == torch.bfloat16:
|
347 |
+
audio_chunk = audio_chunk.float()
|
348 |
+
audio_np = audio_chunk.cpu().numpy().astype(np.float32)
|
349 |
else:
|
350 |
+
audio_np = np.array(audio_chunk, dtype=np.float32)
|
351 |
|
352 |
+
# Ensure 1D
|
353 |
if len(audio_np.shape) > 1:
|
354 |
audio_np = audio_np.squeeze()
|
355 |
|
356 |
+
# Convert to 16-bit
|
357 |
+
audio_16bit = self.convert_to_16_bit_wav(audio_np)
|
358 |
+
all_audio_chunks.append(audio_16bit)
|
|
|
|
|
|
|
359 |
|
360 |
+
# Yield accumulated audio
|
361 |
+
if all_audio_chunks:
|
362 |
+
complete_audio = np.concatenate(all_audio_chunks)
|
363 |
+
yield (sample_rate, complete_audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
+
# Wait for generation to complete
|
366 |
+
generation_thread.join(timeout=5.0)
|
|
|
|
|
367 |
|
368 |
+
# Final yield with complete audio
|
369 |
+
if all_audio_chunks:
|
370 |
+
complete_audio = np.concatenate(all_audio_chunks)
|
371 |
+
generation_time = time.time() - start_time
|
372 |
+
audio_duration = len(complete_audio) / sample_rate
|
373 |
+
print(f"β Generation complete:")
|
374 |
+
print(f" Time taken: {generation_time:.2f} seconds")
|
375 |
+
print(f" Audio duration: {audio_duration:.2f} seconds")
|
376 |
+
print(f" Real-time factor: {audio_duration/generation_time:.2f}x")
|
377 |
+
yield (sample_rate, complete_audio)
|
378 |
|
|
|
379 |
self.current_streamer = None
|
380 |
self.is_generating = False
|
381 |
|
382 |
except Exception as e:
|
383 |
+
print(f"Error in generation: {e}")
|
384 |
import traceback
|
385 |
traceback.print_exc()
|
386 |
self.is_generating = False
|
387 |
self.current_streamer = None
|
|
|
388 |
yield None
|
389 |
|
390 |
+
def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer):
|
391 |
"""Helper method to run generation with streamer."""
|
392 |
try:
|
393 |
def check_stop():
|
394 |
return self.stop_generation
|
395 |
|
396 |
+
# Use torch.cuda.amp for mixed precision if available
|
397 |
if self.device == "cuda" and torch.cuda.is_available():
|
398 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
399 |
+
outputs = self.model.generate(
|
400 |
**inputs,
|
401 |
max_new_tokens=None,
|
402 |
cfg_scale=cfg_scale,
|
|
|
408 |
refresh_negative=True,
|
409 |
)
|
410 |
else:
|
411 |
+
outputs = self.model.generate(
|
412 |
**inputs,
|
413 |
max_new_tokens=None,
|
414 |
cfg_scale=cfg_scale,
|
|
|
423 |
print(f"Error in generation thread: {e}")
|
424 |
import traceback
|
425 |
traceback.print_exc()
|
426 |
+
audio_streamer.end()
|
427 |
+
|
428 |
+
def convert_to_16_bit_wav(self, data):
|
429 |
+
"""Convert audio data to 16-bit WAV format."""
|
430 |
+
if torch.is_tensor(data):
|
431 |
+
data = data.detach().cpu().numpy()
|
432 |
+
|
433 |
+
data = np.array(data)
|
434 |
+
|
435 |
+
if np.max(np.abs(data)) > 1.0:
|
436 |
+
data = data / np.max(np.abs(data))
|
437 |
+
|
438 |
+
data = (data * 32767).astype(np.int16)
|
439 |
+
return data
|
440 |
|
441 |
def stop_audio_generation(self):
|
442 |
+
"""Stop the current audio generation."""
|
443 |
+
self.stop_generation = True
|
444 |
+
if self.current_streamer:
|
445 |
+
try:
|
446 |
+
self.current_streamer.end()
|
447 |
+
except:
|
448 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
|
451 |
def create_chat_interface(chat_instance: VibeVoiceChat):
|
452 |
+
"""Create a simplified Gradio ChatInterface for VibeVoice."""
|
453 |
|
454 |
+
# Get available voices
|
455 |
voice_options = list(chat_instance.available_voices.keys())
|
456 |
if not voice_options:
|
457 |
voice_options = ["Default"]
|
458 |
|
459 |
default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
|
460 |
default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
|
462 |
+
# Define the chat function that returns audio
|
463 |
+
def chat_fn(message: str, history: list, voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
|
464 |
+
"""Process chat message and generate audio response."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
|
466 |
+
# Extract text from message
|
467 |
+
if isinstance(message, dict):
|
468 |
+
text = message.get("text", "")
|
469 |
+
else:
|
470 |
+
text = message
|
471 |
+
|
472 |
+
if not text.strip():
|
473 |
+
return ""
|
474 |
|
475 |
+
try:
|
476 |
+
# Generate audio stream
|
477 |
+
audio_generator = chat_instance.generate_audio_stream(
|
478 |
+
text, history, voice_1, voice_2, num_speakers, cfg_scale
|
479 |
+
)
|
480 |
|
481 |
+
# Collect all audio data
|
482 |
+
audio_data = None
|
483 |
+
for audio_chunk in audio_generator:
|
484 |
+
if audio_chunk is not None:
|
485 |
+
audio_data = audio_chunk
|
486 |
+
|
487 |
+
# Return audio file path or error message
|
488 |
+
if audio_data:
|
489 |
+
# Save audio to temporary file
|
490 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
491 |
+
sample_rate, audio_array = audio_data
|
492 |
+
sf.write(tmp_file.name, audio_array, sample_rate)
|
493 |
+
# Return the file path directly
|
494 |
+
return tmp_file.name
|
495 |
else:
|
496 |
+
return "Failed to generate audio"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
|
498 |
+
except Exception as e:
|
499 |
+
print(f"Error in chat_fn: {e}")
|
500 |
+
import traceback
|
501 |
+
traceback.print_exc()
|
502 |
+
return f"Error: {str(e)}"
|
503 |
+
|
504 |
+
# Create the interface using Blocks for more control
|
505 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True) as interface:
|
506 |
+
gr.Markdown("# ποΈ VibeVoice Chat\nGenerate natural dialogue audio with AI voices")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
|
508 |
with gr.Row():
|
509 |
+
with gr.Column(scale=1):
|
510 |
+
gr.Markdown("### Voice & Generation Settings")
|
511 |
|
512 |
voice_1 = gr.Dropdown(
|
513 |
choices=voice_options,
|
|
|
541 |
info="Guidance strength (higher = more adherence to text)"
|
542 |
)
|
543 |
|
544 |
+
with gr.Column(scale=2):
|
545 |
chatbot = gr.Chatbot(
|
546 |
label="Conversation",
|
547 |
+
height=400,
|
548 |
type="messages",
|
549 |
+
elem_id="chatbot"
|
|
|
550 |
)
|
551 |
|
552 |
msg = gr.Textbox(
|
553 |
label="Message",
|
554 |
placeholder="Type your message or paste a script...",
|
555 |
+
lines=3
|
|
|
556 |
)
|
557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
audio_output = gr.Audio(
|
559 |
+
label="Generated Audio",
|
560 |
+
type="filepath",
|
|
|
561 |
autoplay=True,
|
562 |
+
visible=False
|
563 |
+
)
|
564 |
+
|
565 |
+
with gr.Row():
|
566 |
+
submit = gr.Button("π΅ Generate Audio", variant="primary")
|
567 |
+
clear = gr.Button("ποΈ Clear")
|
568 |
+
|
569 |
+
# Example messages
|
570 |
+
gr.Examples(
|
571 |
+
examples=[
|
572 |
+
"Hello! How are you doing today?",
|
573 |
+
"Speaker 0: Welcome to our podcast!\nSpeaker 1: Thanks for having me!",
|
574 |
+
"Tell me an interesting fact about space.",
|
575 |
+
"What's your favorite type of music and why?",
|
576 |
+
],
|
577 |
+
inputs=msg,
|
578 |
+
label="Example Messages"
|
579 |
+
)
|
580 |
+
|
581 |
+
# Set up event handlers
|
582 |
+
def process_and_display(message, history, voice_1, voice_2, num_speakers, cfg_scale):
|
583 |
+
"""Process message and update both chatbot and audio."""
|
584 |
+
# Add user message to history
|
585 |
+
history = history or []
|
586 |
+
history.append({"role": "user", "content": message})
|
587 |
+
|
588 |
+
# Generate audio
|
589 |
+
audio_path = chat_fn(message, history, voice_1, voice_2, num_speakers, cfg_scale)
|
590 |
+
|
591 |
+
# Add assistant response with audio
|
592 |
+
if audio_path and audio_path.endswith('.wav'):
|
593 |
+
history.append({"role": "assistant", "content": f"π΅ Audio generated successfully"})
|
594 |
+
return history, audio_path, gr.update(visible=True), ""
|
595 |
+
else:
|
596 |
+
history.append({"role": "assistant", "content": audio_path or "Failed to generate audio"})
|
597 |
+
return history, None, gr.update(visible=False), ""
|
598 |
+
|
599 |
+
submit.click(
|
600 |
+
fn=process_and_display,
|
601 |
+
inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
|
602 |
+
outputs=[chatbot, audio_output, audio_output, msg],
|
603 |
+
queue=True
|
604 |
+
)
|
605 |
+
|
606 |
+
msg.submit(
|
607 |
+
fn=process_and_display,
|
608 |
+
inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
|
609 |
+
outputs=[chatbot, audio_output, audio_output, msg],
|
610 |
+
queue=True
|
611 |
+
)
|
612 |
+
|
613 |
+
clear.click(lambda: ([], None, gr.update(visible=False)), outputs=[chatbot, audio_output, audio_output])
|
614 |
+
|
615 |
+
return interface
|
616 |
+
|
617 |
+
|
618 |
+
def parse_args():
|
619 |
+
parser = argparse.ArgumentParser(description="VibeVoice Chat Interface")
|
620 |
+
parser.add_argument(
|
621 |
+
"--model_path",
|
622 |
+
type=str,
|
623 |
+
default="microsoft/VibeVoice-1.5B",
|
624 |
+
help="Path to the VibeVoice model",
|
625 |
+
)
|
626 |
+
parser.add_argument(
|
627 |
+
"--device",
|
628 |
+
type=str,
|
629 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
630 |
+
help="Device for inference",
|
631 |
+
)
|
632 |
+
parser.add_argument(
|
633 |
+
"--inference_steps",
|
634 |
+
type=int,
|
635 |
+
default=5,
|
636 |
+
help="Number of DDPM inference steps (lower = faster, higher = better quality)",
|
637 |
+
)
|
638 |
+
|
639 |
+
return parser.parse_args()
|
640 |
+
|
641 |
+
|
642 |
+
def main():
|
643 |
+
"""Main function to run the chat interface."""
|
644 |
+
args = parse_args()
|
645 |
+
|
646 |
+
set_seed(42)
|
647 |
+
|
648 |
+
print("ποΈ Initializing VibeVoice Chat Interface...")
|
649 |
+
|
650 |
+
# Initialize chat instance
|
651 |
+
chat_instance = VibeVoiceChat(
|
652 |
+
model_path=args.model_path,
|
653 |
+
device=args.device,
|
654 |
+
inference_steps=args.inference_steps
|
655 |
+
)
|
656 |
+
|
657 |
+
# Create interface
|
658 |
+
interface = create_chat_interface(chat_instance)
|
659 |
+
|
660 |
+
print(f"π Launching chat interface")
|
661 |
+
print(f"π Model: {args.model_path}")
|
662 |
+
print(f"π» Device: {chat_instance.device}")
|
663 |
+
print(f"π’ Inference steps: {args.inference_steps}")
|
664 |
+
print(f"π Available voices: {len(chat_instance.available_voices)}")
|
665 |
+
|
666 |
+
if chat_instance.device == "cpu":
|
667 |
+
print("\nβ οΈ WARNING: Running on CPU - generation will be VERY slow!")
|
668 |
+
print(" For faster generation, ensure you have:")
|
669 |
+
print(" 1. NVIDIA GPU with CUDA support")
|
670 |
+
print(" 2. PyTorch with CUDA installed: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
|
671 |
+
|
672 |
+
# Launch the interface
|
673 |
+
interface.queue(max_size=10).launch(
|
674 |
+
show_error=True,
|
675 |
+
quiet=False,
|
676 |
+
)
|
677 |
+
|
678 |
+
|
679 |
+
if __name__ == "__main__":
|
680 |
+
main()
|