akhaliq HF Staff commited on
Commit
ad908e5
Β·
verified Β·
1 Parent(s): 8b711a8

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +34 -16
app.py CHANGED
@@ -101,12 +101,19 @@ class VibeVoiceChat:
101
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
102
  """Initialize the VibeVoice chat model."""
103
  self.model_path = model_path
104
- self.device = device
105
  self.inference_steps = inference_steps
106
  self.is_generating = False
107
  self.stop_generation = False
108
  self.current_streamer = None
109
 
 
 
 
 
 
 
 
110
  self.load_model()
111
  self.setup_voice_presets()
112
 
@@ -116,11 +123,20 @@ class VibeVoiceChat:
116
 
117
  self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
118
 
119
- self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
120
- self.model_path,
121
- torch_dtype=torch.bfloat16,
122
- device_map='cuda',
123
- )
 
 
 
 
 
 
 
 
 
124
  self.model.eval()
125
 
126
  # Configure noise scheduler
@@ -244,6 +260,10 @@ class VibeVoiceChat:
244
  return_attention_mask=True,
245
  )
246
 
 
 
 
 
247
  # Create audio streamer
248
  audio_streamer = AudioStreamer(
249
  batch_size=1,
@@ -297,6 +317,8 @@ class VibeVoiceChat:
297
 
298
  except Exception as e:
299
  print(f"Error in generation: {e}")
 
 
300
  self.is_generating = False
301
  self.current_streamer = None
302
  yield None
@@ -320,6 +342,8 @@ class VibeVoiceChat:
320
  )
321
  except Exception as e:
322
  print(f"Error in generation thread: {e}")
 
 
323
  audio_streamer.end()
324
 
325
  def convert_to_16_bit_wav(self, data):
@@ -385,6 +409,8 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
385
  return gr.Audio(value=None)
386
  except Exception as e:
387
  print(f"Error in chat_fn: {e}")
 
 
388
  return gr.Audio(value=None)
389
 
390
  # Create additional inputs
@@ -419,21 +445,12 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
419
  )
420
  ]
421
 
422
- # Example conversations - formatted as list of lists when additional_inputs are provided
423
- examples = [
424
- ["Welcome to our AI podcast! Today we're discussing the future of technology.", default_voice_1, default_voice_2, 2, 1.3],
425
- ["Speaker 0: What's your favorite programming language?\nSpeaker 1: I really enjoy Python for its simplicity.", default_voice_1, default_voice_2, 2, 1.3],
426
- ["Tell me an interesting fact about space exploration.", default_voice_1, default_voice_1, 1, 1.3],
427
- ["Speaker 0: How do you stay productive?\nSpeaker 1: I use the Pomodoro technique and take regular breaks.", default_voice_1, default_voice_2, 2, 1.3],
428
- ]
429
-
430
- # Create the ChatInterface
431
  interface = gr.ChatInterface(
432
  fn=chat_fn,
433
  type="messages",
434
  title="πŸŽ™οΈ VibeVoice Chat",
435
  description="Generate natural dialogue audio with AI voices. Type your message or paste a script!",
436
- examples=examples,
437
  additional_inputs=additional_inputs,
438
  additional_inputs_accordion=gr.Accordion(label="Voice & Generation Settings", open=True),
439
  submit_btn="🎡 Generate Audio",
@@ -514,6 +531,7 @@ def main():
514
 
515
  print(f"πŸš€ Launching chat interface on port {args.port}")
516
  print(f"πŸ“ Model: {args.model_path}")
 
517
  print(f"🎭 Available voices: {len(chat_instance.available_voices)}")
518
 
519
  # Launch the interface
 
101
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
102
  """Initialize the VibeVoice chat model."""
103
  self.model_path = model_path
104
+ self.device = device if torch.cuda.is_available() else "cpu"
105
  self.inference_steps = inference_steps
106
  self.is_generating = False
107
  self.stop_generation = False
108
  self.current_streamer = None
109
 
110
+ # Check GPU availability
111
+ if torch.cuda.is_available():
112
+ print(f"βœ“ GPU detected: {torch.cuda.get_device_name(0)}")
113
+ print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
114
+ else:
115
+ print("βœ— No GPU detected, using CPU (generation will be slower)")
116
+
117
  self.load_model()
118
  self.setup_voice_presets()
119
 
 
123
 
124
  self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
125
 
126
+ if torch.cuda.is_available():
127
+ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
128
+ self.model_path,
129
+ torch_dtype=torch.bfloat16,
130
+ device_map='cuda',
131
+ attn_implementation="flash_attention_2",
132
+ )
133
+ else:
134
+ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
135
+ self.model_path,
136
+ torch_dtype=torch.float32,
137
+ device_map='cpu',
138
+ )
139
+
140
  self.model.eval()
141
 
142
  # Configure noise scheduler
 
260
  return_attention_mask=True,
261
  )
262
 
263
+ # Move to device
264
+ if self.device == "cuda":
265
+ inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
266
+
267
  # Create audio streamer
268
  audio_streamer = AudioStreamer(
269
  batch_size=1,
 
317
 
318
  except Exception as e:
319
  print(f"Error in generation: {e}")
320
+ import traceback
321
+ traceback.print_exc()
322
  self.is_generating = False
323
  self.current_streamer = None
324
  yield None
 
342
  )
343
  except Exception as e:
344
  print(f"Error in generation thread: {e}")
345
+ import traceback
346
+ traceback.print_exc()
347
  audio_streamer.end()
348
 
349
  def convert_to_16_bit_wav(self, data):
 
409
  return gr.Audio(value=None)
410
  except Exception as e:
411
  print(f"Error in chat_fn: {e}")
412
+ import traceback
413
+ traceback.print_exc()
414
  return gr.Audio(value=None)
415
 
416
  # Create additional inputs
 
445
  )
446
  ]
447
 
448
+ # Create the ChatInterface without examples to avoid the error
 
 
 
 
 
 
 
 
449
  interface = gr.ChatInterface(
450
  fn=chat_fn,
451
  type="messages",
452
  title="πŸŽ™οΈ VibeVoice Chat",
453
  description="Generate natural dialogue audio with AI voices. Type your message or paste a script!",
 
454
  additional_inputs=additional_inputs,
455
  additional_inputs_accordion=gr.Accordion(label="Voice & Generation Settings", open=True),
456
  submit_btn="🎡 Generate Audio",
 
531
 
532
  print(f"πŸš€ Launching chat interface on port {args.port}")
533
  print(f"πŸ“ Model: {args.model_path}")
534
+ print(f"πŸ’» Device: {chat_instance.device}")
535
  print(f"🎭 Available voices: {len(chat_instance.available_voices)}")
536
 
537
  # Launch the interface