mpasila commited on
Commit
fca1c74
·
verified ·
1 Parent(s): 355d056

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +533 -649
app.py CHANGED
@@ -1,649 +1,533 @@
1
- """
2
- Gradio UI for Text-to-Speech using HiggsAudioServeEngine
3
- """
4
-
5
- import argparse
6
- import base64
7
- import os
8
- import uuid
9
- import json
10
- from typing import Optional
11
- import gradio as gr
12
- from loguru import logger
13
- import numpy as np
14
- import time
15
- from functools import lru_cache
16
- import re
17
- import spaces
18
- import torch
19
-
20
- # Import HiggsAudio components
21
- from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
22
- from higgs_audio.data_types import ChatMLSample, AudioContent, Message
23
-
24
- # Global engine instance
25
- engine = None
26
-
27
- # Default model configuration
28
- DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
29
- DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
30
- SAMPLE_RATE = 24000
31
-
32
- DEFAULT_SYSTEM_PROMPT = (
33
- "Generate audio following instruction.\n\n"
34
- "<|scene_desc_start|>\n"
35
- "Audio is recorded from a quiet room.\n"
36
- "<|scene_desc_end|>"
37
- )
38
-
39
- DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
40
-
41
- # Predefined examples for system and input messages
42
- PREDEFINED_EXAMPLES = {
43
- "voice-clone": {
44
- "system_prompt": "",
45
- "input_text": "Hey there! I'm your friendly voice twin in the making. Pick a voice preset below or upload your own audio - let's clone some vocals and bring your voice to life! ",
46
- "description": "Voice clone to clone the reference audio. Leave the system prompt empty.",
47
- },
48
- "smart-voice": {
49
- "system_prompt": DEFAULT_SYSTEM_PROMPT,
50
- "input_text": "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.",
51
- "description": "Smart voice to generate speech based on the context",
52
- },
53
- "multispeaker-voice-description": {
54
- "system_prompt": "You are an AI assistant designed to convert text into speech.\n"
55
- "If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.\n"
56
- "If no speaker tag is present, select a suitable voice on your own.\n\n"
57
- "<|scene_desc_start|>\n"
58
- "SPEAKER0: feminine\n"
59
- "SPEAKER1: masculine\n"
60
- "<|scene_desc_end|>",
61
- "input_text": "[SPEAKER0] I can't believe you did that without even asking me first!\n"
62
- "[SPEAKER1] Oh, come on! It wasn't a big deal, and I knew you would overreact like this.\n"
63
- "[SPEAKER0] Overreact? You made a decision that affects both of us without even considering my opinion!\n"
64
- "[SPEAKER1] Because I didn't have time to sit around waiting for you to make up your mind! Someone had to act.",
65
- "description": "Multispeaker with different voice descriptions in the system prompt",
66
- },
67
- "single-speaker-voice-description": {
68
- "system_prompt": "Generate audio following instruction.\n\n"
69
- "<|scene_desc_start|>\n"
70
- "SPEAKER0: He speaks with a clear British accent and a conversational, inquisitive tone. His delivery is articulate and at a moderate pace, and very clear audio.\n"
71
- "<|scene_desc_end|>",
72
- "input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
73
- "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
74
- "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n"
75
- "\n"
76
- "So here's the big question: Do you want to understand how deep learning works?\n",
77
- "description": "Single speaker with voice description in the system prompt",
78
- },
79
- "single-speaker-zh": {
80
- "system_prompt": "Generate audio following instruction.\n\n"
81
- "<|scene_desc_start|>\n"
82
- "Audio is recorded from a quiet room.\n"
83
- "<|scene_desc_end|>",
84
- "input_text": "大家好, 欢迎收听本期的跟李沐学AI. 今天沐哥在忙着洗数据, 所以由我, 希格斯主播代替他讲这期视频.\n"
85
- "今天我们要聊的是一个你绝对不能忽视的话题: 多模态学习.\n"
86
- "那么, 问题来了, 你真的了解多模态吗? 你知道如何自己动手构建多模态大模型吗.\n"
87
- "或者说, 你能察觉到我其实是个机器人吗?",
88
- "description": "Single speaker speaking Chinese",
89
- },
90
- "single-speaker-bgm": {
91
- "system_prompt": DEFAULT_SYSTEM_PROMPT,
92
- "input_text": "[music start] I will remember this, thought Ender, when I am defeated. To keep dignity, and give honor where it's due, so that defeat is not disgrace. And I hope I don't have to do it often. [music end]",
93
- "description": "Single speaker with BGM using music tag. This is an experimental feature and you may need to try multiple times to get the best result.",
94
- },
95
- }
96
-
97
-
98
- @lru_cache(maxsize=20)
99
- def encode_audio_file(file_path):
100
- """Encode an audio file to base64."""
101
- with open(file_path, "rb") as audio_file:
102
- return base64.b64encode(audio_file.read()).decode("utf-8")
103
-
104
-
105
- def get_current_device():
106
- """Get the current device."""
107
- return "cuda" if torch.cuda.is_available() else "cpu"
108
-
109
-
110
- def load_voice_presets():
111
- """Load the voice presets from the voice_examples directory."""
112
- try:
113
- with open(
114
- os.path.join(os.path.dirname(__file__), "voice_examples", "config.json"),
115
- "r",
116
- ) as f:
117
- voice_dict = json.load(f)
118
- voice_presets = {k: v["transcript"] for k, v in voice_dict.items()}
119
- voice_presets["EMPTY"] = "No reference voice"
120
- logger.info(f"Loaded voice presets: {list(voice_presets.keys())}")
121
- return voice_presets
122
- except FileNotFoundError:
123
- logger.warning("Voice examples config file not found. Using empty voice presets.")
124
- return {"EMPTY": "No reference voice"}
125
- except Exception as e:
126
- logger.error(f"Error loading voice presets: {e}")
127
- return {"EMPTY": "No reference voice"}
128
-
129
-
130
- def get_voice_preset(voice_preset):
131
- """Get the voice path and text for a given voice preset."""
132
- voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav")
133
- if not os.path.exists(voice_path):
134
- logger.warning(f"Voice preset file not found: {voice_path}")
135
- return None, "Voice preset not found"
136
-
137
- text = VOICE_PRESETS.get(voice_preset, "No transcript available")
138
- return voice_path, text
139
-
140
-
141
- def normalize_chinese_punctuation(text):
142
- """
143
- Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
144
- """
145
- # Mapping of Chinese punctuation to English punctuation
146
- chinese_to_english_punct = {
147
- ",": ", ", # comma
148
- "": ".", # period
149
- ":": ":", # colon
150
- ";": ";", # semicolon
151
- "?": "?", # question mark
152
- "!": "!", # exclamation mark
153
- "(": "(", # left parenthesis
154
- ")": ")", # right parenthesis
155
- "【": "[", # left square bracket
156
- "】": "]", # right square bracket
157
- "《": "<", # left angle quote
158
- "》": ">", # right angle quote
159
- "“": '"', # left double quotation
160
- "”": '"', # right double quotation
161
- "‘": "'", # left single quotation
162
- "’": "'", # right single quotation
163
- "、": ",", # enumeration comma
164
- "—": "-", # em dash
165
- "…": "...", # ellipsis
166
- "·": ".", # middle dot
167
- "「": '"', # left corner bracket
168
- "": '"', # right corner bracket
169
- "『": '"', # left double corner bracket
170
- "』": '"', # right double corner bracket
171
- }
172
-
173
- # Replace each Chinese punctuation with its English counterpart
174
- for zh_punct, en_punct in chinese_to_english_punct.items():
175
- text = text.replace(zh_punct, en_punct)
176
-
177
- return text
178
-
179
-
180
- def normalize_text(transcript: str):
181
- transcript = normalize_chinese_punctuation(transcript)
182
- # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future)
183
- transcript = transcript.replace("(", " ")
184
- transcript = transcript.replace(")", " ")
185
- transcript = transcript.replace("°F", " degrees Fahrenheit")
186
- transcript = transcript.replace("°C", " degrees Celsius")
187
-
188
- for tag, replacement in [
189
- ("[laugh]", "<SE>[Laughter]</SE>"),
190
- ("[humming start]", "<SE>[Humming]</SE>"),
191
- ("[humming end]", "<SE_e>[Humming]</SE_e>"),
192
- ("[music start]", "<SE_s>[Music]</SE_s>"),
193
- ("[music end]", "<SE_e>[Music]</SE_e>"),
194
- ("[music]", "<SE>[Music]</SE>"),
195
- ("[sing start]", "<SE_s>[Singing]</SE_s>"),
196
- ("[sing end]", "<SE_e>[Singing]</SE_e>"),
197
- ("[applause]", "<SE>[Applause]</SE>"),
198
- ("[cheering]", "<SE>[Cheering]</SE>"),
199
- ("[cough]", "<SE>[Cough]</SE>"),
200
- ]:
201
- transcript = transcript.replace(tag, replacement)
202
-
203
- lines = transcript.split("\n")
204
- transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
205
- transcript = transcript.strip()
206
-
207
- if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
208
- transcript += "."
209
-
210
- return transcript
211
-
212
-
213
- @spaces.GPU
214
- def initialize_engine(model_path, audio_tokenizer_path) -> bool:
215
- """Initialize the HiggsAudioServeEngine."""
216
- global engine
217
- try:
218
- logger.info(f"Initializing engine with model: {model_path} and audio tokenizer: {audio_tokenizer_path}")
219
- engine = HiggsAudioServeEngine(
220
- model_name_or_path=model_path,
221
- audio_tokenizer_name_or_path=audio_tokenizer_path,
222
- device=get_current_device(),
223
- )
224
- logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}")
225
- return True
226
- except Exception as e:
227
- logger.error(f"Failed to initialize engine: {e}")
228
- return False
229
-
230
-
231
- def check_return_audio(audio_wv: np.ndarray):
232
- # check if the audio returned is all silent
233
- if np.all(audio_wv == 0):
234
- logger.warning("Audio is silent, returning None")
235
-
236
-
237
- def process_text_output(text_output: str):
238
- # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
239
- text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
240
- return text_output
241
-
242
-
243
- def prepare_chatml_sample(
244
- voice_preset: str,
245
- text: str,
246
- reference_audio: Optional[str] = None,
247
- reference_text: Optional[str] = None,
248
- system_prompt: str = DEFAULT_SYSTEM_PROMPT,
249
- ):
250
- """Prepare a ChatMLSample for the HiggsAudioServeEngine."""
251
- messages = []
252
-
253
- # Add system message if provided
254
- if len(system_prompt) > 0:
255
- messages.append(Message(role="system", content=system_prompt))
256
-
257
- # Add reference audio if provided
258
- audio_base64 = None
259
- ref_text = ""
260
-
261
- if reference_audio:
262
- # Custom reference audio
263
- audio_base64 = encode_audio_file(reference_audio)
264
- ref_text = reference_text or ""
265
- elif voice_preset != "EMPTY":
266
- # Voice preset
267
- voice_path, ref_text = get_voice_preset(voice_preset)
268
- if voice_path is None:
269
- logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
270
- else:
271
- audio_base64 = encode_audio_file(voice_path)
272
-
273
- # Only add reference audio if we have it
274
- if audio_base64 is not None:
275
- # Add user message with reference text
276
- messages.append(Message(role="user", content=ref_text))
277
-
278
- # Add assistant message with audio content
279
- audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
280
- messages.append(Message(role="assistant", content=[audio_content]))
281
-
282
- # Add the main user message
283
- text = normalize_text(text)
284
- messages.append(Message(role="user", content=text))
285
-
286
- return ChatMLSample(messages=messages)
287
-
288
-
289
- @spaces.GPU(duration=120)
290
- def text_to_speech(
291
- text,
292
- voice_preset,
293
- reference_audio=None,
294
- reference_text=None,
295
- max_completion_tokens=1024,
296
- temperature=1.0,
297
- top_p=0.95,
298
- top_k=50,
299
- system_prompt=DEFAULT_SYSTEM_PROMPT,
300
- stop_strings=None,
301
- ras_win_len=7,
302
- ras_win_max_num_repeat=2,
303
- ):
304
- """Convert text to speech using HiggsAudioServeEngine."""
305
- global engine
306
-
307
- if engine is None:
308
- initialize_engine(DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
309
-
310
- try:
311
- # Prepare ChatML sample
312
- chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
313
-
314
- # Convert stop strings format
315
- if stop_strings is None:
316
- stop_list = DEFAULT_STOP_STRINGS
317
- else:
318
- stop_list = [s for s in stop_strings["stops"] if s.strip()]
319
-
320
- request_id = f"tts-playground-{str(uuid.uuid4())}"
321
- logger.info(
322
- f"{request_id}: Generating speech for text: {text[:100]}..., \n"
323
- f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}, "
324
- f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}"
325
- )
326
- start_time = time.time()
327
-
328
- # Generate using the engine
329
- response = engine.generate(
330
- chat_ml_sample=chatml_sample,
331
- max_new_tokens=max_completion_tokens,
332
- temperature=temperature,
333
- top_k=top_k if top_k > 0 else None,
334
- top_p=top_p,
335
- stop_strings=stop_list,
336
- ras_win_len=ras_win_len if ras_win_len > 0 else None,
337
- ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
338
- )
339
-
340
- generation_time = time.time() - start_time
341
- logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
342
- gr.Info(f"Generated audio in {generation_time:.3f} seconds")
343
-
344
- # Process the response
345
- text_output = process_text_output(response.generated_text)
346
-
347
- if response.audio is not None:
348
- # Convert to int16 for Gradio
349
- audio_data = (response.audio * 32767).astype(np.int16)
350
- check_return_audio(audio_data)
351
- return text_output, (response.sampling_rate, audio_data)
352
- else:
353
- logger.warning("No audio generated")
354
- return text_output, None
355
-
356
- except Exception as e:
357
- error_msg = f"Error generating speech: {e}"
358
- logger.error(error_msg)
359
- gr.Error(error_msg)
360
- return f"❌ {error_msg}", None
361
-
362
-
363
- def create_ui():
364
- my_theme = gr.Theme.load("theme.json")
365
-
366
- # Add custom CSS to disable focus highlighting on textboxes
367
- custom_css = """
368
- .gradio-container input:focus,
369
- .gradio-container textarea:focus,
370
- .gradio-container select:focus,
371
- .gradio-container .gr-input:focus,
372
- .gradio-container .gr-textarea:focus,
373
- .gradio-container .gr-textbox:focus,
374
- .gradio-container .gr-textbox:focus-within,
375
- .gradio-container .gr-form:focus-within,
376
- .gradio-container *:focus {
377
- box-shadow: none !important;
378
- border-color: var(--border-color-primary) !important;
379
- outline: none !important;
380
- background-color: var(--input-background-fill) !important;
381
- }
382
-
383
- /* Override any hover effects as well */
384
- .gradio-container input:hover,
385
- .gradio-container textarea:hover,
386
- .gradio-container select:hover,
387
- .gradio-container .gr-input:hover,
388
- .gradio-container .gr-textarea:hover,
389
- .gradio-container .gr-textbox:hover {
390
- border-color: var(--border-color-primary) !important;
391
- background-color: var(--input-background-fill) !important;
392
- }
393
-
394
- /* Style for checked checkbox */
395
- .gradio-container input[type="checkbox"]:checked {
396
- background-color: var(--primary-500) !important;
397
- border-color: var(--primary-500) !important;
398
- }
399
- """
400
-
401
- default_template = "smart-voice"
402
-
403
- """Create the Gradio UI."""
404
- with gr.Blocks(theme=my_theme, css=custom_css) as demo:
405
- gr.Markdown("# Higgs Audio Text-to-Speech Playground")
406
-
407
- # Main UI section
408
- with gr.Row():
409
- with gr.Column(scale=2):
410
- # Template selection dropdown
411
- template_dropdown = gr.Dropdown(
412
- label="TTS Template",
413
- choices=list(PREDEFINED_EXAMPLES.keys()),
414
- value=default_template,
415
- info="Select a predefined example for system and input messages.",
416
- )
417
-
418
- # Template description display
419
- template_description = gr.HTML(
420
- value=f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {PREDEFINED_EXAMPLES[default_template]["description"]}</p>',
421
- visible=True,
422
- )
423
-
424
- system_prompt = gr.TextArea(
425
- label="System Prompt",
426
- placeholder="Enter system prompt to guide the model...",
427
- value=PREDEFINED_EXAMPLES[default_template]["system_prompt"],
428
- lines=2,
429
- )
430
-
431
- input_text = gr.TextArea(
432
- label="Input Text",
433
- placeholder="Type the text you want to convert to speech...",
434
- value=PREDEFINED_EXAMPLES[default_template]["input_text"],
435
- lines=5,
436
- )
437
-
438
- voice_preset = gr.Dropdown(
439
- label="Voice Preset",
440
- choices=list(VOICE_PRESETS.keys()),
441
- value="EMPTY",
442
- interactive=False, # Disabled by default since default template is not voice-clone
443
- visible=False,
444
- )
445
-
446
- with gr.Accordion(
447
- "Custom Reference (Optional)", open=False, visible=False
448
- ) as custom_reference_accordion:
449
- reference_audio = gr.Audio(label="Reference Audio", type="filepath")
450
- reference_text = gr.TextArea(
451
- label="Reference Text (transcript of the reference audio)",
452
- placeholder="Enter the transcript of your reference audio...",
453
- lines=3,
454
- )
455
-
456
- with gr.Accordion("Advanced Parameters", open=False):
457
- max_completion_tokens = gr.Slider(
458
- minimum=128,
459
- maximum=4096,
460
- value=1024,
461
- step=10,
462
- label="Max Completion Tokens",
463
- )
464
- temperature = gr.Slider(
465
- minimum=0.0,
466
- maximum=1.5,
467
- value=1.0,
468
- step=0.1,
469
- label="Temperature",
470
- )
471
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P")
472
- top_k = gr.Slider(minimum=-1, maximum=100, value=50, step=1, label="Top K")
473
- ras_win_len = gr.Slider(
474
- minimum=0,
475
- maximum=10,
476
- value=7,
477
- step=1,
478
- label="RAS Window Length",
479
- info="Window length for repetition avoidance sampling",
480
- )
481
- ras_win_max_num_repeat = gr.Slider(
482
- minimum=1,
483
- maximum=10,
484
- value=2,
485
- step=1,
486
- label="RAS Max Num Repeat",
487
- info="Maximum number of repetitions allowed in the window",
488
- )
489
- # Add stop strings component
490
- stop_strings = gr.Dataframe(
491
- label="Stop Strings",
492
- headers=["stops"],
493
- datatype=["str"],
494
- value=[[s] for s in DEFAULT_STOP_STRINGS],
495
- interactive=True,
496
- col_count=(1, "fixed"),
497
- )
498
-
499
- submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
500
-
501
- with gr.Column(scale=2):
502
- output_text = gr.TextArea(label="Model Response", lines=2)
503
-
504
- # Audio output
505
- output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
506
-
507
- stop_btn = gr.Button("Stop Playback", variant="primary")
508
-
509
- # Example voice
510
- with gr.Row(visible=False) as voice_samples_section:
511
- voice_samples_table = gr.Dataframe(
512
- headers=["Voice Preset", "Sample Text"],
513
- datatype=["str", "str"],
514
- value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"],
515
- interactive=False,
516
- )
517
- sample_audio = gr.Audio(label="Voice Sample")
518
-
519
- # Function to play voice sample when clicking on a row
520
- def play_voice_sample(evt: gr.SelectData):
521
- try:
522
- # Get the preset name from the clicked row
523
- preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
524
- if evt.index[0] < len(preset_names):
525
- preset = preset_names[evt.index[0]]
526
- voice_path, _ = get_voice_preset(preset)
527
- if voice_path and os.path.exists(voice_path):
528
- return voice_path
529
- else:
530
- gr.Warning(f"Voice sample file not found for preset: {preset}")
531
- return None
532
- else:
533
- gr.Warning("Invalid voice preset selection")
534
- return None
535
- except Exception as e:
536
- logger.error(f"Error playing voice sample: {e}")
537
- gr.Error(f"Error playing voice sample: {e}")
538
- return None
539
-
540
- voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
541
-
542
- # Function to handle template selection
543
- def apply_template(template_name):
544
- if template_name in PREDEFINED_EXAMPLES:
545
- template = PREDEFINED_EXAMPLES[template_name]
546
- # Enable voice preset and custom reference only for voice-clone template
547
- is_voice_clone = template_name == "voice-clone"
548
- voice_preset_value = "belinda" if is_voice_clone else "EMPTY"
549
- # Set ras_win_len to 0 for single-speaker-bgm, 7 for others
550
- ras_win_len_value = 0 if template_name == "single-speaker-bgm" else 7
551
- description_text = f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {template["description"]}</p>'
552
- return (
553
- template["system_prompt"], # system_prompt
554
- template["input_text"], # input_text
555
- description_text, # template_description
556
- gr.update(
557
- value=voice_preset_value, interactive=is_voice_clone, visible=is_voice_clone
558
- ), # voice_preset (value and interactivity)
559
- gr.update(visible=is_voice_clone), # custom reference accordion visibility
560
- gr.update(visible=is_voice_clone), # voice samples section visibility
561
- ras_win_len_value, # ras_win_len
562
- )
563
- else:
564
- return (
565
- gr.update(),
566
- gr.update(),
567
- gr.update(),
568
- gr.update(),
569
- gr.update(),
570
- gr.update(),
571
- gr.update(),
572
- ) # No change if template not found
573
-
574
- # Set up event handlers
575
-
576
- # Connect template dropdown to handler
577
- template_dropdown.change(
578
- fn=apply_template,
579
- inputs=[template_dropdown],
580
- outputs=[
581
- system_prompt,
582
- input_text,
583
- template_description,
584
- voice_preset,
585
- custom_reference_accordion,
586
- voice_samples_section,
587
- ras_win_len,
588
- ],
589
- )
590
-
591
- # Connect submit button to the TTS function
592
- submit_btn.click(
593
- fn=text_to_speech,
594
- inputs=[
595
- input_text,
596
- voice_preset,
597
- reference_audio,
598
- reference_text,
599
- max_completion_tokens,
600
- temperature,
601
- top_p,
602
- top_k,
603
- system_prompt,
604
- stop_strings,
605
- ras_win_len,
606
- ras_win_max_num_repeat,
607
- ],
608
- outputs=[output_text, output_audio],
609
- api_name="generate_speech",
610
- )
611
-
612
- # Stop button functionality
613
- stop_btn.click(
614
- fn=lambda: None,
615
- inputs=[],
616
- outputs=[output_audio],
617
- js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
618
- )
619
-
620
- return demo
621
-
622
-
623
- def main():
624
- """Main function to parse arguments and launch the UI."""
625
- global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
626
-
627
- parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
628
- parser.add_argument(
629
- "--device",
630
- type=str,
631
- default="cuda",
632
- choices=["cuda", "cpu"],
633
- help="Device to run the model on.",
634
- )
635
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.")
636
- parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
637
-
638
- args = parser.parse_args()
639
-
640
- # Update default values if provided via command line
641
- VOICE_PRESETS = load_voice_presets()
642
-
643
- # Create and launch the UI
644
- demo = create_ui()
645
- demo.launch(server_name=args.host, server_port=args.port)
646
-
647
-
648
- if __name__ == "__main__":
649
- main()
 
