bluenevus commited on
Commit
f3b14e5
·
verified ·
1 Parent(s): 1374f14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -188
app.py CHANGED
@@ -1,113 +1,64 @@
1
- import gradio as gr
2
- import google.generativeai as genai
3
- import numpy as np
4
- import re
5
  import torch
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from huggingface_hub import snapshot_download
8
- import logging
9
- import os
10
- import spaces
11
- import warnings
12
- from snac import SNAC
13
-
14
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
- logger = logging.getLogger(__name__)
16
-
17
- warnings.filterwarnings("ignore", category=UserWarning)
18
- warnings.filterwarnings("ignore", category=RuntimeWarning)
19
 
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
- logger.info(f"Using device: {device}")
22
-
23
- model = None
24
- tokenizer = None
25
- snac_model = None
26
-
27
- @spaces.GPU()
28
- def load_model():
29
- global model, tokenizer, snac_model
30
-
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
-
33
- print("Loading SNAC model...")
34
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
35
- snac_model = snac_model.to(device)
36
-
37
- print("Loading Orpheus model...")
38
- model_name = "canopylabs/orpheus-3b-0.1-ft"
39
-
40
- snapshot_download(
41
- repo_id=model_name,
42
- allow_patterns=[
43
- "config.json",
44
- "*.safetensors",
45
- "model.safetensors.index.json",
46
- "tokenizer.json",
47
- "tokenizer_config.json",
48
- "special_tokens_map.json",
49
- "vocab.json",
50
- "merges.txt",
51
- ],
52
- ignore_patterns=[
53
- "optimizer.pt",
54
- "pytorch_model.bin",
55
- "training_args.bin",
56
- "scheduler.pt",
57
- ]
58
- )
59
-
60
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
61
- model.to(device)
62
- tokenizer = AutoTokenizer.from_pretrained(model_name)
63
- print(f"Orpheus model and tokenizer loaded to {device}")
64
-
65
- @spaces.GPU()
66
- def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
67
- try:
68
- genai.configure(api_key=api_key)
69
- model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
70
-
71
- combined_content = content or ""
72
- if uploaded_file:
73
- file_content = uploaded_file.read().decode('utf-8')
74
- combined_content += "\n" + file_content if combined_content else file_content
75
-
76
- prompt = f"""
77
- Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing:
78
- {combined_content}
79
-
80
- Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts.
81
- Use speech fillers like um, ah. Vary emotional tone.
82
-
83
- Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels.
84
- Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines.
85
-
86
- Use emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>.
87
-
88
- Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>."
89
-
90
- Ensure content flows naturally and stays on topic. Match the script length to {duration}.
91
- """
92
-
93
- response = model.generate_content(prompt)
94
- return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
95
- except Exception as e:
96
- logger.error(f"Error generating podcast script: {str(e)}")
97
- raise
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def process_prompt(prompt, voice, tokenizer, device):
100
  prompt = f"{voice}: {prompt}"
101
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
102
 
103
- start_token = torch.tensor([[128259]], dtype=torch.int64)
104
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
 
 
105
 
106
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
107
  attention_mask = torch.ones_like(modified_input_ids)
108
 
109
  return modified_input_ids.to(device), attention_mask.to(device)
110
 
 
111
  def parse_output(generated_ids):
112
  token_to_find = 128257
113
  token_to_remove = 128258
@@ -133,47 +84,45 @@ def parse_output(generated_ids):
133
  trimmed_row = [t - 128266 for t in trimmed_row]
134
  code_lists.append(trimmed_row)
135
 
136
- return code_lists[0]
137
 
 
138
  def redistribute_codes(code_list, snac_model):
