bluenevus commited on
Commit
33f9554
·
verified ·
1 Parent(s): 7976c43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -134
app.py CHANGED
@@ -80,108 +80,12 @@ def load_model():
80
  logger.error(f"Error loading model: {str(e)}")
81
  raise
82
 
83
- def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
84
- try:
85
- genai.configure(api_key=api_key)
86
- model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
87
-
88
- combined_content = content or ""
89
- if uploaded_file:
90
- file_content = uploaded_file.read().decode('utf-8')
91
- combined_content += "\n" + file_content if combined_content else file_content
92
-
93
- prompt = f"""
94
- Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing:
95
- {combined_content}
96
-
97
- Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts.
98
- Use speech fillers like um, ah. Vary emotional tone.
99
-
100
- Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels.
101
- Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines.
102
-
103
- Use emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>.
104
-
105
- Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>."
106
-
107
- Ensure content flows naturally and stays on topic. Match the script length to {duration}.
108
- """
109
-
110
- response = model.generate_content(prompt)
111
- return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
112
- except Exception as e:
113
- logger.error(f"Error generating podcast script: {str(e)}")
114
- raise
115
-
116
- def process_prompt(prompt, voice, tokenizer, device):
117
- prompt = f"{voice}: {prompt}"
118
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
119
-
120
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
121
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
122
-
123
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
124
-
125
- # No padding needed for single input
126
- attention_mask = torch.ones_like(modified_input_ids)
127
-
128
- return modified_input_ids.to(device), attention_mask.to(device)
129
-
130
- def parse_output(generated_ids):
131
- token_to_find = 128257
132
- token_to_remove = 128258
133
-
134
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
135
-
136
- if len(token_indices[1]) > 0:
137
- last_occurrence_idx = token_indices[1][-1].item()
138
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
139
- else:
140
- cropped_tensor = generated_ids
141
-
142
- processed_rows = []
143
- for row in cropped_tensor:
144
- masked_row = row[row != token_to_remove]
145
- processed_rows.append(masked_row)
146
-
147
- code_lists = []
148
- for row in processed_rows:
149
- row_length = row.size(0)
150
- new_length = (row_length // 7) * 7
151
- trimmed_row = row[:new_length]
152
- trimmed_row = [t - 128266 for t in trimmed_row]
153
- code_lists.append(trimmed_row)
154
-
155
- return code_lists[0] # Return just the first one for single sample
156
-
157
- def redistribute_codes(code_list, snac_model):
158
- device = next(snac_model.parameters()).device # Get the device of SNAC model
159
-
160
- layer_1 = []
161
- layer_2 = []
162
- layer_3 = []
163
- for i in range((len(code_list)+1)//7):
164
- layer_1.append(code_list[7*i])
165
- layer_2.append(code_list[7*i+1]-4096)
166
- layer_3.append(code_list[7*i+2]-(2*4096))
167
- layer_3.append(code_list[7*i+3]-(3*4096))
168
- layer_2.append(code_list[7*i+4]-(4*4096))
169
- layer_3.append(code_list[7*i+5]-(5*4096))
170
- layer_3.append(code_list[7*i+6]-(6*4096))
171
-
172
- # Move tensors to the same device as the SNAC model
173
- codes = [
174
- torch.tensor(layer_1, device=device).unsqueeze(0),
175
- torch.tensor(layer_2, device=device).unsqueeze(0),
176
- torch.tensor(layer_3, device=device).unsqueeze(0)
177
- ]
178
-
179
- audio_hat = snac_model.decode(codes)
180
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
181
-
182
  @spaces.GPU()
183
  def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
184
  global model, tokenizer, snac_model
 
 
 
185
  if not text.strip():
186
  return None
187
 
@@ -238,44 +142,11 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
238
  logger.error(f"Error rendering podcast: {str(e)}")
239
  raise
240
 
241
- # Gradio Interface
242
- with gr.Blocks() as demo:
243
- gr.Markdown("# AI Podcast Generator")
244
-
245
- api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
246
-
247
- with gr.Row():
248
- content_input = gr.Textbox(label="Paste your content (optional)")
249
- document_upload = gr.File(label="Upload Document (optional)")
250
-
251
- duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration")
252
- num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2)
253
-
254
- voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
255
- voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara")
256
- voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo")
257
-
258
- generate_btn = gr.Button("Generate Script")
259
- script_output = gr.Textbox(label="Generated Script", lines=10)
260
-
261
- render_btn = gr.Button("Render Podcast")
262
- audio_output = gr.Audio(label="Generated Podcast")
263
-
264
- generate_btn.click(generate_podcast_script,
265
- inputs=[api_key_input, content_input, document_upload, duration, num_hosts],
266
- outputs=script_output)
267
-
268
- render_btn.click(render_podcast,
269
- inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts],
270
- outputs=audio_output)
271
-
272
- num_hosts.change(lambda x: gr.update(visible=x == 2),
273
- inputs=[num_hosts],
274
- outputs=[voice2_select])
275
 
276
  if __name__ == "__main__":
277
  try:
278
- load_model()
279
  demo.launch()
280
  except Exception as e:
281
  logger.error(f"Error launching the application: {str(e)}")
 
80
  logger.error(f"Error loading model: {str(e)}")
81
  raise
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  @spaces.GPU()
84
  def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
85
  global model, tokenizer, snac_model
86
+ if model is None or tokenizer is None or snac_model is None:
87
+ load_model()
88
+
89
  if not text.strip():
90
  return None
91
 
 
142
  logger.error(f"Error rendering podcast: {str(e)}")
143
  raise
144
 
145
+ # ... (rest of the code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  try:
149
+ load_model() # Load models at startup
150
  demo.launch()
151
  except Exception as e:
152
  logger.error(f"Error launching the application: {str(e)}")