1
+ """
2
+ Gradio UI for Text-to-Speech using HiggsAudioServeEngine
3
+ Adapted: Now compatible with Jupyter, Colab, Runpod, etc,
4
+ by adding launch_notebook() and flexible path/context handling.
5
+ """
6
+
7
+ import argparse
8
+ import base64
9
+ import os
10
+ import uuid
11
+ import json
12
+ from typing import Optional
13
+ import gradio as gr
14
+ from loguru import logger
15
+ import numpy as np
16
+ import time
17
+ from functools import lru_cache
18
+ import re
19
+ import torch
20
+
21
+ # --- Safe import or stub for 'spaces' (for Huggingface Space only) ---
22
+ try:
23
+ import spaces
24
+ except ImportError:
25
+ class DummySpaces:
26
+ def __getattr__(self, name): # any decorator
27
+ return lambda *a, **k: (lambda f: f)
28
+ spaces = DummySpaces()
29
+
30
+ # Import HiggsAudio components
31
+ from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
32
+ from higgs_audio.data_types import ChatMLSample, AudioContent, Message
33
+
34
+ # --- Add this for Colab/notebook path safety ---
35
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
36
+
37
+ # Global engine/voice instance
38
+ engine = None
39
+ VOICE_PRESETS = {}
40
+
41
+ # Default model configuration
42
+ DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
43
+ DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
44
+ SAMPLE_RATE = 24000
45
+
46
+ DEFAULT_SYSTEM_PROMPT = (
47
+ "Generate audio following instruction.\n\n"
48
+ "<|scene_desc_start|>\n"
49
+ "Audio is recorded from a quiet room.\n"
50
+ "<|scene_desc_end|>"
51
+ )
52
+
53
+ DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
54
+
55
+ # ... PREDEFINED_EXAMPLES as before ...
56
+
57
+ # (copy unchanged; omitted for brevity in this answer but use your full PREDEFINED_EXAMPLES dictionary)
58
+
59
+ PREDEFINED_EXAMPLES = {
60
+ # ... Same as your long dict above ...
61
+ # (copy full version from original)
62
+ # (you can copy exactly as in your current app.py)
63
+ }
64
+
65
+ # -- The rest of your code, but replacing path joins to use BASE_DIR instead of __file__! ---
66
+
67
+ @lru_cache(maxsize=20)
68
+ def encode_audio_file(file_path):
69
+ """Encode an audio file to base64."""
70
+ with open(file_path, "rb") as audio_file:
71
+ return base64.b64encode(audio_file.read()).decode("utf-8")
72
+
73
+ def get_current_device():
74
+ """Get the current device."""
75
+ return "cuda" if torch.cuda.is_available() else "cpu"
76
+
77
+ def load_voice_presets():
78
+ """Load the voice presets from the voice_examples directory."""
79
+ try:
80
+ with open(
81
+ os.path.join(BASE_DIR, "voice_examples", "config.json"),
82
+ "r",
83
+ ) as f:
84
+ voice_dict = json.load(f)
85
+ voice_presets = {k: v["transcript"] for k, v in voice_dict.items()}
86
+ voice_presets["EMPTY"] = "No reference voice"
87
+ logger.info(f"Loaded voice presets: {list(voice_presets.keys())}")
88
+ return voice_presets
89
+ except FileNotFoundError:
90
+ logger.warning("Voice examples config file not found. Using empty voice presets.")
91
+ return {"EMPTY": "No reference voice"}
92
+ except Exception as e:
93
+ logger.error(f"Error loading voice presets: {e}")
94
+ return {"EMPTY": "No reference voice"}
95
+
96
+ def get_voice_preset(voice_preset):
97
+ """Get the voice path and text for a given voice preset."""
98
+ voice_path = os.path.join(BASE_DIR, "voice_examples", f"{voice_preset}.wav")
99
+ if not os.path.exists(voice_path):
100
+ logger.warning(f"Voice preset file not found: {voice_path}")
101
+ return None, "Voice preset not found"
102
+
103
+ text = VOICE_PRESETS.get(voice_preset, "No transcript available")
104
+ return voice_path, text
105
+
106
+ # -- rest of your normalization and utility code unchanged --
107
+
108
+ def normalize_chinese_punctuation(text):
109
+ # ... as before ...
110
+ chinese_to_english_punct = {
111
+ # ... as before ...
112
+ }
113
+ for zh_punct, en_punct in chinese_to_english_punct.items():
114
+ text = text.replace(zh_punct, en_punct)
115
+ return text
116
+
117
+ def normalize_text(transcript: str):
118
+ # ... as before, unchanged ...
119
+ transcript = normalize_chinese_punctuation(transcript)
120
+ transcript = transcript.replace("(", " ")
121
+ transcript = transcript.replace(")", " ")
122
+ transcript = transcript.replace("°F", " degrees Fahrenheit")
123
+ transcript = transcript.replace("°C", " degrees Celsius")
124
+ for tag, replacement in [
125
+ ("[laugh]", "<SE>[Laughter]</SE>"),
126
+ ("[humming start]", "<SE>[Humming]</SE>"),
127
+ ("[humming end]", "<SE_e>[Humming]</SE_e>"),
128
+ ("[music start]", "<SE_s>[Music]</SE_s>"),
129
+ ("[music end]", "<SE_e>[Music]</SE_e>"),
130
+ ("[music]", "<SE>[Music]</SE>"),
131
+ ("[sing start]", "<SE_s>[Singing]</SE_s>"),
132
+ ("[sing end]", "<SE_e>[Singing]</SE_e>"),
133
+ ("[applause]", "<SE>[Applause]</SE>"),
134
+ ("[cheering]", "<SE>[Cheering]</SE>"),
135
+ ("[cough]", "<SE>[Cough]</SE>"),
136
+ ]:
137
+ transcript = transcript.replace(tag, replacement)
138
+ # ... rest unchanged ...
139
+ lines = transcript.split("\n")
140
+ transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
141
+ transcript = transcript.strip()
142
+ if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
143
+ transcript += "."
144
+ return transcript
145
+
146
+ @spaces.GPU
147
+ def initialize_engine(model_path, audio_tokenizer_path) -> bool:
148
+ """Initialize the HiggsAudioServeEngine."""
149
+ global engine
150
+ try:
151
+ logger.info(f"Initializing engine with model: {model_path} and audio tokenizer: {audio_tokenizer_path}")
152
+ engine = HiggsAudioServeEngine(
153
+ model_name_or_path=model_path,
154
+ audio_tokenizer_name_or_path=audio_tokenizer_path,
155
+ device=get_current_device(),
156
+ )
157
+ logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}")
158
+ return True
159
+ except Exception as e:
160
+ logger.error(f"Failed to initialize engine: {e}")
161
+ return False
162
+
163
+ def check_return_audio(audio_wv: np.ndarray):
164
+ if np.all(audio_wv == 0):
165
+ logger.warning("Audio is silent, returning None")
166
+
167
+ def process_text_output(text_output: str):
168
+ text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
169
+ return text_output
170
+
171
+ def prepare_chatml_sample(
172
+ voice_preset: str,
173
+ text: str,
174
+ reference_audio: Optional[str] = None,
175
+ reference_text: Optional[str] = None,
176
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
177
+ ):
178
+ messages = []
179
+ if len(system_prompt) > 0:
180
+ messages.append(Message(role="system", content=system_prompt))
181
+ audio_base64 = None
182
+ ref_text = ""
183
+ if reference_audio:
184
+ audio_base64 = encode_audio_file(reference_audio)
185
+ ref_text = reference_text or ""
186
+ elif voice_preset != "EMPTY":
187
+ voice_path, ref_text = get_voice_preset(voice_preset)
188
+ if voice_path is None:
189
+ logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
190
+ else:
191
+ audio_base64 = encode_audio_file(voice_path)
192
+ if audio_base64 is not None:
193
+ messages.append(Message(role="user", content=ref_text))
194
+ audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
195
+ messages.append(Message(role="assistant", content=[audio_content]))
196
+ text = normalize_text(text)
197
+ messages.append(Message(role="user", content=text))
198
+ return ChatMLSample(messages=messages)
199
+
200
+ @spaces.GPU(duration=120)
201
+ def text_to_speech(
202
+ text,
203
+ voice_preset,
204
+ reference_audio=None,
205
+ reference_text=None,
206
+ max_completion_tokens=1024,
207
+ temperature=1.0,
208
+ top_p=0.95,
209
+ top_k=50,
210
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
211
+ stop_strings=None,
212
+ ras_win_len=7,
213
+ ras_win_max_num_repeat=2,
214
+ ):
215
+ global engine
216
+ if engine is None:
217
+ initialize_engine(DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
218
+ try:
219
+ chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
220
+ if stop_strings is None:
221
+ stop_list = DEFAULT_STOP_STRINGS
222
+ else:
223
+ stop_list = [s for s in stop_strings["stops"] if s.strip()]
224
+ request_id = f"tts-playground-{str(uuid.uuid4())}"
225
+ logger.info(
226
+ f"{request_id}: Generating speech for text: {text[:100]}..., \n"
227
+ f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}, "
228
+ f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}"
229
+ )
230
+ start_time = time.time()
231
+ response = engine.generate(
232
+ chat_ml_sample=chatml_sample,
233
+ max_new_tokens=max_completion_tokens,
234
+ temperature=temperature,
235
+ top_k=top_k if top_k > 0 else None,
236
+ top_p=top_p,
237
+ stop_strings=stop_list,
238
+ ras_win_len=ras_win_len if ras_win_len > 0 else None,
239
+ ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
240
+ )
241
+ generation_time = time.time() - start_time
242
+ logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
243
+ gr.Info(f"Generated audio in {generation_time:.3f} seconds")
244
+ text_output = process_text_output(response.generated_text)
245
+ if response.audio is not None:
246
+ audio_data = (response.audio * 32767).astype(np.int16)
247
+ check_return_audio(audio_data)
248
+ return text_output, (response.sampling_rate, audio_data)
249
+ else:
250
+ logger.warning("No audio generated")
251
+ return text_output, None
252
+ except Exception as e:
253
+ error_msg = f"Error generating speech: {e}"
254
+ logger.error(error_msg)
255
+ gr.Error(error_msg)
256
+ return f"❌ {error_msg}", None
257
+
258
+ def create_ui():
259
+ my_theme = gr.Theme.load(os.path.join(BASE_DIR, "theme.json"))
260
+ custom_css = """
261
+ .gradio-container input:focus,
262
+ .gradio-container textarea:focus,
263
+ .gradio-container select:focus,
264
+ .gradio-container .gr-input:focus,
265
+ .gradio-container .gr-textarea:focus,
266
+ .gradio-container .gr-textbox:focus,
267
+ .gradio-container .gr-textbox:focus-within,
268
+ .gradio-container .gr-form:focus-within,
269
+ .gradio-container *:focus {
270
+ box-shadow: none !important;
271
+ border-color: var(--border-color-primary) !important;
272
+ outline: none !important;
273
+ background-color: var(--input-background-fill) !important;
274
+ }
275
+ .gradio-container input:hover,
276
+ .gradio-container textarea:hover,
277
+ .gradio-container select:hover,
278
+ .gradio-container .gr-input:hover,
279
+ .gradio-container .gr-textarea:hover,
280
+ .gradio-container .gr-textbox:hover {
281
+ border-color: var(--border-color-primary) !important;
282
+ background-color: var(--input-background-fill) !important;
283
+ }
284
+ .gradio-container input[type="checkbox"]:checked {
285
+ background-color: var(--primary-500) !important;
286
+ border-color: var(--primary-500) !important;
287
+ }
288
+ """
289
+ default_template = "smart-voice"
290
+ with gr.Blocks(theme=my_theme, css=custom_css) as demo:
291
+ gr.Markdown("# Higgs Audio Text-to-Speech Playground")
292
+ with gr.Row():
293
+ with gr.Column(scale=2):
294
+ template_dropdown = gr.Dropdown(
295
+ label="TTS Template",
296
+ choices=list(PREDEFINED_EXAMPLES.keys()),
297
+ value=default_template,
298
+ info="Select a predefined example for system and input messages.",
299
+ )
300
+ template_description = gr.HTML(
301
+ value=f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {PREDEFINED_EXAMPLES[default_template]["description"]}</p>',
302
+ visible=True,
303
+ )
304
+ system_prompt = gr.TextArea(
305
+ label="System Prompt",
306
+ placeholder="Enter system prompt to guide the model...",
307
+ value=PREDEFINED_EXAMPLES[default_template]["system_prompt"],
308
+ lines=2,
309
+ )
310
+ input_text = gr.TextArea(
311
+ label="Input Text",
312
+ placeholder="Type the text you want to convert to speech...",
313
+ value=PREDEFINED_EXAMPLES[default_template]["input_text"],
314
+ lines=5,
315
+ )
316
+ voice_preset = gr.Dropdown(
317
+ label="Voice Preset",
318
+ choices=list(VOICE_PRESETS.keys()),
319
+ value="EMPTY",
320
+ interactive=False,
321
+ visible=False,
322
+ )
323
+ with gr.Accordion(
324
+ "Custom Reference (Optional)", open=False, visible=False
325
+ ) as custom_reference_accordion:
326
+ reference_audio = gr.Audio(label="Reference Audio", type="filepath")
327
+ reference_text = gr.TextArea(
328
+ label="Reference Text (transcript of the reference audio)",
329
+ placeholder="Enter the transcript of your reference audio...",
330
+ lines=3,
331
+ )
332
+ with gr.Accordion("Advanced Parameters", open=False):
333
+ max_completion_tokens = gr.Slider(
334
+ minimum=128,
335
+ maximum=4096,
336
+ value=1024,
337
+ step=10,
338
+ label="Max Completion Tokens",
339
+ )
340
+ temperature = gr.Slider(
341
+ minimum=0.0,
342
+ maximum=1.5,
343
+ value=1.0,
344
+ step=0.1,
345
+ label="Temperature",
346
+ )
347
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P")
348
+ top_k = gr.Slider(minimum=-1, maximum=100, value=50, step=1, label="Top K")
349
+ ras_win_len = gr.Slider(
350
+ minimum=0,
351
+ maximum=10,
352
+ value=7,
353
+ step=1,
354
+ label="RAS Window Length",
355
+ info="Window length for repetition avoidance sampling",
356
+ )
357
+ ras_win_max_num_repeat = gr.Slider(
358
+ minimum=1,
359
+ maximum=10,
360
+ value=2,
361
+ step=1,
362
+ label="RAS Max Num Repeat",
363
+ info="Maximum number of repetitions allowed in the window",
364
+ )
365
+ stop_strings = gr.Dataframe(
366
+ label="Stop Strings",
367
+ headers=["stops"],
368
+ datatype=["str"],
369
+ value=[[s] for s in DEFAULT_STOP_STRINGS],
370
+ interactive=True,
371
+ col_count=(1, "fixed"),
372
+ )
373
+ submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
374
+ with gr.Column(scale=2):
375
+ output_text = gr.TextArea(label="Model Response", lines=2)
376
+ output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
377
+ stop_btn = gr.Button("Stop Playback", variant="primary")
378
+ with gr.Row(visible=False) as voice_samples_section:
379
+ voice_samples_table = gr.Dataframe(
380
+ headers=["Voice Preset", "Sample Text"],
381
+ datatype=["str", "str"],
382
+ value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"],
383
+ interactive=False,
384
+ )
385
+ sample_audio = gr.Audio(label="Voice Sample")
386
+
387
+ def play_voice_sample(evt: gr.SelectData):
388
+ try:
389
+ preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
390
+ if evt.index[0] < len(preset_names):
391
+ preset = preset_names[evt.index[0]]
392
+ voice_path, _ = get_voice_preset(preset)
393
+ if voice_path and os.path.exists(voice_path):
394
+ return voice_path
395
+ else:
396
+ gr.Warning(f"Voice sample file not found for preset: {preset}")
397
+ return None
398
+ else:
399
+ gr.Warning("Invalid voice preset selection")
400
+ return None
401
+ except Exception as e:
402
+ logger.error(f"Error playing voice sample: {e}")
403
+ gr.Error(f"Error playing voice sample: {e}")
404
+ return None
405
+
406
+ voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
407
+
408
+ def apply_template(template_name):
409
+ if template_name in PREDEFINED_EXAMPLES:
410
+ template = PREDEFINED_EXAMPLES[template_name]
411
+ is_voice_clone = template_name == "voice-clone"
412
+ voice_preset_value = "belinda" if is_voice_clone else "EMPTY"
413
+ ras_win_len_value = 0 if template_name == "single-speaker-bgm" else 7
414
+ description_text = f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {template["description"]}</p>'
415
+ return (
416
+ template["system_prompt"], # system_prompt
417
+ template["input_text"], # input_text
418
+ description_text, # template_description
419
+ gr.update(
420
+ value=voice_preset_value, interactive=is_voice_clone, visible=is_voice_clone
421
+ ),
422
+ gr.update(visible=is_voice_clone),
423
+ gr.update(visible=is_voice_clone),
424
+ ras_win_len_value,
425
+ )
426
+ else:
427
+ return (
428
+ gr.update(),
429
+ gr.update(),
430
+ gr.update(),
431
+ gr.update(),
432
+ gr.update(),
433
+ gr.update(),
434
+ gr.update(),
435
+ )
436
+
437
+ template_dropdown.change(
438
+ fn=apply_template,
439
+ inputs=[template_dropdown],
440
+ outputs=[
441
+ system_prompt,
442
+ input_text,
443
+ template_description,
444
+ voice_preset,
445
+ custom_reference_accordion,
446
+ voice_samples_section,
447
+ ras_win_len,
448
+ ],
449
+ )
450
+
451
+ submit_btn.click(
452
+ fn=text_to_speech,
453
+ inputs=[
454
+ input_text,
455
+ voice_preset,
456
+ reference_audio,
457
+ reference_text,
458
+ max_completion_tokens,
459
+ temperature,
460
+ top_p,
461
+ top_k,
462
+ system_prompt,
463
+ stop_strings,
464
+ ras_win_len,
465
+ ras_win_max_num_repeat,
466
+ ],
467
+ outputs=[output_text, output_audio],
468
+ api_name="generate_speech",
469
+ )
470
+ stop_btn.click(
471
+ fn=lambda: None,
472
+ inputs=[],
473
+ outputs=[output_audio],
474
+ js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
475
+ )
476
+ return demo
477
+
478
+ # ------ NEW! Notebook/Colab/Runpod Launch Function ------
479
+ def launch_notebook(
480
+ model_path=DEFAULT_MODEL_PATH,
481
+ audio_tokenizer_path=DEFAULT_AUDIO_TOKENIZER_PATH,
482
+ device=None,
483
+ host="127.0.0.1",
484
+ port=7860,
485
+ inline=True,
486
+ share=False,
487
+ **gradio_kwargs
488
+ ):
489
+ """
490
+ Launch the Gradio UI inside a notebook, Colab or script.
491
+ - If inline=True (default), embeds in cell (Jupyter/Colab/Runpod, etc).
492
+ - If share=True, Gradio will provide a public URL for the UI.
493
+ """
494
+ global VOICE_PRESETS
495
+ VOICE_PRESETS = load_voice_presets()
496
+
497
+ # Optionally initialize engine, or let it lazy init on first use
498
+ # initialize_engine(model_path, audio_tokenizer_path)
499
+
500
+ demo = create_ui()
501
+ # Note: You can also pass other gradio launch kwargs here if desired.
502
+ demo.launch(
503
+ server_name=host,
504
+ server_port=port,
505
+ inline=inline,
506
+ share=share,
507
+ **gradio_kwargs,
508
+ )
509
+
510
+ def main():
511
+ """
512
+ Main function to parse arguments and launch the UI via CLI (notebooks should use launch_notebook()).
513
+ """
514
+ global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
515
+
516
+ parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
517
+ parser.add_argument(
518
+ "--device",
519
+ type=str,
520
+ default="cuda",
521
+ choices=["cuda", "cpu"],
522
+ help="Device to run the model on.",
523
+ )
524
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.")
525
+ parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
526
+
527
+ args = parser.parse_args()
528
+ VOICE_PRESETS = load_voice_presets()
529
+ demo = create_ui()
530
+ demo.launch(server_name=args.host, server_port=args.port)
531
+
532
+ if __name__ == "__main__":
533
+ main()