139
- try:
140
- device = next(snac_model.parameters()).device
141
-
142
- layer_1, layer_2, layer_3 = [], [], []
143
- for i in range((len(code_list)+1)//7):
144
- layer_1.append(code_list[7*i])
145
- layer_2.append(code_list[7*i+1]-4096)
146
- layer_3.append(code_list[7*i+2]-(2*4096))
147
- layer_3.append(code_list[7*i+3]-(3*4096))
148
- layer_2.append(code_list[7*i+4]-(4*4096))
149
- layer_3.append(code_list[7*i+5]-(5*4096))
150
- layer_3.append(code_list[7*i+6]-(6*4096))
151
-
152
- codes = [
153
- torch.tensor(layer_1, device=device).unsqueeze(0),
154
- torch.tensor(layer_2, device=device).unsqueeze(0),
155
- torch.tensor(layer_3, device=device).unsqueeze(0)
156
- ]
157
-
158
- audio_hat = snac_model.decode(codes)
159
- return audio_hat.detach().squeeze().cpu().numpy()
160
- except Exception as e:
161
- logger.error(f"Error in redistribute_codes: {e}", exc_info=True)
162
- return None
163
 
 
164
  @spaces.GPU()
165
- def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens):
166
- global tokenizer, model
167
- if tokenizer is None or model is None:
168
- print("Model or tokenizer is not initialized. Please ensure the model is properly loaded.")
169
- return None
170
-
171
  if not text.strip():
172
  return None
173
 
174
  try:
 
175
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
176
 
 
177
  with torch.no_grad():
178
  generated_ids = model.generate(
179
  input_ids=input_ids,
@@ -187,7 +136,10 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
187
  eos_token_id=128258,
188
  )
189
 
 
190
  code_list = parse_output(generated_ids)
 
 
191
  audio_samples = redistribute_codes(code_list, snac_model)
192
 
193
  return (24000, audio_samples) # Return sample rate and audio
@@ -195,71 +147,99 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
195
  print(f"Error generating speech: {e}")
196
  return None
197
 
198
- @spaces.GPU()
199
- def render_podcast(api_key, script, voice1, voice2, num_hosts):
200
- try:
201
- lines = [line for line in script.split('\n') if line.strip()]
202
- audio_segments = []
203
-
204
- for i, line in enumerate(lines):
205
- voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
206
- result = generate_speech(line, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200)
207
- if result is not None:
208
- sample_rate, audio = result
209
- audio_segments.append(audio)
210
-
211
- if not audio_segments:
212
- logger.warning("No valid audio segments were generated.")
213
- return (24000, np.zeros(24000, dtype=np.float32))
214
-
215
- podcast_audio = np.concatenate(audio_segments)
216
- podcast_audio = np.clip(podcast_audio, -1, 1)
217
- podcast_audio = (podcast_audio * 32767).astype(np.int16)
218
-
219
- return (24000, podcast_audio)
220
- except Exception as e:
221
- logger.error(f"Error rendering podcast: {str(e)}")
222
- raise
223
-
224
- with gr.Blocks() as demo:
225
- gr.Markdown("# AI Podcast Generator")
226
-
227
- api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
228
 
 
 
 
 
 
229
  with gr.Row():
230
- content_input = gr.Textbox(label="Paste your content (optional)", lines=4)
231
- document_upload = gr.File(label="Upload Document (optional)")
232
-
233
- duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration", value="1-5 min")
234
- num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2)
235
-
236
- voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
237
- voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara")
238
- voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo")
239
-
240
- generate_btn = gr.Button("Generate Script")
241
- script_output = gr.Textbox(label="Generated Script", lines=10)
242
-
243
- render_btn = gr.Button("Render Podcast")
244
- audio_output = gr.Audio(label="Generated Podcast")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- generate_btn.click(generate_podcast_script,
247
- inputs=[api_key_input, content_input, document_upload, duration, num_hosts],
248
- outputs=script_output)
 
 
 
249
 
250
- render_btn.click(render_podcast,
251
- inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts],
252
- outputs=audio_output)
253
-
254
- num_hosts.change(lambda x: gr.update(visible=x == 2),
255
- inputs=[num_hosts],
256
- outputs=[voice2_select])
257
 
 
258
  if __name__ == "__main__":
259
- try:
260
- print("Loading models...")
261
- load_model() # This function should be defined to load all necessary models
262
- print("Models loaded successfully. Launching the interface...")
263
- demo.queue().launch(share=False, ssr_mode=False)
264
- except Exception as e:
265
- print(f"Error during startup: {e}")
 
1
+ import spaces
2
+ from snac import SNAC
 
 
3
  import torch
4
+ import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download
7
+ from dotenv import load_dotenv
8
+ load_dotenv()
 
 
 
 
 
 
 
 
 
9
 
10
+ # Check if CUDA is available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ print("Loading SNAC model...")
14
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
15
+ snac_model = snac_model.to(device)
16
+
17
+ model_name = "canopylabs/orpheus-3b-0.1-ft"
18
+
19
+ # Download only model config and safetensors
20
+ snapshot_download(
21
+ repo_id=model_name,
22
+ allow_patterns=[
23
+ "config.json",
24
+ "*.safetensors",
25
+ "model.safetensors.index.json",
26
+ ],
27
+ ignore_patterns=[
28
+ "optimizer.pt",
29
+ "pytorch_model.bin",
30
+ "training_args.bin",
31
+ "scheduler.pt",
32
+ "tokenizer.json",
33
+ "tokenizer_config.json",
34
+ "special_tokens_map.json",
35
+ "vocab.json",
36
+ "merges.txt",
37
+ "tokenizer.*"
38
+ ]
39
+ )
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
42
+ model.to(device)
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ print(f"Orpheus model loaded to {device}")
45
+
46
+ # Process text prompt
47
  def process_prompt(prompt, voice, tokenizer, device):
48
  prompt = f"{voice}: {prompt}"
49
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
50
 
51
+ start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
52
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
53
+
54
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
55
 
56
+ # No padding needed for single input
57
  attention_mask = torch.ones_like(modified_input_ids)
58
 
59
  return modified_input_ids.to(device), attention_mask.to(device)
60
 
61
+ # Parse output tokens to audio
62
  def parse_output(generated_ids):
63
  token_to_find = 128257
64
  token_to_remove = 128258
 
84
  trimmed_row = [t - 128266 for t in trimmed_row]
85
  code_lists.append(trimmed_row)
86
 
87
+ return code_lists[0] # Return just the first one for single sample
88
 
89
+ # Redistribute codes for audio generation
90
  def redistribute_codes(code_list, snac_model):
91
+ device = next(snac_model.parameters()).device # Get the device of SNAC model
92
+
93
+ layer_1 = []
94
+ layer_2 = []
95
+ layer_3 = []
96
+ for i in range((len(code_list)+1)//7):
97
+ layer_1.append(code_list[7*i])
98
+ layer_2.append(code_list[7*i+1]-4096)
99
+ layer_3.append(code_list[7*i+2]-(2*4096))
100
+ layer_3.append(code_list[7*i+3]-(3*4096))
101
+ layer_2.append(code_list[7*i+4]-(4*4096))
102
+ layer_3.append(code_list[7*i+5]-(5*4096))
103
+ layer_3.append(code_list[7*i+6]-(6*4096))
104
+
105
+ # Move tensors to the same device as the SNAC model
106
+ codes = [
107
+ torch.tensor(layer_1, device=device).unsqueeze(0),
108
+ torch.tensor(layer_2, device=device).unsqueeze(0),
109
+ torch.tensor(layer_3, device=device).unsqueeze(0)
110
+ ]
111
+
112
+ audio_hat = snac_model.decode(codes)
113
+ return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
 
114
 
115
+ # Main generation function
116
  @spaces.GPU()
117
+ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
 
 
 
 
 
118
  if not text.strip():
119
  return None
120
 
121
  try:
122
+ progress(0.1, "Processing text...")
123
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
124
 
125
+ progress(0.3, "Generating speech tokens...")
126
  with torch.no_grad():
127
  generated_ids = model.generate(
128
  input_ids=input_ids,
 
136
  eos_token_id=128258,
137
  )
138
 
139
+ progress(0.6, "Processing speech tokens...")
140
  code_list = parse_output(generated_ids)
141
+
142
+ progress(0.8, "Converting to audio...")
143
  audio_samples = redistribute_codes(code_list, snac_model)
144
 
145
  return (24000, audio_samples) # Return sample rate and audio
 
147
  print(f"Error generating speech: {e}")
148
  return None
149
 
150
+ # Examples for the UI
151
+ examples = [
152
+ ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
153
+ ["I've also been taught to understand and produce paralinguistic things <sigh> like sighing, or <laugh> laughing, or <yawn> yawning!", "dan", 0.7, 0.95, 1.1, 1200],
154
+ ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... <gasp> well, lets just say a lot of parameters.", "leah", 0.6, 0.9, 1.2, 1200],
155
+ ["Sometimes when I talk too much, I need to <cough> excuse myself. <sniffle> The weather has been quite cold lately.", "leo", 0.65, 0.9, 1.1, 1200],
156
+ ["Public speaking can be challenging. <groan> But with enough practice, anyone can become better at it.", "jess", 0.7, 0.95, 1.1, 1200],
157
+ ["The hike was exhausting but the view from the top was absolutely breathtaking! <sigh> It was totally worth it.", "mia", 0.65, 0.9, 1.15, 1200],
158
+ ["Did you hear that joke? <laugh> I couldn't stop laughing when I first heard it. <chuckle> It's still funny.", "zac", 0.7, 0.95, 1.1, 1200],
159
+ ["After running the marathon, I was so tired <yawn> and needed a long rest. <sigh> But I felt accomplished.", "zoe", 0.6, 0.95, 1.1, 1200]
160
+ ]
161
+
162
+ # Available voices
163
+ VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
164
+
165
+ # Available Emotive Tags
166
+ EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
167
+
168
+ # Create Gradio interface
169
+ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
170
+ gr.Markdown(f"""
171
+ # 🎵 [Orpheus Text-to-Speech](https://github.com/canopyai/Orpheus-TTS)
172
+ Enter your text below and hear it converted to natural-sounding speech with the Orpheus TTS model.
 
 
 
 
 
 
 
173
 
174
+ ## Tips for better prompts:
175
+ - Add paralinguistic elements like {", ".join(EMOTIVE_TAGS)} or `uhm` for more human-like speech.
176
+ - Longer text prompts generally work better than very short phrases
177
+ - Increasing `repetition_penalty` and `temperature` makes the model speak faster.
178
+ """)
179
  with gr.Row():
180
+ with gr.Column(scale=3):
181
+ text_input = gr.Textbox(
182
+ label="Text to speak",
183
+ placeholder="Enter your text here...",
184
+ lines=5
185
+ )
186
+ voice = gr.Dropdown(
187
+ choices=VOICES,
188
+ value="tara",
189
+ label="Voice"
190
+ )
191
+
192
+ with gr.Accordion("Advanced Settings", open=False):
193
+ temperature = gr.Slider(
194
+ minimum=0.1, maximum=1.5, value=0.6, step=0.05,
195
+ label="Temperature",
196
+ info="Higher values (0.7-1.0) create more expressive but less stable speech"
197
+ )
198
+ top_p = gr.Slider(
199
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
200
+ label="Top P",
201
+ info="Nucleus sampling threshold"
202
+ )
203
+ repetition_penalty = gr.Slider(
204
+ minimum=1.0, maximum=2.0, value=1.1, step=0.05,
205
+ label="Repetition Penalty",
206
+ info="Higher values discourage repetitive patterns"
207
+ )
208
+ max_new_tokens = gr.Slider(
209
+ minimum=100, maximum=2000, value=1200, step=100,
210
+ label="Max Length",
211
+ info="Maximum length of generated audio (in tokens)"
212
+ )
213
+
214
+ with gr.Row():
215
+ submit_btn = gr.Button("Generate Speech", variant="primary")
216
+ clear_btn = gr.Button("Clear")
217
+
218
+ with gr.Column(scale=2):
219
+ audio_output = gr.Audio(label="Generated Speech", type="numpy")
220
+
221
+ # Set up examples
222
+ gr.Examples(
223
+ examples=examples,
224
+ inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
225
+ outputs=audio_output,
226
+ fn=generate_speech,
227
+ cache_examples=True,
228
+ )
229
 
230
+ # Set up event handlers
231
+ submit_btn.click(
232
+ fn=generate_speech,
233
+ inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
234
+ outputs=audio_output
235
+ )
236
 
237
+ clear_btn.click(
238
+ fn=lambda: (None, None),
239
+ inputs=[],
240
+ outputs=[text_input, audio_output]
241
+ )
 
 
242
 
243
+ # Launch the app
244
  if __name__ == "__main__":
245
+ demo.queue().launch(share=False, ssr_mode=False)